from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
from transformers import TextStreamer

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "scoliono/groupchat_lora_abliterated_instruct-3.1-8b",
    max_seq_length = 2048,
    dtype = None,
    load_in_4bit = True,
)
FastLanguageModel.for_inference(model) # Enable native 2x faster inference

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "llama-3", # Supports zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, unsloth
    map_eos_token = True, # Maps <|im_end|> to </s> instead
)

def inference(messages, max_new_tokens=64, temperature=0.9, repetition_penalty=1.2):
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize = True,
        add_generation_prompt = True, # Must add for generation
        return_tensors = "pt",
    ).to("cuda")

    #text_streamer = TextStreamer(tokenizer)

    token_ids = model.generate(
        input_ids = inputs,
        #streamer = text_streamer,
        max_new_tokens = max_new_tokens,
        use_cache = True,
        temperature = temperature,
        repetition_penalty = repetition_penalty
    )

    return tokenizer.batch_decode(token_ids)