Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
# app.py for Gradio App on Hugging Face Spaces
|
| 2 |
import gradio as gr
|
| 3 |
import torch
|
| 4 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
@@ -49,8 +49,8 @@ class MedicalSimulatorState:
|
|
| 49 |
self.ordered_tests: Dict[str, str] = {} # e.g., {"cbc": "pending", "xray": "result..."}
|
| 50 |
# Add more state variables as needed
|
| 51 |
|
| 52 |
-
# --- Core AI Interaction Function ---
|
| 53 |
-
def
|
| 54 |
if not model or not tokenizer:
|
| 55 |
return "Error: AI model is not loaded."
|
| 56 |
|
|
@@ -82,6 +82,39 @@ def get_ai_response(user_input: str, history: List[Tuple[Optional[str], Optional
|
|
| 82 |
print(f"Error during AI generation: {e}")
|
| 83 |
return f"An error occurred while processing the AI response: {e}"
|
| 84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
# --- Tool Functions (Modify State) ---
|
| 86 |
def start_case(case_type: str, state: MedicalSimulatorState) -> Tuple[MedicalSimulatorState, List[Tuple[Optional[str], Optional[str]]], str, str]:
|
| 87 |
# --- Generate Patient Profile (Simplified Example) ---
|
|
@@ -148,7 +181,7 @@ def handle_chat(user_input: str, history: List[Tuple[Optional[str], Optional[str
|
|
| 148 |
history.append((user_input, None))
|
| 149 |
|
| 150 |
# Get AI response
|
| 151 |
-
ai_response =
|
| 152 |
|
| 153 |
# Add AI response to history
|
| 154 |
history[-1] = (user_input, ai_response) # Update the last entry with the AI's response
|
|
@@ -164,7 +197,7 @@ def use_tool(tool_name: str, state: MedicalSimulatorState) -> Tuple[MedicalSimul
|
|
| 164 |
state.total_cost += cost
|
| 165 |
|
| 166 |
if tool_name == "ask_question":
|
| 167 |
-
ai_response =
|
| 168 |
state.chat_history.append(("System", f"[Action: {tool_name}, Cost: ${cost:.2f}]"))
|
| 169 |
state.chat_history.append(("AI Patient", ai_response))
|
| 170 |
|
|
@@ -214,69 +247,102 @@ def end_case(state: MedicalSimulatorState) -> Tuple[MedicalSimulatorState, List[
|
|
| 214 |
profile_str = "\n".join([f"{k.replace('_', ' ').title()}: {v}" for k, v in (state.patient_profile or {}).items()])
|
| 215 |
return state, state.chat_history, f"${state.total_cost:.2f}", profile_str
|
| 216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
# --- Gradio Interface ---
|
| 219 |
with gr.Blocks(title="Advanced Medical Simulator") as demo:
|
| 220 |
-
# State component to hold the simulator state across interactions
|
| 221 |
-
state = gr.State(lambda: MedicalSimulatorState())
|
| 222 |
-
|
| 223 |
gr.Markdown("# Advanced Medical Simulator")
|
| 224 |
|
| 225 |
-
with gr.
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
|
| 281 |
# Launch the app
|
| 282 |
# For Hugging Face Spaces, Gradio handles the launch.
|
|
|
|
| 1 |
+
# app.py for Gradio App on Hugging Face Spaces (with Sandbox Tab)
|
| 2 |
import gradio as gr
|
| 3 |
import torch
|
| 4 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
| 49 |
self.ordered_tests: Dict[str, str] = {} # e.g., {"cbc": "pending", "xray": "result..."}
|
| 50 |
# Add more state variables as needed
|
| 51 |
|
| 52 |
+
# --- Core AI Interaction Function (Medical Context) ---
|
| 53 |
+
def get_ai_response_medical(user_input: str, history: List[Tuple[Optional[str], Optional[str]]], patient_profile: Dict, underlying_diagnosis: str) -> str:
|
| 54 |
if not model or not tokenizer:
|
| 55 |
return "Error: AI model is not loaded."
|
| 56 |
|
|
|
|
| 82 |
print(f"Error during AI generation: {e}")
|
| 83 |
return f"An error occurred while processing the AI response: {e}"
|
| 84 |
|
| 85 |
+
# --- Core AI Interaction Function (General/Sandbox Context) ---
|
| 86 |
+
def get_ai_response_general(user_input: str, history: List[Tuple[Optional[str], Optional[str]]]) -> str:
|
| 87 |
+
if not model or not tokenizer:
|
| 88 |
+
return "Error: AI model is not loaded."
|
| 89 |
+
|
| 90 |
+
# Construct a prompt for the AI based on general chat history
|
| 91 |
+
history_str = "\n".join([f"{'User' if h[0] else 'Assistant'}: {h[0] or h[1]}" for h in history])
|
| 92 |
+
prompt = f"<|system|>You are a helpful assistant. Answer the user's questions and follow their instructions.<|user|>{history_str}\n{user_input}<|assistant|>"
|
| 93 |
+
|
| 94 |
+
try:
|
| 95 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
| 96 |
+
if inputs["input_ids"].shape[1] > 32768: # Check for model max length
|
| 97 |
+
return "Error: Input prompt is too long for the model."
|
| 98 |
+
|
| 99 |
+
# Generate response
|
| 100 |
+
generate_ids = model.generate(
|
| 101 |
+
inputs.input_ids,
|
| 102 |
+
max_new_tokens=512, # Limit generated tokens
|
| 103 |
+
do_sample=True,
|
| 104 |
+
temperature=0.7,
|
| 105 |
+
top_p=0.9,
|
| 106 |
+
pad_token_id=tokenizer.eos_token_id
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# Decode the generated response
|
| 110 |
+
response_text = tokenizer.decode(generate_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
| 111 |
+
return response_text.strip()
|
| 112 |
+
|
| 113 |
+
except Exception as e:
|
| 114 |
+
print(f"Error during AI generation: {e}")
|
| 115 |
+
return f"An error occurred while processing the AI response: {e}"
|
| 116 |
+
|
| 117 |
+
|
| 118 |
# --- Tool Functions (Modify State) ---
|
| 119 |
def start_case(case_type: str, state: MedicalSimulatorState) -> Tuple[MedicalSimulatorState, List[Tuple[Optional[str], Optional[str]]], str, str]:
|
| 120 |
# --- Generate Patient Profile (Simplified Example) ---
|
|
|
|
| 181 |
history.append((user_input, None))
|
| 182 |
|
| 183 |
# Get AI response
|
| 184 |
+
ai_response = get_ai_response_medical(user_input, history[:-1], state.patient_profile, state.underlying_diagnosis) # Pass history without the user's new message yet
|
| 185 |
|
| 186 |
# Add AI response to history
|
| 187 |
history[-1] = (user_input, ai_response) # Update the last entry with the AI's response
|
|
|
|
| 197 |
state.total_cost += cost
|
| 198 |
|
| 199 |
if tool_name == "ask_question":
|
| 200 |
+
ai_response = get_ai_response_medical("The user asks a general question to gather more history.", state.chat_history, state.patient_profile, state.underlying_diagnosis)
|
| 201 |
state.chat_history.append(("System", f"[Action: {tool_name}, Cost: ${cost:.2f}]"))
|
| 202 |
state.chat_history.append(("AI Patient", ai_response))
|
| 203 |
|
|
|
|
| 247 |
profile_str = "\n".join([f"{k.replace('_', ' ').title()}: {v}" for k, v in (state.patient_profile or {}).items()])
|
| 248 |
return state, state.chat_history, f"${state.total_cost:.2f}", profile_str
|
| 249 |
|
| 250 |
+
# --- Sandbox Chat Handler ---
|
| 251 |
+
def handle_sandbox_chat(user_input: str, history: List[Tuple[Optional[str], Optional[str]]]) -> List[Tuple[Optional[str], Optional[str]]]:
|
| 252 |
+
if not user_input.strip():
|
| 253 |
+
return history
|
| 254 |
+
|
| 255 |
+
# Add user message to history
|
| 256 |
+
history.append((user_input, None))
|
| 257 |
+
|
| 258 |
+
# Get AI response (general context)
|
| 259 |
+
ai_response = get_ai_response_general(user_input, history[:-1])
|
| 260 |
+
|
| 261 |
+
# Add AI response to history
|
| 262 |
+
history[-1] = (user_input, ai_response)
|
| 263 |
+
|
| 264 |
+
return history
|
| 265 |
+
|
| 266 |
|
| 267 |
# --- Gradio Interface ---
|
| 268 |
with gr.Blocks(title="Advanced Medical Simulator") as demo:
|
|
|
|
|
|
|
|
|
|
| 269 |
gr.Markdown("# Advanced Medical Simulator")
|
| 270 |
|
| 271 |
+
with gr.Tab("Medical Simulation"):
|
| 272 |
+
# State component to hold the simulator state across interactions
|
| 273 |
+
state = gr.State(lambda: MedicalSimulatorState())
|
| 274 |
+
|
| 275 |
+
with gr.Row():
|
| 276 |
+
with gr.Column(scale=2):
|
| 277 |
+
# Chat Interface
|
| 278 |
+
chatbot = gr.Chatbot(label="Patient Interaction", height=400, bubble_full_width=False)
|
| 279 |
+
with gr.Row():
|
| 280 |
+
user_input = gr.Textbox(label="Your Action / Question", placeholder="Type your action or question here...", scale=4)
|
| 281 |
+
submit_btn = gr.Button("Submit", scale=1)
|
| 282 |
+
|
| 283 |
+
with gr.Column(scale=1):
|
| 284 |
+
# Patient Chart / Info
|
| 285 |
+
patient_chart = gr.Markdown(label="Patient Chart", value="Click 'Start New Case' to begin.")
|
| 286 |
+
cost_display = gr.Textbox(label="Total Cost", value="$0.00", interactive=False)
|
| 287 |
+
|
| 288 |
+
with gr.Row():
|
| 289 |
+
# Tool Panel
|
| 290 |
+
with gr.Column():
|
| 291 |
+
gr.Markdown("### Tools")
|
| 292 |
+
with gr.Row():
|
| 293 |
+
ask_btn = gr.Button("Ask Question ($10)")
|
| 294 |
+
exam_btn = gr.Button("Physical Exam ($25)")
|
| 295 |
+
with gr.Row():
|
| 296 |
+
cbc_btn = gr.Button("Order CBC ($50)")
|
| 297 |
+
xray_btn = gr.Button("Order X-Ray ($150)")
|
| 298 |
+
with gr.Row():
|
| 299 |
+
med_btn = gr.Button("Administer Med ($30)")
|
| 300 |
+
end_btn = gr.Button("End Case", variant="stop") # Red button for ending
|
| 301 |
+
|
| 302 |
+
with gr.Row():
|
| 303 |
+
# Case Controls
|
| 304 |
+
start_case_btn = gr.Button("Start New Case (General)")
|
| 305 |
+
case_type_dropdown = gr.Dropdown(["General", "Psychiatry", "Pediatric", "Dual Diagnosis"], label="Case Type", value="General")
|
| 306 |
+
|
| 307 |
+
# Event Handling for Medical Simulation Tab
|
| 308 |
+
start_case_btn.click(
|
| 309 |
+
fn=start_case,
|
| 310 |
+
inputs=[case_type_dropdown, state],
|
| 311 |
+
outputs=[state, chatbot, cost_display, patient_chart]
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
submit_btn.click(
|
| 315 |
+
fn=handle_chat,
|
| 316 |
+
inputs=[user_input, chatbot, state],
|
| 317 |
+
outputs=[chatbot, cost_display]
|
| 318 |
+
).then(
|
| 319 |
+
fn=lambda: "", # Clear the input textbox after submission
|
| 320 |
+
inputs=[],
|
| 321 |
+
outputs=[user_input]
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
ask_btn.click(fn=lambda s: use_tool("ask_question", s), inputs=[state], outputs=[state, chatbot, cost_display])
|
| 325 |
+
exam_btn.click(fn=lambda s: use_tool("physical_exam", s), inputs=[state], outputs=[state, chatbot, cost_display])
|
| 326 |
+
cbc_btn.click(fn=lambda s: use_tool("order_cbc", s), inputs=[state], outputs=[state, chatbot, cost_display])
|
| 327 |
+
xray_btn.click(fn=lambda s: use_tool("order_xray", s), inputs=[state], outputs=[state, chatbot, cost_display])
|
| 328 |
+
med_btn.click(fn=lambda s: use_tool("administer_med", s), inputs=[state], outputs=[state, chatbot, cost_display])
|
| 329 |
+
end_btn.click(fn=end_case, inputs=[state], outputs=[state, chatbot, cost_display, patient_chart])
|
| 330 |
+
|
| 331 |
+
with gr.Tab("AI Sandbox"):
|
| 332 |
+
sandbox_chatbot = gr.Chatbot(label="General AI Chat", height=500, bubble_full_width=False)
|
| 333 |
+
with gr.Row():
|
| 334 |
+
sandbox_input = gr.Textbox(label="Message", placeholder="Ask anything...", scale=4)
|
| 335 |
+
sandbox_submit = gr.Button("Send", scale=1)
|
| 336 |
+
|
| 337 |
+
sandbox_submit.click(
|
| 338 |
+
fn=handle_sandbox_chat,
|
| 339 |
+
inputs=[sandbox_input, sandbox_chatbot],
|
| 340 |
+
outputs=[sandbox_chatbot]
|
| 341 |
+
).then(
|
| 342 |
+
fn=lambda: "", # Clear the input textbox after submission
|
| 343 |
+
inputs=[],
|
| 344 |
+
outputs=[sandbox_input]
|
| 345 |
+
)
|
| 346 |
|
| 347 |
# Launch the app
|
| 348 |
# For Hugging Face Spaces, Gradio handles the launch.
|