File size: 4,229 Bytes
890a412
4d743b7
6ee885d
 
4d743b7
 
f484739
 
56bfd7b
1f2bba1
58d10ba
4d743b7
 
 
 
 
 
 
 
 
 
c857032
4d743b7
 
 
 
58d10ba
6cb8836
 
4d743b7
 
 
1290f63
4d743b7
 
 
 
56bfd7b
4d743b7
 
890a412
4d743b7
 
58d10ba
4568429
4d743b7
 
 
 
890a412
 
 
 
4d743b7
 
 
 
 
 
 
 
 
 
 
47e2c8d
58d10ba
2eb12af
 
4d743b7
a3ed5c1
 
4d743b7
 
893184e
4d743b7
 
890a412
4d743b7
 
 
 
 
 
dc806a3
86e8445
 
 
dc806a3
 
 
 
 
 
4d743b7
60483f1
 
56bfd7b
 
 
 
 
60483f1
86e8445
60483f1
86e8445
 
60483f1
 
dc806a3
60483f1
 
86e8445
60483f1
 
 
4d743b7
dc806a3
4d743b7
60483f1
 
 
 
dc806a3
60483f1
 
 
 
 
dc806a3
60483f1
dc806a3
 
60483f1
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import gradio as gr
import torch
from unsloth import FastLanguageModel
from transformers import TextIteratorStreamer
import threading
from peft import PeftModel
import json
import time
import os

max_token = 9000

# -----------------------------
# 1️⃣ Set device
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# -----------------------------
# 2️⃣ Load base model (skip compilation)
# -----------------------------
base_model_name = "unsloth/gemma-3-4b-it-unsloth-bnb-4bit"#"unsloth/gemma-3-4b-it-unsloth-bnb-4bit" #"unsloth/llama-3.2-3b-instruct-unsloth-bnb-4bit" 
base_model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=base_model_name,
    max_seq_length=2048,
    dtype=torch.float16,
    load_in_4bit=True,
)

# -----------------------------
# 3️⃣ Load LoRA
# -----------------------------
lora_repo =  "Ephraimmm/PIDGIN_gemma-3" #"Ephraimmm/pigin-gemma-3-0.2" #"Ephraimmm/Pidgin_llamma_model" 
lora_model = PeftModel.from_pretrained(base_model, lora_repo, adapter_name="adapter_model")
FastLanguageModel.for_inference(lora_model)

# -----------------------------
# 4️⃣ Streaming generation function
# -----------------------------
def generate_response(user_message):
    messages = [
        {
            "role": "system",
            "content": [{"type": "text", "text": """You are a Nigerian assistant that speaks PIDGIN ENGLISH.when asked how far reply I de o, how you de"""}]
                                    },
        {
            "role": "user",
            "content": [{"type": "text", "text": user_message}]
        }
    ]

    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt",
        tokenize=True,
        return_dict=True
    ).to(device)

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    generation_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=max_token,
        temperature=0.1,
        top_p=1.0,
        top_k=None,
        use_cache=False
    )

    def generate():
        lora_model.generate(**generation_kwargs)

    thread = threading.Thread(target=generate)
    thread.start()

    full_response = ""
    for new_token in streamer:
        if new_token:
            full_response += new_token
    thread.join()
    return full_response

# -----------------------------
# 5️⃣ Chat + Save
# -----------------------------
chat_history = []

def chat(user_message):
    bot_response = generate_response(user_message)
    chat_history.append((user_message, bot_response))
    return chat_history, ""  # also clears input box

def save_conversation():
    if not chat_history:
        # Return a small empty txt file instead of None (to avoid Gradio error)
        file_path = "conversation_empty.txt"
        with open(file_path, "w", encoding="utf-8") as f:
            f.write("[]")
        return file_path
    
    conversation = []
    for user_msg, bot_msg in chat_history:
        conversation.append({"role": "user", "content": str(user_msg)})
        conversation.append({"role": "assistant", "content": str(bot_msg)})
    
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    file_path = f"conversation_{timestamp}.txt"   # save as TXT not JSON
    
    with open(file_path, "w", encoding="utf-8") as f:
        json.dump(conversation, f, indent=4, ensure_ascii=False)
    
    return file_path

# -----------------------------
# 6️⃣ Gradio interface
# -----------------------------
with gr.Blocks() as demo:
    gr.Markdown("# Nigerian PIDGIN Assistant")
    gr.Markdown("Chat with a Nigerian assistant that only speaks Pidgin English.")

    chatbot = gr.Chatbot(label="Conversation")
    user_input = gr.Textbox(label="Your message", placeholder="Type your message here...")
    
    with gr.Row():
        send_button = gr.Button("Send")
        save_button = gr.Button("Save Conversation")
        download_file = gr.File(label="Download Conversation")

    send_button.click(chat, inputs=user_input, outputs=[chatbot, user_input])
    save_button.click(save_conversation, outputs=download_file)

demo.launch()