Initial commit

This commit is contained in:
James Shiffer 2021-12-26 11:24:52 -08:00
commit 6a475d584d
21 changed files with 625 additions and 0 deletions

89
.editorconfig Normal file
View File

@ -0,0 +1,89 @@
root = true
[*]
charset = utf-8
end_of_line = lf
indent_size = 4
indent_style = space
insert_final_newline = false
max_line_length = 120
tab_width = 4
ij_continuation_indent_size = 8
ij_formatter_off_tag = @formatter:off
ij_formatter_on_tag = @formatter:on
ij_formatter_tags_enabled = false
ij_smart_tabs = false
ij_visual_guides = none
ij_wrap_on_typing = false
[{*.py,*.pyw}]
ij_python_align_collections_and_comprehensions = true
ij_python_align_multiline_imports = true
ij_python_align_multiline_parameters = true
ij_python_align_multiline_parameters_in_calls = true
ij_python_blank_line_at_file_end = true
ij_python_blank_lines_after_imports = 1
ij_python_blank_lines_after_local_imports = 0
ij_python_blank_lines_around_class = 1
ij_python_blank_lines_around_method = 1
ij_python_blank_lines_around_top_level_classes_functions = 2
ij_python_blank_lines_before_first_method = 0
ij_python_call_parameters_new_line_after_left_paren = false
ij_python_call_parameters_right_paren_on_new_line = false
ij_python_call_parameters_wrap = normal
ij_python_dict_alignment = 0
ij_python_dict_new_line_after_left_brace = false
ij_python_dict_new_line_before_right_brace = false
ij_python_dict_wrapping = 1
ij_python_from_import_new_line_after_left_parenthesis = false
ij_python_from_import_new_line_before_right_parenthesis = false
ij_python_from_import_parentheses_force_if_multiline = false
ij_python_from_import_trailing_comma_if_multiline = false
ij_python_from_import_wrapping = 1
ij_python_hang_closing_brackets = false
ij_python_keep_blank_lines_in_code = 1
ij_python_keep_blank_lines_in_declarations = 1
ij_python_keep_indents_on_empty_lines = false
ij_python_keep_line_breaks = true
ij_python_method_parameters_new_line_after_left_paren = false
ij_python_method_parameters_right_paren_on_new_line = false
ij_python_method_parameters_wrap = normal
ij_python_new_line_after_colon = false
ij_python_new_line_after_colon_multi_clause = true
ij_python_optimize_imports_always_split_from_imports = false
ij_python_optimize_imports_case_insensitive_order = false
ij_python_optimize_imports_join_from_imports_with_same_source = false
ij_python_optimize_imports_sort_by_type_first = true
ij_python_optimize_imports_sort_imports = true
ij_python_optimize_imports_sort_names_in_from_imports = false
ij_python_space_after_comma = true
ij_python_space_after_number_sign = true
ij_python_space_after_py_colon = true
ij_python_space_before_backslash = true
ij_python_space_before_comma = false
ij_python_space_before_for_semicolon = false
ij_python_space_before_lbracket = false
ij_python_space_before_method_call_parentheses = false
ij_python_space_before_method_parentheses = false
ij_python_space_before_number_sign = true
ij_python_space_before_py_colon = false
ij_python_space_within_empty_method_call_parentheses = false
ij_python_space_within_empty_method_parentheses = false
ij_python_spaces_around_additive_operators = true
ij_python_spaces_around_assignment_operators = true
ij_python_spaces_around_bitwise_operators = true
ij_python_spaces_around_eq_in_keyword_argument = false
ij_python_spaces_around_eq_in_named_parameter = false
ij_python_spaces_around_equality_operators = true
ij_python_spaces_around_multiplicative_operators = true
ij_python_spaces_around_power_operator = true
ij_python_spaces_around_relational_operators = true
ij_python_spaces_around_shift_operators = true
ij_python_spaces_within_braces = false
ij_python_spaces_within_brackets = false
ij_python_spaces_within_method_call_parentheses = false
ij_python_spaces_within_method_parentheses = false
ij_python_use_continuation_indent_for_arguments = false
ij_python_use_continuation_indent_for_collection_and_comprehensions = false
ij_python_use_continuation_indent_for_parameters = true
ij_python_wrap_long_lines = false

4
.env.example Normal file
View File

@ -0,0 +1,4 @@
# the bot's token
TOKEN=
# your user token, if you want to scrape messages
USER_TOKEN=

213
.gitignore vendored Normal file
View File

@ -0,0 +1,213 @@
### JetBrains template
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
# User-specific stuff
.idea/**/workspace.xml
.idea/**/tasks.xml
.idea/**/usage.statistics.xml
.idea/**/dictionaries
.idea/**/shelf
# Generated files
.idea/**/contentModel.xml
# Sensitive or high-churn files
.idea/**/dataSources/
.idea/**/dataSources.ids
.idea/**/dataSources.local.xml
.idea/**/sqlDataSources.xml
.idea/**/dynamic.xml
.idea/**/uiDesigner.xml
.idea/**/dbnavigator.xml
# Gradle
.idea/**/gradle.xml
.idea/**/libraries
# Gradle and Maven with auto-import
# When using Gradle or Maven with auto-import, you should exclude module files,
# since they will be recreated, and may cause churn. Uncomment if using
# auto-import.
# .idea/artifacts
# .idea/compiler.xml
# .idea/jarRepositories.xml
# .idea/modules.xml
# .idea/*.iml
# .idea/modules
# *.iml
# *.ipr
# CMake
cmake-build-*/
# Mongo Explorer plugin
.idea/**/mongoSettings.xml
# File-based project format
*.iws
# IntelliJ
out/
# mpeltonen/sbt-idea plugin
.idea_modules/
# JIRA plugin
atlassian-ide-plugin.xml
# Cursive Clojure plugin
.idea/replstate.xml
# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties
# Editor-based Rest Client
.idea/httpRequests
# Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser
### Python template
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/

8
.idea/.gitignore generated vendored Normal file
View File

@ -0,0 +1,8 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

6
.idea/encodings.xml generated Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Encoding">
<file url="file://$PROJECT_DIR$/chats/252959763841679360_1640479428.txt" charset="windows-1252" />
</component>
</project>

View File

@ -0,0 +1,14 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="1">
<item index="0" class="java.lang.String" itemvalue="pysam" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>

View File

@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

14
.idea/miku.iml generated Normal file
View File

@ -0,0 +1,14 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<excludeFolder url="file://$MODULE_DIR$/venv" />
</content>
<orderEntry type="jdk" jdkName="Python 3.8 (miku)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
</module>

4
.idea/misc.xml generated Normal file
View File

@ -0,0 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (miku)" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml generated Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/miku.iml" filepath="$PROJECT_DIR$/.idea/miku.iml" />
</modules>
</component>
</project>

6
.idea/vcs.xml generated Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>

3
README.md Normal file
View File

@ -0,0 +1,3 @@
# miku
Discord bot/companion for the group chatte, powered by the GPT-J language model and modified with a soft prompt to understand all of our esoteric, elaborate inside jokes.

2
chats/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
*
!.gitignore

BIN
requirements.txt Normal file

Binary file not shown.

0
src/miku/__init__.py Normal file
View File

7
src/miku/__main__.py Normal file
View File

@ -0,0 +1,7 @@
from dotenv import load_dotenv
from pathlib import Path
import os
from .bot import boot
load_dotenv(Path.cwd() / '..' / '.env')
token = os.getenv('TOKEN')
boot(token)

64
src/miku/bot.py Normal file
View File

@ -0,0 +1,64 @@
import discord
import logging
import sys
from typing import List, Optional
from .model import GPT2Model, Model
discord_logger = logging.getLogger('discord')
discord_logger.setLevel(logging.WARNING)
class HatsuneMikuBot(discord.Client):
TRIGGER_PHRASE = 'miku'
CONTEXT_MSGS = 5
def __init__(self, **options):
super().__init__(**options)
self.log = logging.getLogger('miku')
self.log.setLevel(logging.DEBUG)
self.log_handler = logging.StreamHandler()
self.log_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s:%(name)s: %(message)s'))
self.log.addHandler(self.log_handler)
self.model: Optional[Model] = None
async def on_ready(self):
self.log.info(f'Logged in as {self.user}')
self.model = GPT2Model(self.user.name)
def should_respond(self, message: discord.Message) -> bool:
return HatsuneMikuBot.TRIGGER_PHRASE in message.content.lower() or self.user in message.mentions
@staticmethod
def parse_message(message: discord.Message) -> str:
return f'{message.author.name}: {message.content}'
async def fetch_message_context(self, chan: discord.TextChannel) -> List[str]:
msgs = await chan.history(limit=HatsuneMikuBot.CONTEXT_MSGS).flatten()
msgs_txt = [HatsuneMikuBot.parse_message(msg) for msg in msgs]
msgs_txt.reverse()
self.log.debug('Found context:\n' + '\n'.join(msgs_txt))
return list(filter(None, msgs_txt))
async def on_message(self, message: discord.Message):
author: discord.User = message.author
if author == self.user:
return
if self.should_respond(message):
self.log.debug(f'Message {message.id} from {author.name}: {message.content}')
chan: discord.TextChannel = message.channel
await chan.trigger_typing()
self.log.debug(f'Fetching context for {message.id}...')
prompt = await self.fetch_message_context(chan)
self.log.debug(f'Generating inference for {message.id}...')
response = self.model.infer(prompt)
await message.reply(response)
def boot(token: str):
if not token:
logging.critical('No bot token supplied!')
sys.exit(1)
HatsuneMikuBot().run(token)

50
src/miku/model.py Normal file
View File

@ -0,0 +1,50 @@
import logging
from transformers import GPT2Tokenizer, GPTNeoModel, GPTNeoForCausalLM
from typing import List
from abc import ABC, abstractmethod
trans_logger = logging.getLogger('transformers')
trans_logger.setLevel(logging.WARNING)
class Model(ABC):
@abstractmethod
def infer(self, prompt: List[str]) -> str:
"""
Generates an inference.
"""
return ''
class GPT2Model(Model):
def __init__(self, username: str):
self.log = logging.getLogger('miku')
self.model: GPTNeoModel = GPTNeoForCausalLM.from_pretrained('iokru/c1-1.3B').half().to('cuda')
self.tokenizer = GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-1.3B')
self.log.info('Model loaded.')
self.username = username
def infer(self, prompt: List[str]) -> str:
# i was getting japanese results because the bot is named Hatsune Miku. this is a workaround
#prompt = [msg.replace(self.username, '<|ME|>') for msg in prompt]
#prompt.append(f'<|ME|>: ')
prompt.append(f'{self.username}:')
flattened_prompt = self.tokenizer.eos_token.join(prompt)
result = self.model.generate(
input_ids=self.tokenizer.encode(flattened_prompt, return_tensors="pt").cuda(),
do_sample=True,
min_length=1,
max_length=100,
temperature=0.6,
tfs=0.993,
repetition_penalty=3.0,
pad_token_id=self.tokenizer.eos_token_id,
#bad_words_ids=config.bad_words_ids
)
inferred_txt = self.tokenizer.decode(result[0])[len(flattened_prompt):]
self.log.debug(inferred_txt)
# TODO: multiline inferences
return inferred_txt.split('\n')[0]

0
src/scraper/__init__.py Normal file
View File

7
src/scraper/__main__.py Normal file
View File

@ -0,0 +1,7 @@
from dotenv import load_dotenv
from pathlib import Path
import os
from .scraper import boot
load_dotenv(Path.cwd() / '..' / '.env')
token = os.getenv('USER_TOKEN')
boot(token)

120
src/scraper/scraper.py Normal file
View File

@ -0,0 +1,120 @@
import logging
import requests
from pathlib import Path
from time import sleep, time
from typing import List, Optional
logging.basicConfig(level=logging.INFO)
class Scraper:
"""
Scrapes the full Discord message history for one channel.
"""
# milliseconds between every 25 requests
RATE_LIMIT = 1000
def __init__(self, token: str, channel: str, export: Path):
self.token: str = token
self.channel: str = channel
# the oldest message ID we've encountered so far, used for pagination
self.oldest_message: Optional[str] = None
# used in rate-limiting
self.requests_made: int = 0
# the messages we want go in here
self.messages: List[str] = []
# export file directory
self.export_dir: Path = export
# last exported line
self.export_line: int = 0
# export start time
self.start_time: Optional[float] = None
@staticmethod
def parse_message(msg: dict) -> Optional[str]:
"""
Parses a message's metadata to get just the relevant text, in "sender: text" format.
Returns None if no valuable content.
"""
if len(msg['content']) == 0:
return None
# TODO: remove stuff like links
return f"{msg['author']['username']}: {msg['content']}"
def fetch_next_page(self, limit: int = 50) -> List[dict]:
url = f'https://discord.com/api/v9/channels/{self.channel}/messages?limit={limit}'
if self.oldest_message:
url += f'&before={self.oldest_message}'
while True:
if self.requests_made > 0 and self.requests_made % 25 == 0:
logging.info(f'[fetch_next_page] Waiting {self.RATE_LIMIT} ms.')
sleep(self.RATE_LIMIT / 1000)
res = requests.get(url, headers={
'Authorization': self.token,
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) '
'discord/0.0.264 Chrome/91.0.4472.164 Electron/13.4.0 Safari/537.36'
})
self.requests_made += 1
if res.status_code == 429:
# our 1000 ms delay may not be perfect all the time.
# just wait and try again
logging.warning(f'[fetch_next_page] got 429, trying this page again')
continue
elif res.status_code != 200:
logging.error(f'[fetch_next_page] Request to {url} failed with status code {res.status_code}!')
raise ConnectionError
else:
messages = res.json()
self.oldest_message = messages[-1]['id']
return messages
def scrape(self):
flag = True
limit = 50
total_messages = 0
page_count = 0
self.start_time = time()
try:
while flag:
next_batch = self.fetch_next_page(limit)
page_count += 1
total_messages += len(next_batch)
for msg in next_batch:
parsed = self.parse_message(msg)
if parsed:
self.messages.append(parsed)
logging.info(f'[scrape] pg {page_count}, parsed {total_messages}, kept {len(self.messages)} messages')
flag = len(next_batch) == limit
except KeyboardInterrupt:
logging.warning('[scrape] user wants to abort, stopping early.')
pass
end_time = time()
logging.info(f'[scrape] completed scraping this channel in {end_time - self.start_time:.1f}s.')
def export(self):
if not self.start_time:
# nothing to export
logging.warning(f'[export] scraping has not begun yet; skipping.')
return
file_path = self.export_dir / f"{self.channel}_{self.start_time:.0f}.txt"
with open(file_path, 'wb') as file:
# messages are captured in reverse chronological order; we need to fix that
# this also means that we can't really do partial exports
joined_msgs = '\n'.join(self.messages[::-1])
file.write(joined_msgs.encode('utf-8'))
logging.info(f'[export] saved {len(self.messages)} lines.')
def boot(token: str):
if not token:
token = input('Enter your Discord user token (Authorization request header): ')
channel = input('Enter channel ID: ')
default_export = Path.cwd().parent / 'chats'
export = input('Enter path to export transcripts (default "chats"): ')
scraper = Scraper(token, channel, Path(export) if export else default_export)
scraper.scrape()
scraper.export()