James Shiffer
8ef7a03895
changed some defaults; added and then decided to drop repetition penalty related hyperparameters; fixed prompt formatting
150 lines
5.7 KiB
TypeScript
150 lines
5.7 KiB
TypeScript
import { Message } from 'discord.js';
|
|
import { LLMProvider } from './provider';
|
|
import { HfInference } from "@huggingface/inference"
|
|
import 'dotenv/config';
|
|
import { serializeMessageHistory } from '../util';
|
|
import { logError, logInfo } from '../../logging';
|
|
import { LLMConfig } from '../commands/types';
|
|
|
|
|
|
const RESPONSE_REGEX = `\\{"timestamp":"(Sun|Mon|Tue|Wed|Thu|Fri|Sat), \\d{2} (Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec) \\d{4} \\d{2}:\\d{2}:\\d{2} GMT","author":"Hatsune Miku#1740","name":"Hatsune Miku","context":"([^"\\\\]|\\\\.)*","content":"([^"\\\\]|\\\\.)*"(,"reactions":("(:\\w+: \\(\\d+\\)(, )?)*"|null))?\\}`;
|
|
|
|
const RESPONSE_SCHEMA = {
|
|
"properties": {
|
|
"timestamp": {
|
|
"description": "When the message was sent, in RFC 7231 format",
|
|
"title": "Timestamp",
|
|
"type": "string"
|
|
},
|
|
"author": {
|
|
"description": "The author's username, which may be one of the following, or something else: \"vinso1445\", \"f0oby\", \"1thinker\", \"scoliono\", \"ahjc\", \"cinnaba\", \"M6481\", \"hypadrive\", \"need_correction\", \"Hatsune Miku#1740\" (You)",
|
|
"title": "Author",
|
|
"type": "string"
|
|
},
|
|
"name": {
|
|
"anyOf": [
|
|
{"type": "string"},
|
|
{"type": "null"}
|
|
],
|
|
"description": "The author's real name, which may be blank or one of the following: \"Vincent Iannelli\", \"Myles Linden\", \"Samuel Habib\", \"James Shiffer\", \"Alex\", \"Jinsung Park\", \"Lawrence Liu\", \"Nazar Khan\", \"Ethan Cheng\", \"Hatsune Miku\" (You)",
|
|
"title": "Name"
|
|
},
|
|
"context": {
|
|
"anyOf": [
|
|
{"type": "string"},
|
|
{"type": "null"}
|
|
],
|
|
"default": null,
|
|
"description": "The contents of the message being replied to, if this message is a reply",
|
|
"title": "Context"
|
|
},
|
|
"content": {
|
|
"description": "The text content of this message",
|
|
"title": "Content",
|
|
"type": "string"
|
|
},
|
|
"reactions": {
|
|
"anyOf": [
|
|
{"type": "string"},
|
|
{"type": "null"}
|
|
],
|
|
"default": null,
|
|
"description": "Optional list of emoji reactions this message received, if any. The following comma-separated format is used: \":skull: (3), :100: (1)\"",
|
|
"title": "Reactions"
|
|
}
|
|
},
|
|
"required": [
|
|
"timestamp",
|
|
"author",
|
|
"name",
|
|
"content"
|
|
]
|
|
};
|
|
|
|
const USER_PROMPT = `Continue the following Discord conversation by completing the next message, playing the role of Hatsune Miku. The conversation must progress forward, and you must avoid repeating yourself.
|
|
|
|
Each message is represented as a line of JSON. Refer to other users by their "name" instead of their "author" field whenever possible.
|
|
|
|
The conversation is as follows. The last line is the message you have to complete. Please ONLY return the string contents of the "content" field, that go in place of the ellipses. Do not include the enclosing quotation marks in your response.
|
|
|
|
`;
|
|
|
|
|
|
export class HuggingfaceProvider implements LLMProvider
|
|
{
|
|
private client: HfInference;
|
|
private model: string;
|
|
|
|
constructor(hf_token: string | undefined = process.env.HF_TOKEN, model = "meta-llama/Llama-3.2-3B-Instruct")
|
|
{
|
|
if (!hf_token) {
|
|
throw new TypeError("Huggingface API token was not passed in, and environment variable HF_TOKEN was unset!");
|
|
}
|
|
this.client = new HfInference(hf_token);
|
|
this.model = model;
|
|
}
|
|
|
|
name() {
|
|
return 'HuggingFace API: ' + this.model;
|
|
}
|
|
|
|
async requestLLMResponse(history: Message[], sysprompt: string, params: LLMConfig): Promise<string>
|
|
{
|
|
let messageList = await Promise.all(
|
|
history.map(serializeMessageHistory)
|
|
);
|
|
messageList = messageList.filter(x => !!x);
|
|
|
|
if (messageList.length === 0) {
|
|
throw new TypeError("No messages with content provided in history!");
|
|
}
|
|
|
|
// dummy message for last line of prompt
|
|
const lastMsg = messageList[messageList.length - 1];
|
|
|
|
// advance by 5 seconds
|
|
let newDate = new Date(lastMsg!.timestamp);
|
|
newDate.setSeconds(newDate.getSeconds() + 5);
|
|
|
|
let templateMsgTxt = JSON.stringify({
|
|
timestamp: newDate.toUTCString(),
|
|
author: "Hatsune Miku",
|
|
name: "Hatsune Miku",
|
|
context: lastMsg!.content,
|
|
content: "..."
|
|
});
|
|
|
|
const messageHistoryTxt = messageList.map(msg => JSON.stringify(msg)).join('\n') + '\n' + templateMsgTxt;
|
|
logInfo(`[hf] Requesting response for message history: ${messageHistoryTxt}`);
|
|
|
|
try {
|
|
const chatCompletion = await this.client.chatCompletion({
|
|
model: this.model,
|
|
messages: [
|
|
{ role: "system", content: sysprompt },
|
|
{ role: "user", content: USER_PROMPT + messageHistoryTxt }
|
|
],
|
|
temperature: params?.temperature || 0.5,
|
|
top_p: params?.top_p || 0.9,
|
|
max_tokens: params?.max_new_tokens || 128,
|
|
/*response_format: {
|
|
type: "regex",
|
|
value: String(RESPONSE_REGEX)
|
|
}*/
|
|
});
|
|
|
|
let response = chatCompletion.choices[0].message.content;
|
|
logInfo(`[hf] API response: ${response}`);
|
|
|
|
if (!response) {
|
|
throw new TypeError("HuggingFace completion API returned no message.");
|
|
}
|
|
|
|
return response;
|
|
} catch (err) {
|
|
logError(`[hf] API Error: ` + err);
|
|
throw err;
|
|
}
|
|
}
|
|
}
|