Initial commit.
commit
ad0ddd598a
|
@ -0,0 +1,107 @@
|
|||
import random
|
||||
import os
|
||||
|
||||
from rich import traceback
|
||||
|
||||
import model
|
||||
import utils
|
||||
from client import DramaClient
|
||||
|
||||
from config import config
|
||||
|
||||
traceback.install()
|
||||
|
||||
|
||||
class Bot:
|
||||
def __init__(self):
|
||||
print("Loading model...")
|
||||
self.model = model.Model()
|
||||
self.client = DramaClient()
|
||||
print("Ready.")
|
||||
|
||||
def post_random_reply(self):
|
||||
print("Looking for comments...")
|
||||
|
||||
comments = [
|
||||
c
|
||||
for c in self.client.fetch_new_comments(limit=50)
|
||||
if "author_name" in c
|
||||
and c["author_name"] not in ["Bussy-boy", "AutoJanny", "BARD_BOT", "Snappy"]
|
||||
and c["post_id"] != 0
|
||||
]
|
||||
|
||||
if len(comments) == 0:
|
||||
print("No comments found.")
|
||||
return
|
||||
|
||||
random.shuffle(comments)
|
||||
comments = comments[: config["num_replies"]]
|
||||
for comment in comments:
|
||||
self.reply(comment)
|
||||
|
||||
def respond_to_replies(self):
|
||||
replies = self.client.fetch_new_replies()
|
||||
|
||||
for reply in replies:
|
||||
if "author_name" not in reply:
|
||||
continue
|
||||
|
||||
if reply["author_name"] not in [
|
||||
"AutoJanny",
|
||||
"BARD_BOT",
|
||||
"bbbb",
|
||||
"longpostbot",
|
||||
]:
|
||||
self.reply(reply)
|
||||
|
||||
def make_forced_replies(self):
|
||||
file_path = f"{config['data_dir']}/forced.txt"
|
||||
if not os.path.isfile(file_path):
|
||||
return
|
||||
|
||||
with open(file_path, "r") as f:
|
||||
lines = f.read().splitlines()
|
||||
|
||||
for comment_id in lines:
|
||||
comment = self.client.get(f"/comment/{comment_id}")
|
||||
self.reply(comment)
|
||||
|
||||
os.remove(file_path)
|
||||
|
||||
def reply(self, comment):
|
||||
print(f"Generating reply for https://rdrama.net/comment/{comment['id']}")
|
||||
|
||||
post, thread_comments = self.client.fetch_context(comment)
|
||||
|
||||
if not post or not thread_comments:
|
||||
print("Could not fetch context!")
|
||||
|
||||
prompt = utils.build_prompt(post, thread_comments)
|
||||
utils.log_prompt(prompt)
|
||||
|
||||
candidates = []
|
||||
num_candidates = config["num_candidates"] if random.random() < 0.6 else 1
|
||||
while len(candidates) < num_candidates:
|
||||
gen_text = self.model.generate(prompt)
|
||||
reply = utils.extract_reply(gen_text)
|
||||
print(f"Generated text: {gen_text}\nReply:\n{reply}")
|
||||
reply = utils.format_reply(reply)
|
||||
|
||||
if len(reply) == 0:
|
||||
print("Retrying: reply empty after processing.")
|
||||
elif utils.is_low_quality(reply, post, thread_comments):
|
||||
print("Retrying: low quality reply.")
|
||||
else:
|
||||
candidates.append(reply)
|
||||
print("Accepting reply.")
|
||||
|
||||
reply = max(candidates, key=utils.reply_length)
|
||||
self.client.reply(reply, comment)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bot = Bot()
|
||||
bot.make_forced_replies()
|
||||
bot.respond_to_replies()
|
||||
bot.post_random_reply()
|
||||
bot.respond_to_replies()
|
|
@ -0,0 +1,148 @@
|
|||
import requests
|
||||
import sys
|
||||
import os
|
||||
import math
|
||||
import time
|
||||
import shelve
|
||||
|
||||
from requests.adapters import HTTPAdapter, Retry
|
||||
|
||||
from config import config
|
||||
|
||||
|
||||
class DramaClient:
|
||||
BASE_URL = "https://rdrama.net"
|
||||
|
||||
def __init__(self):
|
||||
self.db = shelve.open(f"{config['data_dir']}/client_state.p", writeback=True)
|
||||
self.db.setdefault("processed_replies", set())
|
||||
|
||||
self.session = requests.Session()
|
||||
retries = Retry(
|
||||
total=5, backoff_factor=5, status_forcelist=[500, 502, 503, 504, 521]
|
||||
)
|
||||
self.session.mount("https://", HTTPAdapter(max_retries=retries))
|
||||
|
||||
self.chud_phrase = self.get("/@me").get("chud_phrase", "")
|
||||
|
||||
def get(self, endpoint):
|
||||
print("GET", endpoint)
|
||||
time.sleep(5)
|
||||
|
||||
while True:
|
||||
r = self.session.get(
|
||||
f"{self.BASE_URL}{endpoint}",
|
||||
headers={"Authorization": config["api_token"]},
|
||||
)
|
||||
|
||||
if "502 Bad Gateway" in r.text:
|
||||
print("Received 502")
|
||||
time.sleep(10)
|
||||
continue
|
||||
|
||||
break
|
||||
|
||||
# Return None for country club and chudrama posts.
|
||||
if r.status_code == 403:
|
||||
return None
|
||||
|
||||
if r.status_code != 200:
|
||||
print("Error!", r, r.status_code, r.content)
|
||||
sys.exit(1)
|
||||
|
||||
return r.json()
|
||||
|
||||
def post(self, endpoint, payload=None, files=[]):
|
||||
print("POST", endpoint, f"Payload:\n{payload}")
|
||||
time.sleep(5)
|
||||
|
||||
while True:
|
||||
r = self.session.post(
|
||||
f"{self.BASE_URL}{endpoint}",
|
||||
payload,
|
||||
headers={"Authorization": config["api_token"]},
|
||||
files=files,
|
||||
)
|
||||
|
||||
if "502 Bad Gateway" in r.text:
|
||||
print("Received 502")
|
||||
time.sleep(10)
|
||||
continue
|
||||
|
||||
break
|
||||
|
||||
if r.status_code != 200:
|
||||
print("Error!", r, r.status_code, r.content)
|
||||
sys.exit(1)
|
||||
|
||||
return r.json()
|
||||
|
||||
def fetch_new_comments(self, limit=0):
|
||||
comments = []
|
||||
|
||||
last_processed_id = self.db.get("last_processed_id", -1)
|
||||
earliest_id = math.inf
|
||||
page = 1
|
||||
|
||||
# Fetch comments until we find the last one processed.
|
||||
while earliest_id > last_processed_id:
|
||||
page_comments = self.fetch_page(page)
|
||||
|
||||
if len(page_comments) == 0:
|
||||
break
|
||||
|
||||
earliest_id = min([c["id"] for c in page_comments])
|
||||
comments += [c for c in page_comments if c["id"] > last_processed_id]
|
||||
|
||||
if limit > 0 and len(comments) >= limit:
|
||||
break
|
||||
|
||||
page += 1
|
||||
|
||||
if not comments:
|
||||
return []
|
||||
|
||||
self.db["last_processed_id"] = max(c["id"] for c in comments)
|
||||
self.db.sync()
|
||||
|
||||
# New comments may have pushed others to page n+1 while fetching.
|
||||
deduped_comments = {c["id"]: c for c in comments}.values()
|
||||
|
||||
# Oldest first.
|
||||
comments.reverse()
|
||||
|
||||
return comments
|
||||
|
||||
def fetch_new_replies(self):
|
||||
notifs = self.get("/unread")["data"]
|
||||
notifs = [n for n in notifs if n["body"]]
|
||||
return notifs
|
||||
|
||||
def fetch_page(self, page):
|
||||
return self.get(f"/comments?page={page}")["data"]
|
||||
|
||||
def fetch_context(self, comment):
|
||||
post = self.get(f"/post/{comment['post_id']}")
|
||||
|
||||
if not post:
|
||||
return None, None
|
||||
|
||||
comments = [comment]
|
||||
while parent_id := comments[-1].get("parent_comment_id", None):
|
||||
parent = self.get(f"/comment/{parent_id}")
|
||||
comments.append(parent)
|
||||
|
||||
comments.reverse()
|
||||
|
||||
return post, comments
|
||||
|
||||
def reply(self, body, comment):
|
||||
if self.chud_phrase and self.chud_phrase not in body:
|
||||
body += f"\n{self.chud_phrase}"
|
||||
|
||||
payload = {
|
||||
"parent_fullname": f"c_{comment['id']}",
|
||||
"body": body,
|
||||
}
|
||||
|
||||
self.post("/comment", payload)
|
|
@ -0,0 +1,10 @@
|
|||
api_token:
|
||||
data_dir:
|
||||
num_replies: 1
|
||||
num_candidates: 5
|
||||
prompt_token_limit: 6000
|
||||
username: Bussy-boy
|
||||
fake_usernames:
|
||||
- TheBussyMan
|
||||
- HedgehogInTheFog
|
||||
- GodsStrongestMarsey
|
|
@ -0,0 +1,8 @@
|
|||
import yaml
|
||||
import os
|
||||
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
config_path = os.path.join(current_dir, "config.yaml")
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
config = yaml.safe_load(f)
|
|
@ -0,0 +1,249 @@
|
|||
#!/usr/bin/env python
|
||||
# https://gist.github.com/hynekcer/fa340f3b63826168ffc0c4b33310ae9c
|
||||
"""Find the longest repeated substring.
|
||||
"Efficient way to find longest duplicate string for Python (From Programming Pearls)"
|
||||
http://stackoverflow.com/questions/13560037/
|
||||
|
||||
The algorithm is based on "Prefix doubling".
|
||||
The worst time complexity is O(n (log n)^2). Memory requirements are linear.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
from random import randint
|
||||
import itertools
|
||||
import sys
|
||||
import unittest
|
||||
from itertools import groupby
|
||||
from operator import itemgetter
|
||||
import logging
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(logging.INFO)
|
||||
try:
|
||||
log.addHandler(logging.NullHandler())
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
def run():
|
||||
if sys.argv[1:] == ["-"]:
|
||||
text = sys.stdin.read()
|
||||
elif sys.argv[1:]:
|
||||
print("Reading data...")
|
||||
text = open(sys.argv[1]).read()
|
||||
else:
|
||||
text = "banana"
|
||||
print("Sorting...")
|
||||
result = longest_common_substring(text)
|
||||
print('Longest common substrings in "{0}..." are:\n{1}'.format(text[:20], result))
|
||||
|
||||
|
||||
def longest_common_substring(text):
|
||||
"""Get the longest common substrings and their positions.
|
||||
>>> longest_common_substring('banana')
|
||||
{'ana': [1, 3]}
|
||||
>>> text = "not so Agamemnon, who spoke fiercely to "
|
||||
>>> sorted(longest_common_substring(text).items())
|
||||
[(' s', [3, 21]), ('no', [0, 13]), ('o ', [5, 20, 38])]
|
||||
|
||||
This function can be easy modified for any criteria, e.g. for searching ten
|
||||
longest non overlapping repeated substrings.
|
||||
"""
|
||||
sa, rsa, lcp = suffix_array(text)
|
||||
maxlen = max(lcp)
|
||||
result = {}
|
||||
for i in range(1, len(text)):
|
||||
if lcp[i] == maxlen:
|
||||
j1, j2, h = sa[i - 1], sa[i], lcp[i]
|
||||
assert text[j1 : j1 + h] == text[j2 : j2 + h]
|
||||
substring = text[j1 : j1 + h]
|
||||
if substring not in result:
|
||||
result[substring] = [j1]
|
||||
result[substring].append(j2)
|
||||
return dict((k, sorted(v)) for k, v in result.items())
|
||||
|
||||
|
||||
def suffix_array(text, _step=16):
|
||||
"""Analyze all common strings in the text.
|
||||
|
||||
Short substrings of the length _step a are first pre-sorted. The are the
|
||||
results repeatedly merged so that the garanteed number of compared
|
||||
characters bytes is doubled in every iteration until all substrings are
|
||||
sorted exactly.
|
||||
Arguments:
|
||||
text: The text to be analyzed.
|
||||
_step: Is only for optimization and testing. It is the optimal length
|
||||
of substrings used for initial pre-sorting. The bigger value is
|
||||
faster if there is enough memory. Memory requirements are
|
||||
approximately (estimate for 32 bit Python 3.3):
|
||||
len(text) * (29 + (_size + 20 if _size > 2 else 0)) + 1MB
|
||||
|
||||
Return value: (tuple)
|
||||
(sa, rsa, lcp)
|
||||
sa: Suffix array for i in range(1, size):
|
||||
assert text[sa[i-1]:] < text[sa[i]:]
|
||||
rsa: Reverse suffix array for i in range(size):
|
||||
assert rsa[sa[i]] == i
|
||||
lcp: Longest common prefix for i in range(1, size):
|
||||
assert text[sa[i-1]:sa[i-1]+lcp[i]] == text[sa[i]:sa[i]+lcp[i]]
|
||||
if sa[i-1] + lcp[i] < len(text):
|
||||
assert text[sa[i-1] + lcp[i]] < text[sa[i] + lcp[i]]
|
||||
>>> suffix_array(text='banana')
|
||||
([5, 3, 1, 0, 4, 2], [3, 2, 5, 1, 4, 0], [0, 1, 3, 0, 0, 2])
|
||||
|
||||
Explanation: 'a' < 'ana' < 'anana' < 'banana' < 'na' < 'nana'
|
||||
The Longest Common String is 'ana': lcp[2] == 3 == len('ana')
|
||||
It is between tx[sa[1]:] == 'ana' < 'anana' == tx[sa[2]:]
|
||||
"""
|
||||
tx = text
|
||||
t0 = time.time()
|
||||
size = len(tx)
|
||||
step = min(max(_step, 1), len(tx))
|
||||
sa = list(range(len(tx)))
|
||||
log.debug("%6.3f pre sort", time.time() - t0)
|
||||
sa.sort(key=lambda i: tx[i : i + step])
|
||||
log.debug("%6.3f after sort", time.time() - t0)
|
||||
grpstart = size * [False] + [True] # a boolean map for iteration speedup.
|
||||
# It helps to skip yet resolved values. The last value True is a sentinel.
|
||||
rsa = size * [None]
|
||||
stgrp, igrp = "", 0
|
||||
for i, pos in enumerate(sa):
|
||||
st = tx[pos : pos + step]
|
||||
if st != stgrp:
|
||||
grpstart[igrp] = igrp < i - 1
|
||||
stgrp = st
|
||||
igrp = i
|
||||
rsa[pos] = igrp
|
||||
sa[i] = pos
|
||||
grpstart[igrp] = igrp < size - 1 or size == 0
|
||||
log.debug("%6.3f after group", time.time() - t0)
|
||||
while grpstart.index(True) < size:
|
||||
# assert step <= size
|
||||
nmerge = 0
|
||||
nextgr = grpstart.index(True)
|
||||
while nextgr < size:
|
||||
igrp = nextgr
|
||||
nextgr = grpstart.index(True, igrp + 1)
|
||||
glist = []
|
||||
for ig in range(igrp, nextgr):
|
||||
pos = sa[ig]
|
||||
if rsa[pos] != igrp:
|
||||
break
|
||||
newgr = rsa[pos + step] if pos + step < size else -1
|
||||
glist.append((newgr, pos))
|
||||
glist.sort()
|
||||
for ig, g in groupby(glist, key=itemgetter(0)):
|
||||
g = [x[1] for x in g]
|
||||
sa[igrp : igrp + len(g)] = g
|
||||
grpstart[igrp] = len(g) > 1
|
||||
for pos in g:
|
||||
rsa[pos] = igrp
|
||||
igrp += len(g)
|
||||
nmerge += len(glist)
|
||||
log.debug("%6.3f for step=%d nmerge=%d", time.time() - t0, step, nmerge)
|
||||
step *= 2
|
||||
del grpstart
|
||||
# create LCP array
|
||||
lcp = size * [None]
|
||||
h = 0
|
||||
for i in range(size):
|
||||
if rsa[i] > 0:
|
||||
j = sa[rsa[i] - 1]
|
||||
while i != size - h and j != size - h and tx[i + h] == tx[j + h]:
|
||||
h += 1
|
||||
lcp[rsa[i]] = h
|
||||
if h > 0:
|
||||
h -= 1
|
||||
if size > 0:
|
||||
lcp[0] = 0
|
||||
log.debug("%6.3f end", time.time() - t0)
|
||||
return sa, rsa, lcp
|
||||
|
||||
|
||||
# ---
|
||||
|
||||
|
||||
class TestMixin(object):
|
||||
def suffix_verify(self, text, step=16):
|
||||
tx = text
|
||||
sa, rsa, lcp = suffix_array(text=tx, _step=step)
|
||||
self.assertEqual(set(sa), set(range(len(tx))))
|
||||
ok = True
|
||||
for i0, i1, h in zip(sa[:-1], sa[1:], lcp[1:]):
|
||||
self.assertEqual(
|
||||
tx[i1 : i1 + h],
|
||||
tx[i0 : i0 + h],
|
||||
"Verify LCP characters equal on text '%s...'" % text[:20],
|
||||
)
|
||||
self.assertGreater(
|
||||
tx[i1 + h : i1 + h + 1],
|
||||
tx[i0 + h : i0 + h + 1],
|
||||
"Verify LCP+1 char is different '%s...'" % text[:20],
|
||||
)
|
||||
self.assertLessEqual(
|
||||
max(i0, i1),
|
||||
len(tx) - h,
|
||||
"Verify LCP is not more than length of string '%s...'" % text[:20],
|
||||
)
|
||||
self.assertTrue(ok)
|
||||
|
||||
|
||||
class SuffixArrayTest(unittest.TestCase, TestMixin):
|
||||
def test_16(self):
|
||||
# 'a' < 'ana' < 'anana' < 'banana' < 'na' < 'nana'
|
||||
expect = ([5, 3, 1, 0, 4, 2], [3, 2, 5, 1, 4, 0], [0, 1, 3, 0, 0, 2])
|
||||
self.assertEqual(suffix_array(text="banana", _step=16), expect)
|
||||
|
||||
def test_1(self):
|
||||
expect = ([5, 3, 1, 0, 4, 2], [3, 2, 5, 1, 4, 0], [0, 1, 3, 0, 0, 2])
|
||||
self.assertEqual(suffix_array(text="banana", _step=1), expect)
|
||||
|
||||
def test_mini(self):
|
||||
self.assertEqual(suffix_array(text="", _step=1), ([], [], []))
|
||||
self.assertEqual(suffix_array(text="a", _step=1), ([0], [0], [0]))
|
||||
self.assertEqual(suffix_array(text="aa", _step=1), ([1, 0], [1, 0], [0, 1]))
|
||||
self.assertEqual(
|
||||
suffix_array(text="aaa", _step=1), ([2, 1, 0], [2, 1, 0], [0, 1, 2])
|
||||
)
|
||||
|
||||
def test_example(self):
|
||||
self.suffix_verify("abracadabra")
|
||||
|
||||
def test_cartesian(self):
|
||||
"""Test all combinations of alphabet "ABC" up to length 4 characters"""
|
||||
for size in range(7):
|
||||
for cartesian in itertools.product(*(size * ["ABC"])):
|
||||
text = "".join(cartesian)
|
||||
log.debug('Testing "%s"', text)
|
||||
self.suffix_verify(text, 1)
|
||||
|
||||
def test_lcp(self):
|
||||
expect = {"ana": [1, 3]}
|
||||
self.assertDictEqual(longest_common_substring("banana"), expect)
|
||||
expect = {" s": [3, 21], "no": [0, 13], "o ": [5, 20, 38]}
|
||||
self.assertDictEqual(
|
||||
longest_common_substring("not so Agamemnon, who spoke fiercely to "), expect
|
||||
)
|
||||
|
||||
|
||||
class SlowTests(unittest.TestCase, TestMixin):
|
||||
"""Slow development tests running many minutes.
|
||||
|
||||
It can be run only by an EXPLICIT command!
|
||||
e.g.: python -m unittest maxsubstring.SlowTests._test_random
|
||||
"""
|
||||
|
||||
def _test_random(self):
|
||||
for power in range(2, 21, 2):
|
||||
size = randint(2 ** (power - 1), 2**power)
|
||||
for alphabet in (2, 4, 16, 256):
|
||||
text = "".join(chr(65 + randint(0, alphabet - 1)) for _ in range(size))
|
||||
log.debug("%s %s %s", size, alphabet, 1)
|
||||
self.suffix_verify(text, 1)
|
||||
log.debug("%s %s %s", size, alphabet, 16)
|
||||
self.suffix_verify(text, 16)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
|
@ -0,0 +1,71 @@
|
|||
import torch
|
||||
import transformers
|
||||
from transformers import (
|
||||
GPTNeoXTokenizerFast,
|
||||
LogitsProcessor,
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
LogitsProcessorList,
|
||||
)
|
||||
from config import config
|
||||
|
||||
|
||||
class StopAfterPlusIsGenerated(LogitsProcessor):
|
||||
def __init__(self, plus_token_id, eos_token_id):
|
||||
super().__init__()
|
||||
|
||||
self.plus_token_id = plus_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
def __call__(self, input_ids, scores):
|
||||
forced_eos = torch.full((scores.size(1),), -float("inf")).to(
|
||||
device=scores.device, dtype=scores.dtype
|
||||
)
|
||||
forced_eos[self.eos_token_id] = 0
|
||||
|
||||
scores[input_ids[:, -1] == self.plus_token_id] = forced_eos
|
||||
return scores
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self):
|
||||
name = f"{config['data_dir']}/mpt-30b-drama-ba678"
|
||||
self.tokenizer = GPTNeoXTokenizerFast.from_pretrained(
|
||||
name, pad_token="<|endoftext|>"
|
||||
)
|
||||
|
||||
model_config = AutoConfig.from_pretrained(name, trust_remote_code=True)
|
||||
model_config.attn_config["attn_impl"] = "triton"
|
||||
model_config.init_device = "cuda:0"
|
||||
model_config.eos_token_id = self.tokenizer.eos_token_id
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
name,
|
||||
config=model_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
self.logits_processor = LogitsProcessorList(
|
||||
[StopAfterPlusIsGenerated(559, self.model.config.eos_token_id)]
|
||||
)
|
||||
|
||||
def generate(self, prompt):
|
||||
with torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
encoded = self.tokenizer(
|
||||
prompt, return_tensors="pt", padding=True, truncation=True
|
||||
).to("cuda:0")
|
||||
|
||||
gen_tokens = self.model.generate(
|
||||
input_ids=encoded.input_ids,
|
||||
attention_mask=encoded.attention_mask,
|
||||
pad_token_id=0,
|
||||
do_sample=True,
|
||||
temperature=0.90,
|
||||
use_cache=True,
|
||||
max_length=8192,
|
||||
logits_processor=self.logits_processor,
|
||||
)
|
||||
return self.tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)[0][
|
||||
len(prompt) :
|
||||
]
|
|
@ -0,0 +1,6 @@
|
|||
fuzzywuzzy==0.18.0
|
||||
PyYAML==6.0
|
||||
Requests==2.31.0
|
||||
rich==13.4.2
|
||||
torch==2.0.1
|
||||
transformers==4.31.0
|
|
@ -0,0 +1,172 @@
|
|||
import random
|
||||
import re
|
||||
|
||||
from fuzzywuzzy import fuzz
|
||||
from transformers import GPTNeoXTokenizerFast
|
||||
|
||||
from config import config
|
||||
from max_substring import longest_common_substring
|
||||
|
||||
URL_REGEX = (
|
||||
r"http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+",
|
||||
)
|
||||
tokenizer = GPTNeoXTokenizerFast.from_pretrained("mosaicml/mpt-7b")
|
||||
|
||||
|
||||
def remove_notifications(text):
|
||||
"""Change @float-trip to @<i></i>float-trip and carp to c<i></i>arp."""
|
||||
text = re.sub(rf"@(?!{config['username']}\b)", "@<i></i>", text)
|
||||
|
||||
notified_users = [
|
||||
"aevan",
|
||||
"avean",
|
||||
"joan",
|
||||
"pewkie",
|
||||
"carp",
|
||||
"idio3",
|
||||
"idio ",
|
||||
"the_homocracy",
|
||||
"schizocel",
|
||||
"scitzocel",
|
||||
"snakes",
|
||||
"sneks",
|
||||
"jc",
|
||||
"justcool",
|
||||
"clit",
|
||||
"geese",
|
||||
"kippy",
|
||||
"mccox",
|
||||
"chiobu",
|
||||
"donger",
|
||||
"soren",
|
||||
]
|
||||
|
||||
for user in notified_users:
|
||||
match = re.search(user, text, re.IGNORECASE)
|
||||
if match:
|
||||
text = f"{text[:match.start() + 1]}<i></i>{text[match.start() + 1:]}"
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def format_reply(text):
|
||||
for username in config["fake_usernames"]:
|
||||
text.replace(username, config["username"])
|
||||
text = replace_rdrama_images(text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def is_low_quality(reply, post, comments):
|
||||
"""
|
||||
Label the reply as low quality if:
|
||||
- The Levenshtein distance determines it's similar to a previous comment in the thread.
|
||||
- len(longest_common_substring) > 100
|
||||
- After removing links, Markdown images, and quoted text, the length is < 10.
|
||||
"""
|
||||
for comment in comments:
|
||||
if fuzz.ratio(reply, comment["body"]) > 90:
|
||||
return True
|
||||
|
||||
lcs = list(longest_common_substring(reply).keys())[0]
|
||||
if len(lcs) >= 100:
|
||||
return True
|
||||
|
||||
if reply_length(reply) < 10:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def replace_rdrama_images(text):
|
||||
"""Replace images pointing to rdrama.net with a loading image."""
|
||||
loading = "https://i.rdrama.net/i/l.webp"
|
||||
webp_pattern = r"https://\S*\.rdrama\.net/\S*\.webp"
|
||||
md_img_pattern = r"!\[[^\]]*\]\((https://\S*\.rdrama\.net)?/\S*\)"
|
||||
|
||||
text = re.sub(webp_pattern, loading, text)
|
||||
text = re.sub(md_img_pattern, f"![]({loading})", text)
|
||||
return text
|
||||
|
||||
|
||||
def normalize_emojis(s):
|
||||
"""Bring # and ! to the front of an emoji."""
|
||||
|
||||
def repl(match):
|
||||
# Extract the word between colons and the special characters.
|
||||
word = match.group(0)
|
||||
specials = set(re.findall(r"[#!]", word))
|
||||
# Sort specials and append the word without specials.
|
||||
new_emoji = "".join(sorted(specials, reverse=True)) + re.sub(r"[#!]", "", word)
|
||||
return new_emoji
|
||||
|
||||
emoji_pattern = r"(?<=:)[a-zA-Z@#!]*[#!][a-zA-Z@#!]*(?=:)"
|
||||
s = re.sub(emoji_pattern, repl, s)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def build_prompt(post, comments):
|
||||
prompt = (
|
||||
f"[Post] [Author] {post['author_name']} "
|
||||
f"[Title] {post['title']} [URL] {post['url']} "
|
||||
f"[Hole] {post['sub'] or 'N/A'} [Votes] +71 / -0\n\n"
|
||||
f"{post['body']}\n\n[Comments]"
|
||||
)
|
||||
|
||||
comments.append({"author_name": config["username"], "body": ""})
|
||||
|
||||
for depth, comment in enumerate(comments):
|
||||
body = normalize_emojis(comment["body"])
|
||||
author = comment["author_name"]
|
||||
|
||||
comment_str = f"\n\n{author} +45 / -0\n{body}"
|
||||
indent = depth * " "
|
||||
comment_str = "\n".join([indent + line for line in comment_str.split("\n")])
|
||||
prompt += comment_str
|
||||
|
||||
prompt = prompt.replace(config["username"], random.choice(config["fake_usernames"]))
|
||||
prompt = prompt.strip() + "\n"
|
||||
|
||||
# Truncate the prompt to leave room for generation.
|
||||
tokens = tokenizer.tokenize(prompt)
|
||||
if len(tokens) > config["prompt_token_limit"]:
|
||||
tokens = tokens[-config["prompt_token_limit"] :]
|
||||
prompt = tokenizer.convert_tokens_to_string(tokens)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def log_prompt(prompt):
|
||||
with open(f"{config['data_dir']}/prompts.txt", "a") as f:
|
||||
f.write(f"{prompt}\n==========\n")
|
||||
|
||||
|
||||
def reply_length(reply):
|
||||
"""Return the length of the reply, without Markdown images, URLs, or quoted text."""
|
||||
# Remove Markdown images and URLs.
|
||||
reply = re.sub(r"!\[.*?\]\(.*?\)", "", reply)
|
||||
reply = re.sub(URL_REGEX, "", reply)
|
||||
# Remove quoted text.
|
||||
lines = reply.splitlines()
|
||||
lines = [line for line in lines if not line.lstrip().startswith((">", "\\>"))]
|
||||
reply = "\n".join(lines).strip()
|
||||
|
||||
return len(reply)
|
||||
|
||||
|
||||
def count_tokens(text):
|
||||
return len(tokenizer(text).input_ids)
|
||||
|
||||
|
||||
def extract_reply(text):
|
||||
"""
|
||||
Generated text will either:
|
||||
- Be cut off at the token limit
|
||||
- End with the start of a new comment: `float-trip +10`
|
||||
For the latter case, drop the last line.
|
||||
"""
|
||||
pattern = r"^ *[\w-]* +\+.*$"
|
||||
lines = text.split("\n")
|
||||
if re.match(pattern, lines[-1]):
|
||||
lines = lines[:-1]
|
||||
return "\n".join([line.strip() for line in lines]).strip()
|
Loading…
Reference in New Issue