Pidgin_0.1 / app.py
Ephraimmm's picture
Update app.py
893184e verified
raw
history blame
4 kB
import torch
import gc
import json
import time
from threading import Thread
import gradio as gr
from unsloth import FastLanguageModel
from transformers import TextIteratorStreamer
# ---------------------
# Setup + Model Load
# ---------------------
# Clear out memory before loading
torch.cuda.empty_cache()
gc.collect()
MODEL_ID = "Ephraimmm/PIDGIN_gemma-3"
CONTEXT_LEN = 128000 # Gemma-3 default context window as per blog
print("Using Unsloth Gemma-3 model with 128K context window...")
# Make sure your environment has updated versions:
# pip install -U unsloth unsloth_zoo transformers
# Load the quantized model with Unsloth
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = MODEL_ID,
max_seq_length = CONTEXT_LEN,
dtype = None, # Let Unsloth pick appropriate dtype
load_in_4bit = True,
trust_remote_code = True,
)
FastLanguageModel.for_inference(model)
print("βœ… Model loaded (4-bit dynamic if available)")
# ---------------------
# Chat Streaming Function
# ---------------------
def stream_chat(message, history):
# Build message list as required by Unsloth
messages = [
{"role": "system", "content": "You be Naija assistant. You must always reply for Pidgin English."}
]
if history:
for human, bot in history:
messages.append({"role": "user", "content": human})
messages.append({"role": "assistant", "content": bot})
messages.append({"role": "user", "content": message})
# Using apply_chat_template (supported by Unsloth) to handle the formatting
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt = True,
return_tensors = "pt"
).to(model.device)
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
generate_kwargs = dict(
input_ids = inputs,
streamer = streamer,
max_new_tokens = 512,
temperature = 0.8,
do_sample = True,
top_p = 0.9,
)
# Run in background thread to stream
thread = Thread(target = model.generate, kwargs = generate_kwargs)
thread.start()
output = ""
for partial in streamer:
output += partial
yield output
# ---------------------
# Save chat to file (JSON format)
# ---------------------
def save_chat(history):
export = []
for human, bot in history:
export.append({"role": "user", "content": human})
export.append({"role": "assistant", "content": bot})
timestamp = time.strftime("%Y%m%d-%H%M%S")
fname = f"conversation_{timestamp}.json"
with open(fname, "w", encoding="utf-8") as f:
json.dump(export, f, ensure_ascii=False, indent=2)
return fname
# ---------------------
# UI with Gradio
# ---------------------
with gr.Blocks(title="πŸ‡³πŸ‡¬ PIDGIN Gemma-3 Chatbot") as demo:
gr.HTML("<h1><center>πŸ‡³πŸ‡¬ PIDGIN Gemma-3 Chatbot</center></h1>")
chatbot = gr.Chatbot(height=450, show_label=False)
with gr.Row():
msg = gr.Textbox(placeholder="Type your message here...", lines=2, scale=4)
send = gr.Button("Send", variant="primary", scale=1, size="lg")
with gr.Row():
clear = gr.Button("Clear Chat", variant="secondary", scale=1)
save_btn = gr.Button("πŸ’Ύ Save Conversation", variant="secondary", scale=1)
download_file = gr.File()
def respond(message, history):
if history is None:
history = []
stream = stream_chat(message, history)
response = ""
for partial in stream:
response = partial
yield history + [(message, response)], ""
yield history + [(message, response)], ""
msg.submit(respond, [msg, chatbot], [chatbot, msg])
send.click(respond, [msg, chatbot], [chatbot, msg])
clear.click(lambda: ([], ""), outputs=[chatbot, msg])
save_btn.click(save_chat, inputs=[chatbot], outputs=[download_file])
if __name__ == "__main__":
demo.launch(share=True, debug=True)