MikuAI/model.py
2024-03-31 19:52:44 +00:00

39 lines
1.2 KiB
Python

from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
from transformers import TextStreamer
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "scoliono/groupchat_lora",
max_seq_length = 2048,
dtype = None,
load_in_4bit = True,
)
FastLanguageModel.for_inference(model) # Enable native 2x faster inference
tokenizer = get_chat_template(
tokenizer,
chat_template = "chatml", # Supports zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, unsloth
map_eos_token = True, # Maps <|im_end|> to </s> instead
)
def inference(messages, max_new_tokens=64, temperature=0.9, repetition_penalty=1.2):
inputs = tokenizer.apply_chat_template(
messages,
tokenize = True,
add_generation_prompt = True, # Must add for generation
return_tensors = "pt",
).to("cuda")
#text_streamer = TextStreamer(tokenizer)
token_ids = model.generate(
input_ids = inputs,
#streamer = text_streamer,
max_new_tokens = max_new_tokens,
use_cache = True,
temperature = temperature,
repetition_penalty = repetition_penalty
)
return tokenizer.batch_decode(token_ids)