Stream tensor deserialization + add block list.

master
float-trip 2023-07-22 17:03:30 +00:00
parent 826a631fd1
commit a06c66dff4
2 changed files with 23 additions and 9 deletions

15
bot.py
View File

@ -1,5 +1,6 @@
import os
import random
import sys
import requests
from rich import traceback
@ -28,6 +29,7 @@ class Bot:
if "author_name" in c
and not c["is_bot"]
and c["author_name"] != config["username"]
and c["author_id"] not in config["ignore_user_ids"]
and c["post_id"] != 0
]
@ -48,7 +50,10 @@ class Bot:
continue
if not reply["is_bot"]:
self.reply(reply)
try:
self.reply(reply)
except requests.exceptions.RequestException as e:
print(f"Error while replying: {e}")
def make_forced_replies(self):
file_path = f"{config['data_dir']}/forced.txt"
@ -60,7 +65,10 @@ class Bot:
for comment_id in lines:
comment = self.client.get(f"/comment/{comment_id}")
self.reply(comment)
try:
self.reply(comment)
except requests.exceptions.RequestException as e:
print(f"Error while replying: {e}")
os.remove(file_path)
@ -121,5 +129,6 @@ if __name__ == "__main__":
bot = Bot()
bot.make_forced_replies()
bot.respond_to_replies()
bot.post_random_reply()
if len(sys.argv) < 2 or sys.argv[1] != "reply":
bot.post_random_reply()
bot.respond_to_replies()

View File

@ -1,5 +1,6 @@
import torch
import transformers
from tensorizer import TensorDeserializer
from tensorizer.utils import no_init_or_tensor, convert_bytes, get_mem_usage
from transformers import (
GPTNeoXTokenizerFast,
LogitsProcessor,
@ -39,12 +40,16 @@ class Model:
model_config.init_device = "cuda:0"
model_config.eos_token_id = self.tokenizer.eos_token_id
self.model = AutoModelForCausalLM.from_pretrained(
name,
config=model_config,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
with no_init_or_tensor():
self.model = AutoModelForCausalLM.from_config(
model_config, trust_remote_code=True
)
deserializer = TensorDeserializer(
f"{config['data_dir']}/drama.tensors", plaid_mode=True
)
deserializer.load_into_module(self.model)
self.model.eval()
self.logits_processor = LogitsProcessorList(
[StopAfterPlusIsGenerated(559, self.model.config.eos_token_id)]