Pick the reply with the median length + switch from shelve to SqliteDict.
parent
e549796935
commit
e0c7556344
25
bot.py
25
bot.py
|
@ -7,7 +7,7 @@ import model
|
|||
import utils
|
||||
from client import DramaClient
|
||||
|
||||
from config import config
|
||||
from data import config, db
|
||||
|
||||
traceback.install()
|
||||
|
||||
|
@ -24,7 +24,7 @@ class Bot:
|
|||
|
||||
comments = [
|
||||
c
|
||||
for c in self.client.fetch_new_comments(limit=50)
|
||||
for c in self.client.fetch_new_comments()
|
||||
if "author_name" in c
|
||||
and not c["is_bot"]
|
||||
and c["author_name"] != config["username"]
|
||||
|
@ -71,13 +71,13 @@ class Bot:
|
|||
|
||||
if not post or not thread_comments:
|
||||
print("Could not fetch context!")
|
||||
return
|
||||
|
||||
prompt = utils.build_prompt(post, thread_comments)
|
||||
utils.log_prompt(prompt)
|
||||
|
||||
candidates = []
|
||||
num_candidates = config["num_candidates"] if random.random() < 0.6 else 1
|
||||
while len(candidates) < num_candidates:
|
||||
rejects = []
|
||||
while len(candidates) < config["num_candidates"]:
|
||||
gen_text = self.model.generate(prompt)
|
||||
reply = utils.extract_reply(gen_text)
|
||||
print(f"Generated text: {gen_text}\nReply:\n{reply}")
|
||||
|
@ -85,14 +85,25 @@ class Bot:
|
|||
|
||||
if len(reply) == 0:
|
||||
print("Retrying: reply empty after processing.")
|
||||
rejects.append(reply)
|
||||
elif utils.is_low_quality(reply, post, thread_comments):
|
||||
print("Retrying: low quality reply.")
|
||||
rejects
|
||||
else:
|
||||
candidates.append(reply)
|
||||
print("Accepting reply.")
|
||||
|
||||
# Get the longest reply, but cap the considered length at 500 chars.
|
||||
reply = max(candidates, key=lambda r: min(utils.reply_length(r), 500))
|
||||
reply = utils.median_by_key(candidates, key=utils.reply_length)
|
||||
|
||||
db["prompts"].append(
|
||||
{
|
||||
"prompt": prompt,
|
||||
"candidates": candidates,
|
||||
"rejects": rejects,
|
||||
"selected": reply,
|
||||
}
|
||||
)
|
||||
|
||||
self.client.reply(reply, comment)
|
||||
|
||||
|
||||
|
|
15
client.py
15
client.py
|
@ -7,16 +7,13 @@ import shelve
|
|||
|
||||
from requests.adapters import HTTPAdapter, Retry
|
||||
|
||||
from config import config
|
||||
from data import config, db
|
||||
|
||||
|
||||
class DramaClient:
|
||||
BASE_URL = "https://rdrama.net"
|
||||
|
||||
def __init__(self):
|
||||
self.db = shelve.open(f"{config['data_dir']}/client_state.p", writeback=True)
|
||||
self.db.setdefault("processed_replies", set())
|
||||
|
||||
self.session = requests.Session()
|
||||
retries = Retry(
|
||||
total=5, backoff_factor=5, status_forcelist=[500, 502, 503, 504, 521]
|
||||
|
@ -77,22 +74,21 @@ class DramaClient:
|
|||
|
||||
return r.json()
|
||||
|
||||
def fetch_new_comments(self, limit=0):
|
||||
def fetch_new_comments(self, limit=config["num_replies"] * 25):
|
||||
comments = []
|
||||
|
||||
last_processed_id = self.db.get("last_processed_id", -1)
|
||||
earliest_id = math.inf
|
||||
page = 1
|
||||
|
||||
# Fetch comments until we find the last one processed.
|
||||
while earliest_id > last_processed_id:
|
||||
while earliest_id > db["last_processed_id"]:
|
||||
page_comments = self.fetch_page(page)
|
||||
|
||||
if len(page_comments) == 0:
|
||||
break
|
||||
|
||||
earliest_id = min([c["id"] for c in page_comments])
|
||||
comments += [c for c in page_comments if c["id"] > last_processed_id]
|
||||
comments += [c for c in page_comments if c["id"] > db["last_processed_id"]]
|
||||
|
||||
if limit > 0 and len(comments) >= limit:
|
||||
break
|
||||
|
@ -102,8 +98,7 @@ class DramaClient:
|
|||
if not comments:
|
||||
return []
|
||||
|
||||
self.db["last_processed_id"] = max(c["id"] for c in comments)
|
||||
self.db.sync()
|
||||
db["last_processed_id"] = max(c["id"] for c in comments)
|
||||
|
||||
# New comments may have pushed others to page n+1 while fetching.
|
||||
deduped_comments = {c["id"]: c for c in comments}.values()
|
||||
|
|
|
@ -1,8 +0,0 @@
|
|||
import yaml
|
||||
import os
|
||||
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
config_path = os.path.join(current_dir, "config.yaml")
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
config = yaml.safe_load(f)
|
2
model.py
2
model.py
|
@ -7,7 +7,7 @@ from transformers import (
|
|||
AutoModelForCausalLM,
|
||||
LogitsProcessorList,
|
||||
)
|
||||
from config import config
|
||||
from data import config
|
||||
|
||||
|
||||
class StopAfterPlusIsGenerated(LogitsProcessor):
|
||||
|
|
|
@ -2,5 +2,6 @@ fuzzywuzzy==0.18.0
|
|||
PyYAML==6.0
|
||||
Requests==2.31.0
|
||||
rich==13.4.2
|
||||
sqlitedict==2.1.0
|
||||
torch==2.0.1
|
||||
transformers==4.31.0
|
||||
|
|
19
utils.py
19
utils.py
|
@ -4,7 +4,7 @@ import re
|
|||
from fuzzywuzzy import fuzz
|
||||
from transformers import GPTNeoXTokenizerFast
|
||||
|
||||
from config import config
|
||||
from data import config
|
||||
from maxsubstring import longest_common_substring
|
||||
|
||||
URL_REGEX = (
|
||||
|
@ -131,6 +131,7 @@ def build_prompt(post, comments):
|
|||
prompt += comment_str
|
||||
|
||||
prompt = prompt.replace(config["username"], random.choice(config["fake_usernames"]))
|
||||
prompt = prompt.replace("👻", "Ghost")
|
||||
prompt = prompt.strip() + "\n"
|
||||
|
||||
# Truncate the prompt to leave room for generation.
|
||||
|
@ -142,11 +143,6 @@ def build_prompt(post, comments):
|
|||
return prompt
|
||||
|
||||
|
||||
def log_prompt(prompt):
|
||||
with open(f"{config['data_dir']}/prompts.txt", "a") as f:
|
||||
f.write(f"{prompt}\n==========\n")
|
||||
|
||||
|
||||
def reply_length(reply):
|
||||
"""Return the length of the reply, without Markdown images, URLs, or quoted text."""
|
||||
# Remove Markdown images and URLs.
|
||||
|
@ -160,6 +156,17 @@ def reply_length(reply):
|
|||
return len(reply)
|
||||
|
||||
|
||||
def median_by_key(lst, key):
|
||||
lst = sorted(lst, key=key)
|
||||
mid_index = len(lst) // 2
|
||||
|
||||
# For lists of even length, pick either option as the median.
|
||||
if len(lst) % 2 == 0:
|
||||
return random.choice([lst[mid_index - 1], lst[mid_index]])
|
||||
else:
|
||||
return lst[mid_index]
|
||||
|
||||
|
||||
def count_tokens(text):
|
||||
return len(tokenizer(text).input_ids)
|
||||
|
||||
|
|
Loading…
Reference in New Issue