Spaces:
Runtime error
Runtime error
| import torch | |
| import gc | |
| import json | |
| import time | |
| from threading import Thread | |
| import gradio as gr | |
| from unsloth import FastLanguageModel | |
| from transformers import TextIteratorStreamer | |
| # --------------------- | |
| # Setup + Model Load | |
| # --------------------- | |
| # Clear out memory before loading | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| MODEL_ID = "Ephraimmm/PIDGIN_gemma-3" | |
| CONTEXT_LEN = 128000 # Gemma-3 default context window as per blog | |
| print("Using Unsloth Gemma-3 model with 128K context window...") | |
| # Make sure your environment has updated versions: | |
| # pip install -U unsloth unsloth_zoo transformers | |
| # Load the quantized model with Unsloth | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name = MODEL_ID, | |
| max_seq_length = CONTEXT_LEN, | |
| dtype = None, # Let Unsloth pick appropriate dtype | |
| load_in_4bit = True, | |
| trust_remote_code = True, | |
| ) | |
| FastLanguageModel.for_inference(model) | |
| print("β Model loaded (4-bit dynamic if available)") | |
| # --------------------- | |
| # Chat Streaming Function | |
| # --------------------- | |
| def stream_chat(message, history): | |
| # Build message list as required by Unsloth | |
| messages = [ | |
| {"role": "system", "content": "You be Naija assistant. You must always reply for Pidgin English."} | |
| ] | |
| if history: | |
| for human, bot in history: | |
| messages.append({"role": "user", "content": human}) | |
| messages.append({"role": "assistant", "content": bot}) | |
| messages.append({"role": "user", "content": message}) | |
| # Using apply_chat_template (supported by Unsloth) to handle the formatting | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt = True, | |
| return_tensors = "pt" | |
| ).to(model.device) | |
| streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) | |
| generate_kwargs = dict( | |
| input_ids = inputs, | |
| streamer = streamer, | |
| max_new_tokens = 512, | |
| temperature = 0.8, | |
| do_sample = True, | |
| top_p = 0.9, | |
| ) | |
| # Run in background thread to stream | |
| thread = Thread(target = model.generate, kwargs = generate_kwargs) | |
| thread.start() | |
| output = "" | |
| for partial in streamer: | |
| output += partial | |
| yield output | |
| # --------------------- | |
| # Save chat to file (JSON format) | |
| # --------------------- | |
| def save_chat(history): | |
| export = [] | |
| for human, bot in history: | |
| export.append({"role": "user", "content": human}) | |
| export.append({"role": "assistant", "content": bot}) | |
| timestamp = time.strftime("%Y%m%d-%H%M%S") | |
| fname = f"conversation_{timestamp}.json" | |
| with open(fname, "w", encoding="utf-8") as f: | |
| json.dump(export, f, ensure_ascii=False, indent=2) | |
| return fname | |
| # --------------------- | |
| # UI with Gradio | |
| # --------------------- | |
| with gr.Blocks(title="π³π¬ PIDGIN Gemma-3 Chatbot") as demo: | |
| gr.HTML("<h1><center>π³π¬ PIDGIN Gemma-3 Chatbot</center></h1>") | |
| chatbot = gr.Chatbot(height=450, show_label=False) | |
| with gr.Row(): | |
| msg = gr.Textbox(placeholder="Type your message here...", lines=2, scale=4) | |
| send = gr.Button("Send", variant="primary", scale=1, size="lg") | |
| with gr.Row(): | |
| clear = gr.Button("Clear Chat", variant="secondary", scale=1) | |
| save_btn = gr.Button("πΎ Save Conversation", variant="secondary", scale=1) | |
| download_file = gr.File() | |
| def respond(message, history): | |
| if history is None: | |
| history = [] | |
| stream = stream_chat(message, history) | |
| response = "" | |
| for partial in stream: | |
| response = partial | |
| yield history + [(message, response)], "" | |
| yield history + [(message, response)], "" | |
| msg.submit(respond, [msg, chatbot], [chatbot, msg]) | |
| send.click(respond, [msg, chatbot], [chatbot, msg]) | |
| clear.click(lambda: ([], ""), outputs=[chatbot, msg]) | |
| save_btn.click(save_chat, inputs=[chatbot], outputs=[download_file]) | |
| if __name__ == "__main__": | |
| demo.launch(share=True, debug=True) | |