Initial commit.

master
float-trip 2023-07-09 23:00:43 +00:00
commit ad0ddd598a
8 changed files with 771 additions and 0 deletions

107
bot.py 100644
View File

@ -0,0 +1,107 @@
import random
import os
from rich import traceback
import model
import utils
from client import DramaClient
from config import config
traceback.install()
class Bot:
def __init__(self):
print("Loading model...")
self.model = model.Model()
self.client = DramaClient()
print("Ready.")
def post_random_reply(self):
print("Looking for comments...")
comments = [
c
for c in self.client.fetch_new_comments(limit=50)
if "author_name" in c
and c["author_name"] not in ["Bussy-boy", "AutoJanny", "BARD_BOT", "Snappy"]
and c["post_id"] != 0
]
if len(comments) == 0:
print("No comments found.")
return
random.shuffle(comments)
comments = comments[: config["num_replies"]]
for comment in comments:
self.reply(comment)
def respond_to_replies(self):
replies = self.client.fetch_new_replies()
for reply in replies:
if "author_name" not in reply:
continue
if reply["author_name"] not in [
"AutoJanny",
"BARD_BOT",
"bbbb",
"longpostbot",
]:
self.reply(reply)
def make_forced_replies(self):
file_path = f"{config['data_dir']}/forced.txt"
if not os.path.isfile(file_path):
return
with open(file_path, "r") as f:
lines = f.read().splitlines()
for comment_id in lines:
comment = self.client.get(f"/comment/{comment_id}")
self.reply(comment)
os.remove(file_path)
def reply(self, comment):
print(f"Generating reply for https://rdrama.net/comment/{comment['id']}")
post, thread_comments = self.client.fetch_context(comment)
if not post or not thread_comments:
print("Could not fetch context!")
prompt = utils.build_prompt(post, thread_comments)
utils.log_prompt(prompt)
candidates = []
num_candidates = config["num_candidates"] if random.random() < 0.6 else 1
while len(candidates) < num_candidates:
gen_text = self.model.generate(prompt)
reply = utils.extract_reply(gen_text)
print(f"Generated text: {gen_text}\nReply:\n{reply}")
reply = utils.format_reply(reply)
if len(reply) == 0:
print("Retrying: reply empty after processing.")
elif utils.is_low_quality(reply, post, thread_comments):
print("Retrying: low quality reply.")
else:
candidates.append(reply)
print("Accepting reply.")
reply = max(candidates, key=utils.reply_length)
self.client.reply(reply, comment)
if __name__ == "__main__":
bot = Bot()
bot.make_forced_replies()
bot.respond_to_replies()
bot.post_random_reply()
bot.respond_to_replies()

148
client.py 100644
View File

@ -0,0 +1,148 @@
import requests
import sys
import os
import math
import time
import shelve
from requests.adapters import HTTPAdapter, Retry
from config import config
class DramaClient:
BASE_URL = "https://rdrama.net"
def __init__(self):
self.db = shelve.open(f"{config['data_dir']}/client_state.p", writeback=True)
self.db.setdefault("processed_replies", set())
self.session = requests.Session()
retries = Retry(
total=5, backoff_factor=5, status_forcelist=[500, 502, 503, 504, 521]
)
self.session.mount("https://", HTTPAdapter(max_retries=retries))
self.chud_phrase = self.get("/@me").get("chud_phrase", "")
def get(self, endpoint):
print("GET", endpoint)
time.sleep(5)
while True:
r = self.session.get(
f"{self.BASE_URL}{endpoint}",
headers={"Authorization": config["api_token"]},
)
if "502 Bad Gateway" in r.text:
print("Received 502")
time.sleep(10)
continue
break
# Return None for country club and chudrama posts.
if r.status_code == 403:
return None
if r.status_code != 200:
print("Error!", r, r.status_code, r.content)
sys.exit(1)
return r.json()
def post(self, endpoint, payload=None, files=[]):
print("POST", endpoint, f"Payload:\n{payload}")
time.sleep(5)
while True:
r = self.session.post(
f"{self.BASE_URL}{endpoint}",
payload,
headers={"Authorization": config["api_token"]},
files=files,
)
if "502 Bad Gateway" in r.text:
print("Received 502")
time.sleep(10)
continue
break
if r.status_code != 200:
print("Error!", r, r.status_code, r.content)
sys.exit(1)
return r.json()
def fetch_new_comments(self, limit=0):
comments = []
last_processed_id = self.db.get("last_processed_id", -1)
earliest_id = math.inf
page = 1
# Fetch comments until we find the last one processed.
while earliest_id > last_processed_id:
page_comments = self.fetch_page(page)
if len(page_comments) == 0:
break
earliest_id = min([c["id"] for c in page_comments])
comments += [c for c in page_comments if c["id"] > last_processed_id]
if limit > 0 and len(comments) >= limit:
break
page += 1
if not comments:
return []
self.db["last_processed_id"] = max(c["id"] for c in comments)
self.db.sync()
# New comments may have pushed others to page n+1 while fetching.
deduped_comments = {c["id"]: c for c in comments}.values()
# Oldest first.
comments.reverse()
return comments
def fetch_new_replies(self):
notifs = self.get("/unread")["data"]
notifs = [n for n in notifs if n["body"]]
return notifs
def fetch_page(self, page):
return self.get(f"/comments?page={page}")["data"]
def fetch_context(self, comment):
post = self.get(f"/post/{comment['post_id']}")
if not post:
return None, None
comments = [comment]
while parent_id := comments[-1].get("parent_comment_id", None):
parent = self.get(f"/comment/{parent_id}")
comments.append(parent)
comments.reverse()
return post, comments
def reply(self, body, comment):
if self.chud_phrase and self.chud_phrase not in body:
body += f"\n{self.chud_phrase}"
payload = {
"parent_fullname": f"c_{comment['id']}",
"body": body,
}
self.post("/comment", payload)

View File

@ -0,0 +1,10 @@
api_token:
data_dir:
num_replies: 1
num_candidates: 5
prompt_token_limit: 6000
username: Bussy-boy
fake_usernames:
- TheBussyMan
- HedgehogInTheFog
- GodsStrongestMarsey

8
config.py 100644
View File

@ -0,0 +1,8 @@
import yaml
import os
current_dir = os.path.dirname(os.path.realpath(__file__))
config_path = os.path.join(current_dir, "config.yaml")
with open(config_path, "r") as f:
config = yaml.safe_load(f)

249
maxsubstring.py 100644
View File

@ -0,0 +1,249 @@
#!/usr/bin/env python
# https://gist.github.com/hynekcer/fa340f3b63826168ffc0c4b33310ae9c
"""Find the longest repeated substring.
"Efficient way to find longest duplicate string for Python (From Programming Pearls)"
http://stackoverflow.com/questions/13560037/
The algorithm is based on "Prefix doubling".
The worst time complexity is O(n (log n)^2). Memory requirements are linear.
"""
import time
from random import randint
import itertools
import sys
import unittest
from itertools import groupby
from operator import itemgetter
import logging
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)
try:
log.addHandler(logging.NullHandler())
except AttributeError:
pass
def run():
if sys.argv[1:] == ["-"]:
text = sys.stdin.read()
elif sys.argv[1:]:
print("Reading data...")
text = open(sys.argv[1]).read()
else:
text = "banana"
print("Sorting...")
result = longest_common_substring(text)
print('Longest common substrings in "{0}..." are:\n{1}'.format(text[:20], result))
def longest_common_substring(text):
"""Get the longest common substrings and their positions.
>>> longest_common_substring('banana')
{'ana': [1, 3]}
>>> text = "not so Agamemnon, who spoke fiercely to "
>>> sorted(longest_common_substring(text).items())
[(' s', [3, 21]), ('no', [0, 13]), ('o ', [5, 20, 38])]
This function can be easy modified for any criteria, e.g. for searching ten
longest non overlapping repeated substrings.
"""
sa, rsa, lcp = suffix_array(text)
maxlen = max(lcp)
result = {}
for i in range(1, len(text)):
if lcp[i] == maxlen:
j1, j2, h = sa[i - 1], sa[i], lcp[i]
assert text[j1 : j1 + h] == text[j2 : j2 + h]
substring = text[j1 : j1 + h]
if substring not in result:
result[substring] = [j1]
result[substring].append(j2)
return dict((k, sorted(v)) for k, v in result.items())
def suffix_array(text, _step=16):
"""Analyze all common strings in the text.
Short substrings of the length _step a are first pre-sorted. The are the
results repeatedly merged so that the garanteed number of compared
characters bytes is doubled in every iteration until all substrings are
sorted exactly.
Arguments:
text: The text to be analyzed.
_step: Is only for optimization and testing. It is the optimal length
of substrings used for initial pre-sorting. The bigger value is
faster if there is enough memory. Memory requirements are
approximately (estimate for 32 bit Python 3.3):
len(text) * (29 + (_size + 20 if _size > 2 else 0)) + 1MB
Return value: (tuple)
(sa, rsa, lcp)
sa: Suffix array for i in range(1, size):
assert text[sa[i-1]:] < text[sa[i]:]
rsa: Reverse suffix array for i in range(size):
assert rsa[sa[i]] == i
lcp: Longest common prefix for i in range(1, size):
assert text[sa[i-1]:sa[i-1]+lcp[i]] == text[sa[i]:sa[i]+lcp[i]]
if sa[i-1] + lcp[i] < len(text):
assert text[sa[i-1] + lcp[i]] < text[sa[i] + lcp[i]]
>>> suffix_array(text='banana')
([5, 3, 1, 0, 4, 2], [3, 2, 5, 1, 4, 0], [0, 1, 3, 0, 0, 2])
Explanation: 'a' < 'ana' < 'anana' < 'banana' < 'na' < 'nana'
The Longest Common String is 'ana': lcp[2] == 3 == len('ana')
It is between tx[sa[1]:] == 'ana' < 'anana' == tx[sa[2]:]
"""
tx = text
t0 = time.time()
size = len(tx)
step = min(max(_step, 1), len(tx))
sa = list(range(len(tx)))
log.debug("%6.3f pre sort", time.time() - t0)
sa.sort(key=lambda i: tx[i : i + step])
log.debug("%6.3f after sort", time.time() - t0)
grpstart = size * [False] + [True] # a boolean map for iteration speedup.
# It helps to skip yet resolved values. The last value True is a sentinel.
rsa = size * [None]
stgrp, igrp = "", 0
for i, pos in enumerate(sa):
st = tx[pos : pos + step]
if st != stgrp:
grpstart[igrp] = igrp < i - 1
stgrp = st
igrp = i
rsa[pos] = igrp
sa[i] = pos
grpstart[igrp] = igrp < size - 1 or size == 0
log.debug("%6.3f after group", time.time() - t0)
while grpstart.index(True) < size:
# assert step <= size
nmerge = 0
nextgr = grpstart.index(True)
while nextgr < size:
igrp = nextgr
nextgr = grpstart.index(True, igrp + 1)
glist = []
for ig in range(igrp, nextgr):
pos = sa[ig]
if rsa[pos] != igrp:
break
newgr = rsa[pos + step] if pos + step < size else -1
glist.append((newgr, pos))
glist.sort()
for ig, g in groupby(glist, key=itemgetter(0)):
g = [x[1] for x in g]
sa[igrp : igrp + len(g)] = g
grpstart[igrp] = len(g) > 1
for pos in g:
rsa[pos] = igrp
igrp += len(g)
nmerge += len(glist)
log.debug("%6.3f for step=%d nmerge=%d", time.time() - t0, step, nmerge)
step *= 2
del grpstart
# create LCP array
lcp = size * [None]
h = 0
for i in range(size):
if rsa[i] > 0:
j = sa[rsa[i] - 1]
while i != size - h and j != size - h and tx[i + h] == tx[j + h]:
h += 1
lcp[rsa[i]] = h
if h > 0:
h -= 1
if size > 0:
lcp[0] = 0
log.debug("%6.3f end", time.time() - t0)
return sa, rsa, lcp
# ---
class TestMixin(object):
def suffix_verify(self, text, step=16):
tx = text
sa, rsa, lcp = suffix_array(text=tx, _step=step)
self.assertEqual(set(sa), set(range(len(tx))))
ok = True
for i0, i1, h in zip(sa[:-1], sa[1:], lcp[1:]):
self.assertEqual(
tx[i1 : i1 + h],
tx[i0 : i0 + h],
"Verify LCP characters equal on text '%s...'" % text[:20],
)
self.assertGreater(
tx[i1 + h : i1 + h + 1],
tx[i0 + h : i0 + h + 1],
"Verify LCP+1 char is different '%s...'" % text[:20],
)
self.assertLessEqual(
max(i0, i1),
len(tx) - h,
"Verify LCP is not more than length of string '%s...'" % text[:20],
)
self.assertTrue(ok)
class SuffixArrayTest(unittest.TestCase, TestMixin):
def test_16(self):
# 'a' < 'ana' < 'anana' < 'banana' < 'na' < 'nana'
expect = ([5, 3, 1, 0, 4, 2], [3, 2, 5, 1, 4, 0], [0, 1, 3, 0, 0, 2])
self.assertEqual(suffix_array(text="banana", _step=16), expect)
def test_1(self):
expect = ([5, 3, 1, 0, 4, 2], [3, 2, 5, 1, 4, 0], [0, 1, 3, 0, 0, 2])
self.assertEqual(suffix_array(text="banana", _step=1), expect)
def test_mini(self):
self.assertEqual(suffix_array(text="", _step=1), ([], [], []))
self.assertEqual(suffix_array(text="a", _step=1), ([0], [0], [0]))
self.assertEqual(suffix_array(text="aa", _step=1), ([1, 0], [1, 0], [0, 1]))
self.assertEqual(
suffix_array(text="aaa", _step=1), ([2, 1, 0], [2, 1, 0], [0, 1, 2])
)
def test_example(self):
self.suffix_verify("abracadabra")
def test_cartesian(self):
"""Test all combinations of alphabet "ABC" up to length 4 characters"""
for size in range(7):
for cartesian in itertools.product(*(size * ["ABC"])):
text = "".join(cartesian)
log.debug('Testing "%s"', text)
self.suffix_verify(text, 1)
def test_lcp(self):
expect = {"ana": [1, 3]}
self.assertDictEqual(longest_common_substring("banana"), expect)
expect = {" s": [3, 21], "no": [0, 13], "o ": [5, 20, 38]}
self.assertDictEqual(
longest_common_substring("not so Agamemnon, who spoke fiercely to "), expect
)
class SlowTests(unittest.TestCase, TestMixin):
"""Slow development tests running many minutes.
It can be run only by an EXPLICIT command!
e.g.: python -m unittest maxsubstring.SlowTests._test_random
"""
def _test_random(self):
for power in range(2, 21, 2):
size = randint(2 ** (power - 1), 2**power)
for alphabet in (2, 4, 16, 256):
text = "".join(chr(65 + randint(0, alphabet - 1)) for _ in range(size))
log.debug("%s %s %s", size, alphabet, 1)
self.suffix_verify(text, 1)
log.debug("%s %s %s", size, alphabet, 16)
self.suffix_verify(text, 16)
if __name__ == "__main__":
run()

71
model.py 100644
View File

@ -0,0 +1,71 @@
import torch
import transformers
from transformers import (
GPTNeoXTokenizerFast,
LogitsProcessor,
AutoConfig,
AutoModelForCausalLM,
LogitsProcessorList,
)
from config import config
class StopAfterPlusIsGenerated(LogitsProcessor):
def __init__(self, plus_token_id, eos_token_id):
super().__init__()
self.plus_token_id = plus_token_id
self.eos_token_id = eos_token_id
def __call__(self, input_ids, scores):
forced_eos = torch.full((scores.size(1),), -float("inf")).to(
device=scores.device, dtype=scores.dtype
)
forced_eos[self.eos_token_id] = 0
scores[input_ids[:, -1] == self.plus_token_id] = forced_eos
return scores
class Model:
def __init__(self):
name = f"{config['data_dir']}/mpt-30b-drama-ba678"
self.tokenizer = GPTNeoXTokenizerFast.from_pretrained(
name, pad_token="<|endoftext|>"
)
model_config = AutoConfig.from_pretrained(name, trust_remote_code=True)
model_config.attn_config["attn_impl"] = "triton"
model_config.init_device = "cuda:0"
model_config.eos_token_id = self.tokenizer.eos_token_id
self.model = AutoModelForCausalLM.from_pretrained(
name,
config=model_config,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
self.logits_processor = LogitsProcessorList(
[StopAfterPlusIsGenerated(559, self.model.config.eos_token_id)]
)
def generate(self, prompt):
with torch.autocast("cuda", dtype=torch.bfloat16):
encoded = self.tokenizer(
prompt, return_tensors="pt", padding=True, truncation=True
).to("cuda:0")
gen_tokens = self.model.generate(
input_ids=encoded.input_ids,
attention_mask=encoded.attention_mask,
pad_token_id=0,
do_sample=True,
temperature=0.90,
use_cache=True,
max_length=8192,
logits_processor=self.logits_processor,
)
return self.tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)[0][
len(prompt) :
]

6
requirements.txt 100644
View File

@ -0,0 +1,6 @@
fuzzywuzzy==0.18.0
PyYAML==6.0
Requests==2.31.0
rich==13.4.2
torch==2.0.1
transformers==4.31.0

172
utils.py 100644
View File

@ -0,0 +1,172 @@
import random
import re
from fuzzywuzzy import fuzz
from transformers import GPTNeoXTokenizerFast
from config import config
from max_substring import longest_common_substring
URL_REGEX = (
r"http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+",
)
tokenizer = GPTNeoXTokenizerFast.from_pretrained("mosaicml/mpt-7b")
def remove_notifications(text):
"""Change @float-trip to @<i></i>float-trip and carp to c<i></i>arp."""
text = re.sub(rf"@(?!{config['username']}\b)", "@<i></i>", text)
notified_users = [
"aevan",
"avean",
"joan",
"pewkie",
"carp",
"idio3",
"idio ",
"the_homocracy",
"schizocel",
"scitzocel",
"snakes",
"sneks",
"jc",
"justcool",
"clit",
"geese",
"kippy",
"mccox",
"chiobu",
"donger",
"soren",
]
for user in notified_users:
match = re.search(user, text, re.IGNORECASE)
if match:
text = f"{text[:match.start() + 1]}<i></i>{text[match.start() + 1:]}"
return text
def format_reply(text):
for username in config["fake_usernames"]:
text.replace(username, config["username"])
text = replace_rdrama_images(text)
return text.strip()
def is_low_quality(reply, post, comments):
"""
Label the reply as low quality if:
- The Levenshtein distance determines it's similar to a previous comment in the thread.
- len(longest_common_substring) > 100
- After removing links, Markdown images, and quoted text, the length is < 10.
"""
for comment in comments:
if fuzz.ratio(reply, comment["body"]) > 90:
return True
lcs = list(longest_common_substring(reply).keys())[0]
if len(lcs) >= 100:
return True
if reply_length(reply) < 10:
return True
return False
def replace_rdrama_images(text):
"""Replace images pointing to rdrama.net with a loading image."""
loading = "https://i.rdrama.net/i/l.webp"
webp_pattern = r"https://\S*\.rdrama\.net/\S*\.webp"
md_img_pattern = r"!\[[^\]]*\]\((https://\S*\.rdrama\.net)?/\S*\)"
text = re.sub(webp_pattern, loading, text)
text = re.sub(md_img_pattern, f"![]({loading})", text)
return text
def normalize_emojis(s):
"""Bring # and ! to the front of an emoji."""
def repl(match):
# Extract the word between colons and the special characters.
word = match.group(0)
specials = set(re.findall(r"[#!]", word))
# Sort specials and append the word without specials.
new_emoji = "".join(sorted(specials, reverse=True)) + re.sub(r"[#!]", "", word)
return new_emoji
emoji_pattern = r"(?<=:)[a-zA-Z@#!]*[#!][a-zA-Z@#!]*(?=:)"
s = re.sub(emoji_pattern, repl, s)
return s
def build_prompt(post, comments):
prompt = (
f"[Post] [Author] {post['author_name']} "
f"[Title] {post['title']} [URL] {post['url']} "
f"[Hole] {post['sub'] or 'N/A'} [Votes] +71 / -0\n\n"
f"{post['body']}\n\n[Comments]"
)
comments.append({"author_name": config["username"], "body": ""})
for depth, comment in enumerate(comments):
body = normalize_emojis(comment["body"])
author = comment["author_name"]
comment_str = f"\n\n{author} +45 / -0\n{body}"
indent = depth * " "
comment_str = "\n".join([indent + line for line in comment_str.split("\n")])
prompt += comment_str
prompt = prompt.replace(config["username"], random.choice(config["fake_usernames"]))
prompt = prompt.strip() + "\n"
# Truncate the prompt to leave room for generation.
tokens = tokenizer.tokenize(prompt)
if len(tokens) > config["prompt_token_limit"]:
tokens = tokens[-config["prompt_token_limit"] :]
prompt = tokenizer.convert_tokens_to_string(tokens)
return prompt
def log_prompt(prompt):
with open(f"{config['data_dir']}/prompts.txt", "a") as f:
f.write(f"{prompt}\n==========\n")
def reply_length(reply):
"""Return the length of the reply, without Markdown images, URLs, or quoted text."""
# Remove Markdown images and URLs.
reply = re.sub(r"!\[.*?\]\(.*?\)", "", reply)
reply = re.sub(URL_REGEX, "", reply)
# Remove quoted text.
lines = reply.splitlines()
lines = [line for line in lines if not line.lstrip().startswith((">", "\\>"))]
reply = "\n".join(lines).strip()
return len(reply)
def count_tokens(text):
return len(tokenizer(text).input_ids)
def extract_reply(text):
"""
Generated text will either:
- Be cut off at the token limit
- End with the start of a new comment: `float-trip +10`
For the latter case, drop the last line.
"""
pattern = r"^ *[\w-]* +\+.*$"
lines = text.split("\n")
if re.match(pattern, lines[-1]):
lines = lines[:-1]
return "\n".join([line.strip() for line in lines]).strip()