nroggendorff commited on
Commit
ea34ed3
·
1 Parent(s): 2b46203

take everything that made this app special, and eliminate it

Browse files
Files changed (1) hide show
  1. app.py +36 -65
app.py CHANGED
@@ -1,84 +1,55 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- from spaces import GPU as gpu
 
 
4
 
 
5
 
6
- class Delta:
7
- def __init__(self, content):
8
- self.content = content
9
-
10
-
11
- class Choice:
12
- def __init__(self, delta):
13
- self.delta = delta
14
-
15
-
16
- class InferenceClient:
17
- def __init__(self, model_id="nroggendorff/smallama-it"):
18
- self.tokenizer = AutoTokenizer.from_pretrained(model_id)
19
- self.model = AutoModelForCausalLM.from_pretrained(model_id)
20
-
21
- class ModelOutput:
22
- def __init__(self, client, inputs):
23
- self.client = client
24
- self.inputs = inputs
25
- self.choices = []
26
-
27
- def decode(self, output):
28
- decoded_output = self.client.tokenizer.decode(
29
- output[0][self.inputs["input_ids"].shape[-1] :],
30
- skip_special_tokens=True,
31
- )
32
- self.choices = [Choice(Delta(decoded_output))]
33
- return self
34
-
35
- @gpu
36
- def chat_completion(
37
- self, messages, max_tokens=256, stream=True, temperature=0.2, top_p=0.95
38
- ):
39
- inputs = self.tokenizer.apply_chat_template(
40
- messages,
41
- add_generation_prompt=True,
42
- tokenize=True,
43
- return_dict=True,
44
- return_tensors="pt",
45
- ).to(self.model.device)
46
-
47
- model_output = self.ModelOutput(self, inputs)
48
-
49
- for _ in range(max_tokens):
50
- output = self.model.generate(
51
- **inputs, max_new_tokens=1, temperature=temperature, top_p=top_p
52
- )
53
- yield model_output.decode(output)
54
 
55
 
 
56
  def respond(
57
  message,
58
  history: list[dict[str, str]],
59
- system_message,
60
  max_tokens,
61
  temperature,
62
  top_p,
63
  ):
64
- client = InferenceClient()
65
- messages = [{"role": "system", "content": system_message}]
66
- messages.extend(history)
67
  messages.append({"role": "user", "content": message})
68
 
69
- response = ""
70
-
71
- for message in client.chat_completion(
72
  messages,
73
- max_tokens=max_tokens,
74
- stream=True,
 
 
 
 
 
 
 
 
 
 
 
 
75
  temperature=temperature,
76
  top_p=top_p,
77
- ):
78
- choices = message.choices
79
- token = ""
80
- if len(choices) and choices[0].delta.content:
81
- token = choices[0].delta.content
 
 
 
 
82
  response += token
83
  yield response
84
 
@@ -88,7 +59,7 @@ chatbot = gr.ChatInterface(
88
  type="messages",
89
  additional_inputs=[
90
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
91
- gr.Slider(minimum=0.1, maximum=4.0, value=0.2, step=0.1, label="Temperature"),
92
  gr.Slider(
93
  minimum=0.1,
94
  maximum=1.0,
 
1
  import gradio as gr
2
+ import spaces
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
4
+ from threading import Thread
5
+ import torch
6
 
7
+ MODEL_ID = "nroggendorff/smallama-it"
8
 
9
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ MODEL_ID, dtype=torch.float16, device_map="auto"
12
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
+ @spaces.GPU
16
  def respond(
17
  message,
18
  history: list[dict[str, str]],
 
19
  max_tokens,
20
  temperature,
21
  top_p,
22
  ):
23
+ messages = history
 
 
24
  messages.append({"role": "user", "content": message})
25
 
26
+ inputs = tokenizer.apply_chat_template(
 
 
27
  messages,
28
+ add_generation_prompt=True,
29
+ tokenize=True,
30
+ return_dict=True,
31
+ return_tensors="pt",
32
+ ).to(model.device)
33
+
34
+ streamer = TextIteratorStreamer(
35
+ tokenizer, skip_prompt=True, skip_special_tokens=True
36
+ )
37
+
38
+ generation_kwargs = dict(
39
+ input_ids=inputs["input_ids"],
40
+ attention_mask=inputs["attention_mask"],
41
+ max_new_tokens=max_tokens,
42
  temperature=temperature,
43
  top_p=top_p,
44
+ do_sample=True,
45
+ streamer=streamer,
46
+ )
47
+
48
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
49
+ thread.start()
50
+
51
+ response = ""
52
+ for token in streamer:
53
  response += token
54
  yield response
55
 
 
59
  type="messages",
60
  additional_inputs=[
61
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
62
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
63
  gr.Slider(
64
  minimum=0.1,
65
  maximum=1.0,