39 lines
1.2 KiB
Python
39 lines
1.2 KiB
Python
from unsloth import FastLanguageModel
|
|
from unsloth.chat_templates import get_chat_template
|
|
from transformers import TextStreamer
|
|
|
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
|
model_name = "scoliono/groupchat_lora_abliterated_8b",
|
|
max_seq_length = 2048,
|
|
dtype = None,
|
|
load_in_4bit = True,
|
|
)
|
|
FastLanguageModel.for_inference(model) # Enable native 2x faster inference
|
|
|
|
tokenizer = get_chat_template(
|
|
tokenizer,
|
|
chat_template = "llama-3", # Supports zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, unsloth
|
|
map_eos_token = True, # Maps <|im_end|> to </s> instead
|
|
)
|
|
|
|
def inference(messages, max_new_tokens=64, temperature=0.9, repetition_penalty=1.2):
|
|
inputs = tokenizer.apply_chat_template(
|
|
messages,
|
|
tokenize = True,
|
|
add_generation_prompt = True, # Must add for generation
|
|
return_tensors = "pt",
|
|
).to("cuda")
|
|
|
|
#text_streamer = TextStreamer(tokenizer)
|
|
|
|
token_ids = model.generate(
|
|
input_ids = inputs,
|
|
#streamer = text_streamer,
|
|
max_new_tokens = max_new_tokens,
|
|
use_cache = True,
|
|
temperature = temperature,
|
|
repetition_penalty = repetition_penalty
|
|
)
|
|
|
|
return tokenizer.batch_decode(token_ids)
|