Jpete20001 commited on
Commit
244da05
·
verified ·
1 Parent(s): d9cba7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -63
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py for Gradio App on Hugging Face Spaces
2
  import gradio as gr
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -49,8 +49,8 @@ class MedicalSimulatorState:
49
  self.ordered_tests: Dict[str, str] = {} # e.g., {"cbc": "pending", "xray": "result..."}
50
  # Add more state variables as needed
51
 
52
- # --- Core AI Interaction Function ---
53
- def get_ai_response(user_input: str, history: List[Tuple[Optional[str], Optional[str]]], patient_profile: Dict, underlying_diagnosis: str) -> str:
54
  if not model or not tokenizer:
55
  return "Error: AI model is not loaded."
56
 
@@ -82,6 +82,39 @@ def get_ai_response(user_input: str, history: List[Tuple[Optional[str], Optional
82
  print(f"Error during AI generation: {e}")
83
  return f"An error occurred while processing the AI response: {e}"
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  # --- Tool Functions (Modify State) ---
86
  def start_case(case_type: str, state: MedicalSimulatorState) -> Tuple[MedicalSimulatorState, List[Tuple[Optional[str], Optional[str]]], str, str]:
87
  # --- Generate Patient Profile (Simplified Example) ---
@@ -148,7 +181,7 @@ def handle_chat(user_input: str, history: List[Tuple[Optional[str], Optional[str
148
  history.append((user_input, None))
149
 
150
  # Get AI response
151
- ai_response = get_ai_response(user_input, history[:-1], state.patient_profile, state.underlying_diagnosis) # Pass history without the user's new message yet
152
 
153
  # Add AI response to history
154
  history[-1] = (user_input, ai_response) # Update the last entry with the AI's response
@@ -164,7 +197,7 @@ def use_tool(tool_name: str, state: MedicalSimulatorState) -> Tuple[MedicalSimul
164
  state.total_cost += cost
165
 
166
  if tool_name == "ask_question":
167
- ai_response = get_ai_response("The user asks a general question to gather more history.", state.chat_history, state.patient_profile, state.underlying_diagnosis)
168
  state.chat_history.append(("System", f"[Action: {tool_name}, Cost: ${cost:.2f}]"))
169
  state.chat_history.append(("AI Patient", ai_response))
170
 
@@ -214,69 +247,102 @@ def end_case(state: MedicalSimulatorState) -> Tuple[MedicalSimulatorState, List[
214
  profile_str = "\n".join([f"{k.replace('_', ' ').title()}: {v}" for k, v in (state.patient_profile or {}).items()])
215
  return state, state.chat_history, f"${state.total_cost:.2f}", profile_str
216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
  # --- Gradio Interface ---
219
  with gr.Blocks(title="Advanced Medical Simulator") as demo:
220
- # State component to hold the simulator state across interactions
221
- state = gr.State(lambda: MedicalSimulatorState())
222
-
223
  gr.Markdown("# Advanced Medical Simulator")
224
 
225
- with gr.Row():
226
- with gr.Column(scale=2):
227
- # Chat Interface
228
- chatbot = gr.Chatbot(label="Patient Interaction", height=400, bubble_full_width=False)
229
- with gr.Row():
230
- user_input = gr.Textbox(label="Your Action / Question", placeholder="Type your action or question here...", scale=4)
231
- submit_btn = gr.Button("Submit", scale=1)
232
-
233
- with gr.Column(scale=1):
234
- # Patient Chart / Info
235
- patient_chart = gr.Markdown(label="Patient Chart", value="Click 'Start New Case' to begin.")
236
- cost_display = gr.Textbox(label="Total Cost", value="$0.00", interactive=False)
237
-
238
- with gr.Row():
239
- # Tool Panel
240
- with gr.Column():
241
- gr.Markdown("### Tools")
242
- with gr.Row():
243
- ask_btn = gr.Button("Ask Question ($10)")
244
- exam_btn = gr.Button("Physical Exam ($25)")
245
- with gr.Row():
246
- cbc_btn = gr.Button("Order CBC ($50)")
247
- xray_btn = gr.Button("Order X-Ray ($150)")
248
- with gr.Row():
249
- med_btn = gr.Button("Administer Med ($30)")
250
- end_btn = gr.Button("End Case", variant="stop") # Red button for ending
251
-
252
- with gr.Row():
253
- # Case Controls
254
- start_case_btn = gr.Button("Start New Case (General)")
255
- case_type_dropdown = gr.Dropdown(["General", "Psychiatry", "Pediatric", "Dual Diagnosis"], label="Case Type", value="General")
256
-
257
- # Event Handling
258
- start_case_btn.click(
259
- fn=start_case,
260
- inputs=[case_type_dropdown, state],
261
- outputs=[state, chatbot, cost_display, patient_chart]
262
- )
263
-
264
- submit_btn.click(
265
- fn=handle_chat,
266
- inputs=[user_input, chatbot, state],
267
- outputs=[chatbot, cost_display]
268
- ).then(
269
- fn=lambda: "", # Clear the input textbox after submission
270
- inputs=[],
271
- outputs=[user_input]
272
- )
273
-
274
- ask_btn.click(fn=lambda s: use_tool("ask_question", s), inputs=[state], outputs=[state, chatbot, cost_display])
275
- exam_btn.click(fn=lambda s: use_tool("physical_exam", s), inputs=[state], outputs=[state, chatbot, cost_display])
276
- cbc_btn.click(fn=lambda s: use_tool("order_cbc", s), inputs=[state], outputs=[state, chatbot, cost_display])
277
- xray_btn.click(fn=lambda s: use_tool("order_xray", s), inputs=[state], outputs=[state, chatbot, cost_display])
278
- med_btn.click(fn=lambda s: use_tool("administer_med", s), inputs=[state], outputs=[state, chatbot, cost_display])
279
- end_btn.click(fn=end_case, inputs=[state], outputs=[state, chatbot, cost_display, patient_chart])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
  # Launch the app
282
  # For Hugging Face Spaces, Gradio handles the launch.
 
1
+ # app.py for Gradio App on Hugging Face Spaces (with Sandbox Tab)
2
  import gradio as gr
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
49
  self.ordered_tests: Dict[str, str] = {} # e.g., {"cbc": "pending", "xray": "result..."}
50
  # Add more state variables as needed
51
 
52
+ # --- Core AI Interaction Function (Medical Context) ---
53
+ def get_ai_response_medical(user_input: str, history: List[Tuple[Optional[str], Optional[str]]], patient_profile: Dict, underlying_diagnosis: str) -> str:
54
  if not model or not tokenizer:
55
  return "Error: AI model is not loaded."
56
 
 
82
  print(f"Error during AI generation: {e}")
83
  return f"An error occurred while processing the AI response: {e}"
84
 
85
+ # --- Core AI Interaction Function (General/Sandbox Context) ---
86
+ def get_ai_response_general(user_input: str, history: List[Tuple[Optional[str], Optional[str]]]) -> str:
87
+ if not model or not tokenizer:
88
+ return "Error: AI model is not loaded."
89
+
90
+ # Construct a prompt for the AI based on general chat history
91
+ history_str = "\n".join([f"{'User' if h[0] else 'Assistant'}: {h[0] or h[1]}" for h in history])
92
+ prompt = f"<|system|>You are a helpful assistant. Answer the user's questions and follow their instructions.<|user|>{history_str}\n{user_input}<|assistant|>"
93
+
94
+ try:
95
+ inputs = tokenizer(prompt, return_tensors="pt")
96
+ if inputs["input_ids"].shape[1] > 32768: # Check for model max length
97
+ return "Error: Input prompt is too long for the model."
98
+
99
+ # Generate response
100
+ generate_ids = model.generate(
101
+ inputs.input_ids,
102
+ max_new_tokens=512, # Limit generated tokens
103
+ do_sample=True,
104
+ temperature=0.7,
105
+ top_p=0.9,
106
+ pad_token_id=tokenizer.eos_token_id
107
+ )
108
+
109
+ # Decode the generated response
110
+ response_text = tokenizer.decode(generate_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
111
+ return response_text.strip()
112
+
113
+ except Exception as e:
114
+ print(f"Error during AI generation: {e}")
115
+ return f"An error occurred while processing the AI response: {e}"
116
+
117
+
118
  # --- Tool Functions (Modify State) ---
119
  def start_case(case_type: str, state: MedicalSimulatorState) -> Tuple[MedicalSimulatorState, List[Tuple[Optional[str], Optional[str]]], str, str]:
120
  # --- Generate Patient Profile (Simplified Example) ---
 
181
  history.append((user_input, None))
182
 
183
  # Get AI response
184
+ ai_response = get_ai_response_medical(user_input, history[:-1], state.patient_profile, state.underlying_diagnosis) # Pass history without the user's new message yet
185
 
186
  # Add AI response to history
187
  history[-1] = (user_input, ai_response) # Update the last entry with the AI's response
 
197
  state.total_cost += cost
198
 
199
  if tool_name == "ask_question":
200
+ ai_response = get_ai_response_medical("The user asks a general question to gather more history.", state.chat_history, state.patient_profile, state.underlying_diagnosis)
201
  state.chat_history.append(("System", f"[Action: {tool_name}, Cost: ${cost:.2f}]"))
202
  state.chat_history.append(("AI Patient", ai_response))
203
 
 
247
  profile_str = "\n".join([f"{k.replace('_', ' ').title()}: {v}" for k, v in (state.patient_profile or {}).items()])
248
  return state, state.chat_history, f"${state.total_cost:.2f}", profile_str
249
 
250
+ # --- Sandbox Chat Handler ---
251
+ def handle_sandbox_chat(user_input: str, history: List[Tuple[Optional[str], Optional[str]]]) -> List[Tuple[Optional[str], Optional[str]]]:
252
+ if not user_input.strip():
253
+ return history
254
+
255
+ # Add user message to history
256
+ history.append((user_input, None))
257
+
258
+ # Get AI response (general context)
259
+ ai_response = get_ai_response_general(user_input, history[:-1])
260
+
261
+ # Add AI response to history
262
+ history[-1] = (user_input, ai_response)
263
+
264
+ return history
265
+
266
 
267
  # --- Gradio Interface ---
268
  with gr.Blocks(title="Advanced Medical Simulator") as demo:
 
 
 
269
  gr.Markdown("# Advanced Medical Simulator")
270
 
271
+ with gr.Tab("Medical Simulation"):
272
+ # State component to hold the simulator state across interactions
273
+ state = gr.State(lambda: MedicalSimulatorState())
274
+
275
+ with gr.Row():
276
+ with gr.Column(scale=2):
277
+ # Chat Interface
278
+ chatbot = gr.Chatbot(label="Patient Interaction", height=400, bubble_full_width=False)
279
+ with gr.Row():
280
+ user_input = gr.Textbox(label="Your Action / Question", placeholder="Type your action or question here...", scale=4)
281
+ submit_btn = gr.Button("Submit", scale=1)
282
+
283
+ with gr.Column(scale=1):
284
+ # Patient Chart / Info
285
+ patient_chart = gr.Markdown(label="Patient Chart", value="Click 'Start New Case' to begin.")
286
+ cost_display = gr.Textbox(label="Total Cost", value="$0.00", interactive=False)
287
+
288
+ with gr.Row():
289
+ # Tool Panel
290
+ with gr.Column():
291
+ gr.Markdown("### Tools")
292
+ with gr.Row():
293
+ ask_btn = gr.Button("Ask Question ($10)")
294
+ exam_btn = gr.Button("Physical Exam ($25)")
295
+ with gr.Row():
296
+ cbc_btn = gr.Button("Order CBC ($50)")
297
+ xray_btn = gr.Button("Order X-Ray ($150)")
298
+ with gr.Row():
299
+ med_btn = gr.Button("Administer Med ($30)")
300
+ end_btn = gr.Button("End Case", variant="stop") # Red button for ending
301
+
302
+ with gr.Row():
303
+ # Case Controls
304
+ start_case_btn = gr.Button("Start New Case (General)")
305
+ case_type_dropdown = gr.Dropdown(["General", "Psychiatry", "Pediatric", "Dual Diagnosis"], label="Case Type", value="General")
306
+
307
+ # Event Handling for Medical Simulation Tab
308
+ start_case_btn.click(
309
+ fn=start_case,
310
+ inputs=[case_type_dropdown, state],
311
+ outputs=[state, chatbot, cost_display, patient_chart]
312
+ )
313
+
314
+ submit_btn.click(
315
+ fn=handle_chat,
316
+ inputs=[user_input, chatbot, state],
317
+ outputs=[chatbot, cost_display]
318
+ ).then(
319
+ fn=lambda: "", # Clear the input textbox after submission
320
+ inputs=[],
321
+ outputs=[user_input]
322
+ )
323
+
324
+ ask_btn.click(fn=lambda s: use_tool("ask_question", s), inputs=[state], outputs=[state, chatbot, cost_display])
325
+ exam_btn.click(fn=lambda s: use_tool("physical_exam", s), inputs=[state], outputs=[state, chatbot, cost_display])
326
+ cbc_btn.click(fn=lambda s: use_tool("order_cbc", s), inputs=[state], outputs=[state, chatbot, cost_display])
327
+ xray_btn.click(fn=lambda s: use_tool("order_xray", s), inputs=[state], outputs=[state, chatbot, cost_display])
328
+ med_btn.click(fn=lambda s: use_tool("administer_med", s), inputs=[state], outputs=[state, chatbot, cost_display])
329
+ end_btn.click(fn=end_case, inputs=[state], outputs=[state, chatbot, cost_display, patient_chart])
330
+
331
+ with gr.Tab("AI Sandbox"):
332
+ sandbox_chatbot = gr.Chatbot(label="General AI Chat", height=500, bubble_full_width=False)
333
+ with gr.Row():
334
+ sandbox_input = gr.Textbox(label="Message", placeholder="Ask anything...", scale=4)
335
+ sandbox_submit = gr.Button("Send", scale=1)
336
+
337
+ sandbox_submit.click(
338
+ fn=handle_sandbox_chat,
339
+ inputs=[sandbox_input, sandbox_chatbot],
340
+ outputs=[sandbox_chatbot]
341
+ ).then(
342
+ fn=lambda: "", # Clear the input textbox after submission
343
+ inputs=[],
344
+ outputs=[sandbox_input]
345
+ )
346
 
347
  # Launch the app
348
  # For Hugging Face Spaces, Gradio handles the launch.