commit 6a475d584d0542982777c27e212fbb54bf34230f Author: James Shiffer Date: Sun Dec 26 11:24:52 2021 -0800 Initial commit diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..005fa5a --- /dev/null +++ b/.editorconfig @@ -0,0 +1,89 @@ +root = true + +[*] +charset = utf-8 +end_of_line = lf +indent_size = 4 +indent_style = space +insert_final_newline = false +max_line_length = 120 +tab_width = 4 +ij_continuation_indent_size = 8 +ij_formatter_off_tag = @formatter:off +ij_formatter_on_tag = @formatter:on +ij_formatter_tags_enabled = false +ij_smart_tabs = false +ij_visual_guides = none +ij_wrap_on_typing = false + +[{*.py,*.pyw}] +ij_python_align_collections_and_comprehensions = true +ij_python_align_multiline_imports = true +ij_python_align_multiline_parameters = true +ij_python_align_multiline_parameters_in_calls = true +ij_python_blank_line_at_file_end = true +ij_python_blank_lines_after_imports = 1 +ij_python_blank_lines_after_local_imports = 0 +ij_python_blank_lines_around_class = 1 +ij_python_blank_lines_around_method = 1 +ij_python_blank_lines_around_top_level_classes_functions = 2 +ij_python_blank_lines_before_first_method = 0 +ij_python_call_parameters_new_line_after_left_paren = false +ij_python_call_parameters_right_paren_on_new_line = false +ij_python_call_parameters_wrap = normal +ij_python_dict_alignment = 0 +ij_python_dict_new_line_after_left_brace = false +ij_python_dict_new_line_before_right_brace = false +ij_python_dict_wrapping = 1 +ij_python_from_import_new_line_after_left_parenthesis = false +ij_python_from_import_new_line_before_right_parenthesis = false +ij_python_from_import_parentheses_force_if_multiline = false +ij_python_from_import_trailing_comma_if_multiline = false +ij_python_from_import_wrapping = 1 +ij_python_hang_closing_brackets = false +ij_python_keep_blank_lines_in_code = 1 +ij_python_keep_blank_lines_in_declarations = 1 +ij_python_keep_indents_on_empty_lines = false +ij_python_keep_line_breaks = true +ij_python_method_parameters_new_line_after_left_paren = false +ij_python_method_parameters_right_paren_on_new_line = false +ij_python_method_parameters_wrap = normal +ij_python_new_line_after_colon = false +ij_python_new_line_after_colon_multi_clause = true +ij_python_optimize_imports_always_split_from_imports = false +ij_python_optimize_imports_case_insensitive_order = false +ij_python_optimize_imports_join_from_imports_with_same_source = false +ij_python_optimize_imports_sort_by_type_first = true +ij_python_optimize_imports_sort_imports = true +ij_python_optimize_imports_sort_names_in_from_imports = false +ij_python_space_after_comma = true +ij_python_space_after_number_sign = true +ij_python_space_after_py_colon = true +ij_python_space_before_backslash = true +ij_python_space_before_comma = false +ij_python_space_before_for_semicolon = false +ij_python_space_before_lbracket = false +ij_python_space_before_method_call_parentheses = false +ij_python_space_before_method_parentheses = false +ij_python_space_before_number_sign = true +ij_python_space_before_py_colon = false +ij_python_space_within_empty_method_call_parentheses = false +ij_python_space_within_empty_method_parentheses = false +ij_python_spaces_around_additive_operators = true +ij_python_spaces_around_assignment_operators = true +ij_python_spaces_around_bitwise_operators = true +ij_python_spaces_around_eq_in_keyword_argument = false +ij_python_spaces_around_eq_in_named_parameter = false +ij_python_spaces_around_equality_operators = true +ij_python_spaces_around_multiplicative_operators = true +ij_python_spaces_around_power_operator = true +ij_python_spaces_around_relational_operators = true +ij_python_spaces_around_shift_operators = true +ij_python_spaces_within_braces = false +ij_python_spaces_within_brackets = false +ij_python_spaces_within_method_call_parentheses = false +ij_python_spaces_within_method_parentheses = false +ij_python_use_continuation_indent_for_arguments = false +ij_python_use_continuation_indent_for_collection_and_comprehensions = false +ij_python_use_continuation_indent_for_parameters = true +ij_python_wrap_long_lines = false diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..ebf52c2 --- /dev/null +++ b/.env.example @@ -0,0 +1,4 @@ +# the bot's token +TOKEN= +# your user token, if you want to scrape messages +USER_TOKEN= \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bfd2757 --- /dev/null +++ b/.gitignore @@ -0,0 +1,213 @@ +### JetBrains template +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/**/usage.statistics.xml +.idea/**/dictionaries +.idea/**/shelf + +# Generated files +.idea/**/contentModel.xml + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# Gradle and Maven with auto-import +# When using Gradle or Maven with auto-import, you should exclude module files, +# since they will be recreated, and may cause churn. Uncomment if using +# auto-import. +# .idea/artifacts +# .idea/compiler.xml +# .idea/jarRepositories.xml +# .idea/modules.xml +# .idea/*.iml +# .idea/modules +# *.iml +# *.ipr + +# CMake +cmake-build-*/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests + +# Android studio 3.1+ serialized cache file +.idea/caches/build_file_checksums.ser + +### Python template +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/encodings.xml b/.idea/encodings.xml new file mode 100644 index 0000000..5820c9b --- /dev/null +++ b/.idea/encodings.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..53a6d78 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,14 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/miku.iml b/.idea/miku.iml new file mode 100644 index 0000000..9fe3b52 --- /dev/null +++ b/.idea/miku.iml @@ -0,0 +1,14 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..f5c1a18 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..fd39f8e --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..6d1c17e --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# miku + +Discord bot/companion for the group chatte, powered by the GPT-J language model and modified with a soft prompt to understand all of our esoteric, elaborate inside jokes. diff --git a/chats/.gitignore b/chats/.gitignore new file mode 100644 index 0000000..c96a04f --- /dev/null +++ b/chats/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..92700e5 Binary files /dev/null and b/requirements.txt differ diff --git a/src/miku/__init__.py b/src/miku/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/miku/__main__.py b/src/miku/__main__.py new file mode 100644 index 0000000..99a8ed1 --- /dev/null +++ b/src/miku/__main__.py @@ -0,0 +1,7 @@ +from dotenv import load_dotenv +from pathlib import Path +import os +from .bot import boot +load_dotenv(Path.cwd() / '..' / '.env') +token = os.getenv('TOKEN') +boot(token) diff --git a/src/miku/bot.py b/src/miku/bot.py new file mode 100644 index 0000000..902b224 --- /dev/null +++ b/src/miku/bot.py @@ -0,0 +1,64 @@ +import discord +import logging +import sys +from typing import List, Optional + +from .model import GPT2Model, Model + +discord_logger = logging.getLogger('discord') +discord_logger.setLevel(logging.WARNING) + + +class HatsuneMikuBot(discord.Client): + + TRIGGER_PHRASE = 'miku' + CONTEXT_MSGS = 5 + + def __init__(self, **options): + super().__init__(**options) + self.log = logging.getLogger('miku') + self.log.setLevel(logging.DEBUG) + self.log_handler = logging.StreamHandler() + self.log_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s:%(name)s: %(message)s')) + self.log.addHandler(self.log_handler) + self.model: Optional[Model] = None + + async def on_ready(self): + self.log.info(f'Logged in as {self.user}') + self.model = GPT2Model(self.user.name) + + def should_respond(self, message: discord.Message) -> bool: + return HatsuneMikuBot.TRIGGER_PHRASE in message.content.lower() or self.user in message.mentions + + @staticmethod + def parse_message(message: discord.Message) -> str: + return f'{message.author.name}: {message.content}' + + async def fetch_message_context(self, chan: discord.TextChannel) -> List[str]: + msgs = await chan.history(limit=HatsuneMikuBot.CONTEXT_MSGS).flatten() + msgs_txt = [HatsuneMikuBot.parse_message(msg) for msg in msgs] + msgs_txt.reverse() + self.log.debug('Found context:\n' + '\n'.join(msgs_txt)) + return list(filter(None, msgs_txt)) + + async def on_message(self, message: discord.Message): + author: discord.User = message.author + if author == self.user: + return + + if self.should_respond(message): + self.log.debug(f'Message {message.id} from {author.name}: {message.content}') + chan: discord.TextChannel = message.channel + await chan.trigger_typing() + self.log.debug(f'Fetching context for {message.id}...') + prompt = await self.fetch_message_context(chan) + self.log.debug(f'Generating inference for {message.id}...') + response = self.model.infer(prompt) + await message.reply(response) + + +def boot(token: str): + if not token: + logging.critical('No bot token supplied!') + sys.exit(1) + HatsuneMikuBot().run(token) diff --git a/src/miku/model.py b/src/miku/model.py new file mode 100644 index 0000000..ed3611c --- /dev/null +++ b/src/miku/model.py @@ -0,0 +1,50 @@ +import logging + +from transformers import GPT2Tokenizer, GPTNeoModel, GPTNeoForCausalLM +from typing import List +from abc import ABC, abstractmethod + +trans_logger = logging.getLogger('transformers') +trans_logger.setLevel(logging.WARNING) + + +class Model(ABC): + @abstractmethod + def infer(self, prompt: List[str]) -> str: + """ + Generates an inference. + """ + return '' + + +class GPT2Model(Model): + def __init__(self, username: str): + self.log = logging.getLogger('miku') + self.model: GPTNeoModel = GPTNeoForCausalLM.from_pretrained('iokru/c1-1.3B').half().to('cuda') + self.tokenizer = GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-1.3B') + self.log.info('Model loaded.') + self.username = username + + def infer(self, prompt: List[str]) -> str: + # i was getting japanese results because the bot is named Hatsune Miku. this is a workaround + #prompt = [msg.replace(self.username, '<|ME|>') for msg in prompt] + #prompt.append(f'<|ME|>: ') + prompt.append(f'{self.username}:') + flattened_prompt = self.tokenizer.eos_token.join(prompt) + + result = self.model.generate( + input_ids=self.tokenizer.encode(flattened_prompt, return_tensors="pt").cuda(), + do_sample=True, + min_length=1, + max_length=100, + temperature=0.6, + tfs=0.993, + repetition_penalty=3.0, + pad_token_id=self.tokenizer.eos_token_id, + #bad_words_ids=config.bad_words_ids + ) + + inferred_txt = self.tokenizer.decode(result[0])[len(flattened_prompt):] + self.log.debug(inferred_txt) + # TODO: multiline inferences + return inferred_txt.split('\n')[0] diff --git a/src/scraper/__init__.py b/src/scraper/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/scraper/__main__.py b/src/scraper/__main__.py new file mode 100644 index 0000000..9177783 --- /dev/null +++ b/src/scraper/__main__.py @@ -0,0 +1,7 @@ +from dotenv import load_dotenv +from pathlib import Path +import os +from .scraper import boot +load_dotenv(Path.cwd() / '..' / '.env') +token = os.getenv('USER_TOKEN') +boot(token) diff --git a/src/scraper/scraper.py b/src/scraper/scraper.py new file mode 100644 index 0000000..26251f6 --- /dev/null +++ b/src/scraper/scraper.py @@ -0,0 +1,120 @@ +import logging +import requests +from pathlib import Path +from time import sleep, time +from typing import List, Optional + +logging.basicConfig(level=logging.INFO) + + +class Scraper: + """ + Scrapes the full Discord message history for one channel. + """ + + # milliseconds between every 25 requests + RATE_LIMIT = 1000 + + def __init__(self, token: str, channel: str, export: Path): + self.token: str = token + self.channel: str = channel + # the oldest message ID we've encountered so far, used for pagination + self.oldest_message: Optional[str] = None + # used in rate-limiting + self.requests_made: int = 0 + # the messages we want go in here + self.messages: List[str] = [] + # export file directory + self.export_dir: Path = export + # last exported line + self.export_line: int = 0 + # export start time + self.start_time: Optional[float] = None + + @staticmethod + def parse_message(msg: dict) -> Optional[str]: + """ + Parses a message's metadata to get just the relevant text, in "sender: text" format. + Returns None if no valuable content. + """ + if len(msg['content']) == 0: + return None + # TODO: remove stuff like links + return f"{msg['author']['username']}: {msg['content']}" + + def fetch_next_page(self, limit: int = 50) -> List[dict]: + url = f'https://discord.com/api/v9/channels/{self.channel}/messages?limit={limit}' + if self.oldest_message: + url += f'&before={self.oldest_message}' + + while True: + if self.requests_made > 0 and self.requests_made % 25 == 0: + logging.info(f'[fetch_next_page] Waiting {self.RATE_LIMIT} ms.') + sleep(self.RATE_LIMIT / 1000) + + res = requests.get(url, headers={ + 'Authorization': self.token, + 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) ' + 'discord/0.0.264 Chrome/91.0.4472.164 Electron/13.4.0 Safari/537.36' + }) + self.requests_made += 1 + + if res.status_code == 429: + # our 1000 ms delay may not be perfect all the time. + # just wait and try again + logging.warning(f'[fetch_next_page] got 429, trying this page again') + continue + elif res.status_code != 200: + logging.error(f'[fetch_next_page] Request to {url} failed with status code {res.status_code}!') + raise ConnectionError + else: + messages = res.json() + self.oldest_message = messages[-1]['id'] + return messages + + def scrape(self): + flag = True + limit = 50 + total_messages = 0 + page_count = 0 + self.start_time = time() + try: + while flag: + next_batch = self.fetch_next_page(limit) + page_count += 1 + total_messages += len(next_batch) + for msg in next_batch: + parsed = self.parse_message(msg) + if parsed: + self.messages.append(parsed) + logging.info(f'[scrape] pg {page_count}, parsed {total_messages}, kept {len(self.messages)} messages') + flag = len(next_batch) == limit + except KeyboardInterrupt: + logging.warning('[scrape] user wants to abort, stopping early.') + pass + end_time = time() + logging.info(f'[scrape] completed scraping this channel in {end_time - self.start_time:.1f}s.') + + def export(self): + if not self.start_time: + # nothing to export + logging.warning(f'[export] scraping has not begun yet; skipping.') + return + file_path = self.export_dir / f"{self.channel}_{self.start_time:.0f}.txt" + with open(file_path, 'wb') as file: + # messages are captured in reverse chronological order; we need to fix that + # this also means that we can't really do partial exports + joined_msgs = '\n'.join(self.messages[::-1]) + file.write(joined_msgs.encode('utf-8')) + logging.info(f'[export] saved {len(self.messages)} lines.') + + +def boot(token: str): + if not token: + token = input('Enter your Discord user token (Authorization request header): ') + channel = input('Enter channel ID: ') + default_export = Path.cwd().parent / 'chats' + export = input('Enter path to export transcripts (default "chats"): ') + scraper = Scraper(token, channel, Path(export) if export else default_export) + scraper.scrape() + scraper.export()