HITL-KG / app.py
avojarot's picture
Update app.py
9144dde verified
"""
HITL-KG Medical Reasoning System - Main Application
Human-in-the-Loop Knowledge Graph Visualization for Medical Reasoning
Features:
- Interactive knowledge graph visualization with Cytoscape
- Multilingual support with embedding-based entity extraction
- Session persistence with chat history
- RLHF feedback collection
- Glass-box visualization of reasoning process
Refactored for:
- Simplified state management
- Cleaner callbacks
- Embedding-based search (replacing keyword matching)
- Configuration-driven setup
"""
import os
import uuid
import logging
from datetime import datetime
from typing import Dict, List, Optional, Any
import dash
from dash import html, dcc, callback, Input, Output, State, ctx, ALL
from dash.exceptions import PreventUpdate
import dash_cytoscape as cyto
import dash_bootstrap_components as dbc
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Load extra layouts
cyto.load_extra_layouts()
# Import core modules
from src.core import (
load_knowledge_graph,
create_engine,
LLMProvider,
GenerationConfig,
GraphSynchronizer,
NodeType,
NODE_TYPE_INFO,
get_config,
get_session_manager,
detect_language,
)
# Import styles (assumes these exist - not recreating)
try:
from src.styles import CYTOSCAPE_STYLESHEET, LAYOUT_CONFIGS, CUSTOM_CSS
except ImportError:
# Fallback styles if module doesn't exist
CYTOSCAPE_STYLESHEET = [
{"selector": "node", "style": {
"label": "data(label)", "background-color": "#818cf8",
"color": "#fff", "font-size": "10px", "text-wrap": "wrap",
"text-max-width": "100px"
}},
{"selector": "edge", "style": {
"curve-style": "bezier", "target-arrow-shape": "triangle",
"line-color": "#64748b", "target-arrow-color": "#64748b"
}},
{"selector": ".query", "style": {"background-color": "#38bdf8"}},
{"selector": ".fact", "style": {"background-color": "#4ade80"}},
{"selector": ".reasoning", "style": {"background-color": "#818cf8"}},
{"selector": ".hypothesis", "style": {"background-color": "#fbbf24"}},
{"selector": ".conclusion", "style": {"background-color": "#f472b6"}},
{"selector": ".ghost", "style": {"background-color": "#94a3b8", "opacity": 0.6}},
]
LAYOUT_CONFIGS = {
"hierarchical": {"name": "dagre", "rankDir": "TB", "spacingFactor": 1.2},
"force": {"name": "cose", "animate": False},
"radial": {"name": "concentric", "animate": False},
}
CUSTOM_CSS = ""
# ============================================================================
# CONFIGURATION
# ============================================================================
config = get_config()
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
DEFAULT_PROVIDER = "openai" if OPENAI_API_KEY else "local"
AVAILABLE_PROVIDERS = []
if OPENAI_API_KEY:
AVAILABLE_PROVIDERS.append({"label": "🤖 OpenAI GPT-4", "value": "openai"})
AVAILABLE_PROVIDERS.append({"label": "📊 Local Knowledge Graph", "value": "local"})
EXAMPLE_QUERIES = {
"en": [
{"text": "Fever and cough for 3 days", "icon": "🤒"},
{"text": "Headache with fatigue", "icon": "😫"},
{"text": "Sore throat and runny nose", "icon": "🤧"},
{"text": "Shortness of breath", "icon": "😷"},
],
"uk": [
{"text": "Температура і кашель", "icon": "🤒"},
{"text": "Головний біль з втомою", "icon": "😫"},
{"text": "Біль у горлі та нежить", "icon": "🤧"},
{"text": "Задишка", "icon": "😷"},
],
}
# ============================================================================
# USER STATE MANAGEMENT (Simplified)
# ============================================================================
class UserState:
"""Per-user state with KG, engine, and session data."""
def __init__(self, session_id: str):
self.session_id = session_id
self.kg = load_knowledge_graph(use_embeddings=True)
self.engine = None
self.synchronizer = None
self.provider = DEFAULT_PROVIDER
self.language = "en"
self._init_engine()
self._restore_session()
def _init_engine(self):
"""Initialize reasoning engine."""
try:
if self.provider == "openai" and OPENAI_API_KEY:
from src.core import OpenAIEngine
self.engine = OpenAIEngine(self.kg, api_key=OPENAI_API_KEY)
else:
from src.core import LocalEngine
self.engine = LocalEngine(self.kg)
self.synchronizer = GraphSynchronizer(self.engine, self.kg)
except Exception as e:
logger.error(f"Engine init failed: {e}")
from src.core import LocalEngine
self.engine = LocalEngine(self.kg)
self.synchronizer = GraphSynchronizer(self.engine, self.kg)
def _restore_session(self):
"""Restore session from manager."""
sm = get_session_manager()
session = sm.get_session(self.session_id)
if session and session.graph_state:
try:
self.kg.restore_state(session.graph_state)
self.language = session.language
except Exception as e:
logger.warning(f"Failed to restore session: {e}")
def set_provider(self, provider: str):
"""Switch LLM provider."""
if provider != self.provider:
self.provider = provider
self._init_engine()
def reset(self):
"""Reset reasoning state completely."""
# Clear knowledge graph reasoning
self.kg.clear_reasoning()
# Reset session
sm = get_session_manager()
session = sm.get_or_create(self.session_id)
session.chat_history.clear()
session.graph_state = None
sm.save_session(self.session_id)
# Reset language
self.language = "en"
logger.info(f"Session {self.session_id} reset. Graph now has {len(self.kg.nodes)} nodes")
def save(self):
"""Save current state."""
sm = get_session_manager()
sm.update_graph_state(self.session_id, self.kg.get_state())
def get_chat_history(self) -> List[Dict]:
"""Get chat history from session."""
sm = get_session_manager()
session = sm.get_or_create(self.session_id)
return [{"role": m.role, "content": m.content} for m in session.chat_history]
def add_message(self, role: str, content: str):
"""Add message to session."""
sm = get_session_manager()
session = sm.get_or_create(self.session_id)
session.add_message(role, content, self.language)
sm.save_session(self.session_id)
# User state storage
_user_states: Dict[str, UserState] = {}
_user_states_lock = __import__('threading').Lock()
def get_user_state(session_id: str) -> UserState:
"""Get or create user state."""
with _user_states_lock:
if session_id not in _user_states:
_user_states[session_id] = UserState(session_id)
logger.info(f"Created user state: {session_id}")
return _user_states[session_id]
def cleanup_user_states():
"""Cleanup old user states."""
with _user_states_lock:
if len(_user_states) > config.max_sessions:
# Remove oldest 10%
sorted_states = sorted(
_user_states.items(),
key=lambda x: x[1].kg.version
)
for sid, _ in sorted_states[:len(sorted_states) // 10]:
del _user_states[sid]
# ============================================================================
# DASH APPLICATION
# ============================================================================
app = dash.Dash(
__name__,
external_stylesheets=[
dbc.themes.DARKLY,
"https://fonts.googleapis.com/css2?family=DM+Sans:wght@400;500;600&display=swap",
],
suppress_callback_exceptions=True,
title="HITL-KG Medical Reasoning",
)
server = app.server
# ============================================================================
# LAYOUT COMPONENTS
# ============================================================================
def create_header():
"""Application header."""
status = "🟢 OpenAI" if OPENAI_API_KEY else "🟡 Local"
return dbc.Navbar(
dbc.Container([
dbc.Row([
dbc.Col([
html.Span("⚕️", style={"fontSize": "1.5rem", "marginRight": "10px"}),
html.Span("HITL-KG", style={"fontWeight": "700", "fontSize": "1.2rem"}),
dbc.Badge("Medical Reasoning", color="info", className="ms-2"),
], className="d-flex align-items-center"),
]),
dbc.Row([
dbc.Col([
html.Span(status, className="me-3", style={"fontSize": "0.85rem"}),
html.Span(id="language-indicator", children="🌐 EN"),
]),
dbc.Col([
dbc.Button("❓ Help", id="btn-help", size="sm", outline=True, className="me-2"),
dbc.Button("↺ Reset", id="btn-reset", size="sm", outline=True),
], width="auto"),
], className="g-0"),
], fluid=True),
dark=True,
className="mb-3",
style={"backgroundColor": "#1e293b"},
)
def create_chat_panel():
"""Chat interface panel with tabs for current chat and history."""
return dbc.Card([
dbc.CardHeader([
html.Div([
html.Span("💬", className="me-2"),
html.Span("Symptom Analysis", style={"fontWeight": "600"}),
]),
dbc.Button("+ New", id="btn-new-chat", size="sm", color="primary",
className="float-end"),
], className="d-flex justify-content-between align-items-center"),
dbc.CardBody([
# Tabs for Chat and History
dbc.Tabs([
dbc.Tab(label="💬 Chat", tab_id="tab-chat", children=[
html.Div([
# Provider selector
html.Div([
html.Label("AI Model", className="small mt-2", style={"color": "#94a3b8"}),
dcc.Dropdown(
id="provider-select",
options=AVAILABLE_PROVIDERS,
value=DEFAULT_PROVIDER,
clearable=False,
className="mb-3",
),
]),
# Chat history
html.Div(
id="chat-history",
className="chat-container",
style={
"height": "180px", "overflowY": "auto",
"backgroundColor": "#0f172a", "borderRadius": "8px",
"padding": "10px", "marginBottom": "15px",
},
children=[create_welcome_message()]
),
# Quick examples
html.Div([
html.Label("Quick examples:", className="small mb-2",
style={"color": "#94a3b8"}),
html.Div(id="example-queries", children=[
html.Button(
[html.Span(q["icon"], className="me-1"), q["text"]],
id={"type": "example", "index": i},
className="btn btn-outline-secondary btn-sm me-2 mb-2",
)
for i, q in enumerate(EXAMPLE_QUERIES["en"])
], className="d-flex flex-wrap"),
], className="mb-3"),
# Input
dbc.Textarea(
id="chat-input",
placeholder="Describe your symptoms...",
style={"height": "60px", "resize": "none"},
className="mb-2",
),
html.Div([
dbc.Button("🔍 Analyze", id="btn-send", color="primary", className="me-2"),
dbc.Button("🗑️ Clear", id="btn-clear", color="secondary", outline=True, size="sm"),
]),
])
]),
dbc.Tab(label="📜 History", tab_id="tab-history", children=[
html.Div([
html.Label("Session History", className="small mt-2 mb-2",
style={"color": "#94a3b8"}),
html.Div(
id="session-history-list",
style={
"height": "350px", "overflowY": "auto",
"backgroundColor": "#0f172a", "borderRadius": "8px",
"padding": "10px",
},
children=[
html.P("No saved sessions yet.",
className="text-muted small text-center mt-3")
]
),
html.Div([
dbc.Button("💾 Save Session", id="btn-save-session",
color="success", size="sm", className="mt-2 me-2"),
dbc.Button("🗑️ Clear History", id="btn-clear-history",
color="danger", size="sm", className="mt-2", outline=True),
]),
])
]),
], id="chat-tabs", active_tab="tab-chat"),
]),
], style={"backgroundColor": "#1e293b"})
def create_graph_panel():
"""Graph visualization panel."""
return dbc.Card([
dbc.CardHeader([
html.Div([
html.Span("🧠", className="me-2"),
html.Span("Reasoning Graph", style={"fontWeight": "600"}),
]),
html.Div([
dbc.Button("−", id="btn-zoom-out", size="sm", outline=True),
dbc.Button("⟲", id="btn-zoom-fit", size="sm", outline=True, className="mx-1"),
dbc.Button("+", id="btn-zoom-in", size="sm", outline=True),
dbc.ButtonGroup([
dbc.Button("⇄", id="btn-layout-dag", size="sm", outline=True, active=True),
dbc.Button("◎", id="btn-layout-force", size="sm", outline=True),
dbc.Button("◉", id="btn-layout-radial", size="sm", outline=True),
], size="sm", className="ms-2"),
], className="d-flex"),
], className="d-flex justify-content-between align-items-center"),
dbc.CardBody([
cyto.Cytoscape(
id="reasoning-graph",
elements=[],
layout=LAYOUT_CONFIGS.get("hierarchical", {"name": "dagre"}),
style={"width": "100%", "height": "350px", "backgroundColor": "#0a1020"},
stylesheet=CYTOSCAPE_STYLESHEET,
boxSelectionEnabled=True,
minZoom=0.2,
maxZoom=3.0,
),
# Legend
html.Div([
html.Div([
html.Span("●", style={"color": info["color"], "marginRight": "4px"}),
html.Span(info["name"], style={"fontSize": "0.75rem", "marginRight": "10px"}),
], className="d-inline-block")
for ntype, info in list(NODE_TYPE_INFO.items())[:6]
], className="mt-2 text-center"),
# Stats
html.Div(
id="stats-display",
children="Ready — Enter symptoms to begin",
className="mt-2 text-center",
style={"color": "#94a3b8", "fontSize": "0.85rem"},
),
]),
], style={"backgroundColor": "#1e293b"})
def create_control_panel():
"""Steering controls panel."""
return dbc.Card([
dbc.CardHeader([
html.Span("🎛️", className="me-2"),
html.Span("Controls", style={"fontWeight": "600"}),
]),
dbc.CardBody([
# Selected node info
html.Div([
html.Label("Selected Node", className="small", style={"color": "#94a3b8"}),
html.Div(id="selected-node-info", children=[
html.P("👆 Click a node", className="text-muted small")
], style={"minHeight": "80px"}),
], className="mb-3"),
html.Hr(style={"borderColor": "#475569"}),
# Feedback
html.Div([
html.Label("Feedback (RLHF)", className="small", style={"color": "#94a3b8"}),
html.Div([
dbc.Button("✓ Correct", id="btn-correct", color="success",
size="sm", disabled=True, className="me-2"),
dbc.Button("✗ Incorrect", id="btn-incorrect", color="danger",
size="sm", disabled=True),
], className="mb-2"),
html.Div(id="feedback-status"),
], className="mb-3"),
html.Hr(style={"borderColor": "#475569"}),
# Display options
html.Div([
html.Label("Display", className="small", style={"color": "#94a3b8"}),
dbc.Checklist(
id="display-options",
options=[{"label": " Show pruned paths", "value": "ghosts"}],
value=[],
switch=True,
className="mb-2",
),
html.Label("Confidence threshold", className="small",
style={"color": "#64748b", "fontSize": "0.75rem"}),
dcc.Slider(
id="confidence-slider",
min=0, max=1, step=0.1, value=0,
marks={0: "0%", 0.5: "50%", 1: "100%"},
),
], className="mb-3"),
html.Hr(style={"borderColor": "#475569"}),
# Actions
html.Div([
html.Label("Actions", className="small", style={"color": "#94a3b8"}),
dbc.ButtonGroup([
dbc.Button("✂️ Prune", id="btn-prune", color="danger",
size="sm", disabled=True),
dbc.Button("👻 Resurrect", id="btn-resurrect", color="warning",
size="sm", disabled=True),
], className="w-100 mb-2"),
dbc.Button("🚀 Start reasoning from here", id="btn-branch",
color="info", size="sm", disabled=True, className="w-100"),
], className="mb-3"),
html.Hr(style={"borderColor": "#475569"}),
# Fact injection
html.Div([
html.Label("Inject Fact", className="small", style={"color": "#94a3b8"}),
dbc.InputGroup([
dbc.Input(id="fact-input", placeholder="e.g., Patient has diabetes",
size="sm"),
dbc.Button("➕", id="btn-inject", color="success", size="sm",
disabled=True),
], size="sm"),
]),
], style={"maxHeight": "calc(100vh - 200px)", "overflowY": "auto"}),
], style={"backgroundColor": "#1e293b"})
def create_welcome_message():
"""Welcome message for chat."""
return html.Div([
html.Div("👋", className="text-center", style={"fontSize": "1.5rem"}),
html.P("Welcome to HITL-KG Medical Reasoning",
className="text-center mb-1", style={"fontWeight": "600"}),
html.P("Describe symptoms to see AI reasoning visualized.",
className="text-center small text-muted"),
html.P("🌐 Type in any language!",
className="text-center small", style={"fontStyle": "italic", "color": "#64748b"}),
], className="p-3")
def create_help_modal():
"""Help modal."""
return dbc.Modal([
dbc.ModalHeader(dbc.ModalTitle("📖 How to Use")),
dbc.ModalBody([
html.H6("🎯 Overview"),
html.P("HITL-KG visualizes AI medical reasoning as an interactive graph."),
html.H6("🔵 Node Types", className="mt-3"),
html.Ul([
html.Li([html.Strong(info["icon"] + " " + info["name"] + ": "), info["description"]])
for info in list(NODE_TYPE_INFO.values())[:6]
]),
html.H6("🎮 Interactions", className="mt-3"),
html.Ul([
html.Li("Click nodes to select and view details"),
html.Li("Prune to remove incorrect reasoning paths"),
html.Li("Resurrect to restore pruned nodes"),
html.Li("Inject facts to add medical information"),
]),
html.H6("🌐 Languages", className="mt-3"),
html.P("Supports English, Ukrainian, Russian, Spanish, German, French and more."),
]),
dbc.ModalFooter(dbc.Button("Got it!", id="btn-close-help", color="primary")),
], id="help-modal", size="lg", is_open=False)
# Main layout
app.layout = html.Div([
# Stores
dcc.Store(id="session-id", storage_type="session"),
dcc.Store(id="selected-node-id", data=None),
dcc.Store(id="graph-version", data=0),
# Header
create_header(),
# Help modal
create_help_modal(),
# Main content
dbc.Container([
dbc.Row([
dbc.Col(create_chat_panel(), lg=3, md=12, className="mb-3"),
dbc.Col(create_graph_panel(), lg=6, md=12, className="mb-3"),
dbc.Col(create_control_panel(), lg=3, md=12, className="mb-3"),
], className="g-3"),
], fluid=True),
# Loading indicator
dcc.Loading(id="loading", type="circle", fullscreen=False,
children=html.Div(id="loading-target")),
# Cleanup interval
dcc.Interval(id="cleanup-interval", interval=300000),
], style={"minHeight": "100vh", "backgroundColor": "#0f172a"})
# ============================================================================
# CALLBACKS
# ============================================================================
@callback(
Output("session-id", "data"),
Input("session-id", "data"),
)
def init_session(existing_id):
"""Initialize session."""
if existing_id:
return existing_id
return str(uuid.uuid4())[:12]
@callback(
Output("help-modal", "is_open"),
[Input("btn-help", "n_clicks"), Input("btn-close-help", "n_clicks")],
State("help-modal", "is_open"),
)
def toggle_help(open_clicks, close_clicks, is_open):
"""Toggle help modal."""
if open_clicks or close_clicks:
return not is_open
return is_open
@callback(
Output("loading-target", "children"),
Input("provider-select", "value"),
State("session-id", "data"),
)
def switch_provider(provider, session_id):
"""Switch LLM provider."""
if provider and session_id:
state = get_user_state(session_id)
state.set_provider(provider)
return ""
@callback(
Output("chat-input", "value", allow_duplicate=True),
Input({"type": "example", "index": ALL}, "n_clicks"),
prevent_initial_call=True
)
def fill_example(clicks):
"""Fill input with example query."""
if not any(clicks):
raise PreventUpdate
triggered = ctx.triggered_id
if triggered and isinstance(triggered, dict):
idx = triggered.get("index", 0)
if idx < len(EXAMPLE_QUERIES["en"]):
return EXAMPLE_QUERIES["en"][idx]["text"]
raise PreventUpdate
@callback(
[
Output("reasoning-graph", "elements"),
Output("chat-history", "children"),
Output("stats-display", "children"),
Output("chat-input", "value"),
Output("graph-version", "data"),
Output("language-indicator", "children"),
],
[
Input("btn-send", "n_clicks"),
Input("btn-clear", "n_clicks"),
Input("btn-reset", "n_clicks"),
Input("btn-new-chat", "n_clicks"),
],
[
State("chat-input", "value"),
State("selected-node-id", "data"),
State("display-options", "value"),
State("confidence-slider", "value"),
State("graph-version", "data"),
State("session-id", "data"),
],
prevent_initial_call=True
)
def handle_main_actions(send_clicks, clear_clicks, reset_clicks, new_clicks,
input_text, selected_node, options, conf_threshold,
version, session_id):
"""Handle main user actions."""
if not session_id:
raise PreventUpdate
state = get_user_state(session_id)
triggered = ctx.triggered_id
# Reset/Clear/New
if triggered in ["btn-clear", "btn-reset", "btn-new-chat"]:
state.reset()
return (
[],
[create_welcome_message()],
"Ready — Enter symptoms to begin",
"",
0,
"🌐 EN",
)
# Send
if triggered == "btn-send":
if not input_text or not input_text.strip():
raise PreventUpdate
# Detect language
lang = detect_language(input_text)
state.language = lang
# Add user message
state.add_message("user", input_text.strip())
# Generate reasoning
try:
context = state.engine.build_context(input_text, selected_node)
context.language = lang
config = GenerationConfig(
model="gpt-4o-mini" if state.provider == "openai" else "local",
language=lang
)
response_content = ""
node_count = 0
for node in state.engine.generate(context, config):
node_count += 1
if node.node_type == NodeType.CONCLUSION:
response_content = node.content
# Debug: log graph connectivity stats
stats = state.kg.get_stats()
logger.info(f"Generation complete: {node_count} nodes generated, graph has {stats['nodes']} nodes and {stats['edges']} edges")
state.add_message(
"assistant",
response_content or "Analysis complete. See the reasoning graph."
)
except Exception as e:
logger.error(f"Generation error: {e}")
state.add_message("error", f"Analysis failed: {str(e)}")
# Save state
state.save()
# Build response
chat_display = build_chat_display(state.get_chat_history())
include_ghosts = "ghosts" in (options or [])
elements = state.kg.to_cytoscape_elements(
include_ghosts=include_ghosts,
confidence_threshold=conf_threshold
)
stats = state.kg.get_stats()
stats_text = f"📊 {stats['nodes']} nodes • {stats['edges']} edges"
return (
elements,
chat_display,
stats_text,
"",
version + 1,
f"🌐 {lang.upper()}",
)
raise PreventUpdate
@callback(
Output("reasoning-graph", "elements", allow_duplicate=True),
[Input("display-options", "value"), Input("confidence-slider", "value")],
State("session-id", "data"),
prevent_initial_call=True
)
def update_display(options, threshold, session_id):
"""Update graph display options."""
if not session_id:
raise PreventUpdate
state = get_user_state(session_id)
include_ghosts = "ghosts" in (options or [])
return state.kg.to_cytoscape_elements(
include_ghosts=include_ghosts,
confidence_threshold=threshold
)
@callback(
[
Output("selected-node-info", "children"),
Output("selected-node-id", "data"),
Output("btn-prune", "disabled"),
Output("btn-resurrect", "disabled"),
Output("btn-inject", "disabled"),
Output("btn-correct", "disabled"),
Output("btn-incorrect", "disabled"),
Output("btn-branch", "disabled"),
],
Input("reasoning-graph", "tapNodeData"),
)
def handle_node_click(node_data):
"""Handle node selection."""
if not node_data:
return (
html.P("👆 Click a node", className="text-muted small"),
None, True, True, True, True, True, True
)
node_id = node_data.get("id")
node_type = node_data.get("type", "unknown")
confidence = node_data.get("confidence", 0)
content = node_data.get("content", node_data.get("full_label", ""))
info = NODE_TYPE_INFO.get(NodeType(node_type), {"icon": "●", "name": "Unknown", "color": "#64748b"})
node_info = html.Div([
html.Div([
html.Span(info["icon"], className="me-2"),
dbc.Badge(node_type.upper(), style={"backgroundColor": info["color"]}),
html.Span(f" {confidence:.0%}", className="ms-2",
style={"color": "#34d399" if confidence > 0.7 else "#facc15"}),
], className="mb-2"),
html.Div(content[:200], style={"fontSize": "0.85rem", "color": "#e2e8f0"}),
])
is_ghost = node_type == "ghost"
can_prune = node_type not in ["query", "ghost"]
can_feedback = node_type in ["hypothesis", "conclusion", "reasoning"]
can_branch = node_type in ["query", "hypothesis", "reasoning", "fact"]
return (
node_info,
node_id,
not can_prune,
not is_ghost,
node_id is None,
not can_feedback,
not can_feedback,
not can_branch,
)
@callback(
[
Output("reasoning-graph", "elements", allow_duplicate=True),
Output("stats-display", "children", allow_duplicate=True),
Output("feedback-status", "children"),
Output("fact-input", "value"),
],
[
Input("btn-prune", "n_clicks"),
Input("btn-resurrect", "n_clicks"),
Input("btn-inject", "n_clicks"),
Input("btn-correct", "n_clicks"),
Input("btn-incorrect", "n_clicks"),
],
[
State("selected-node-id", "data"),
State("fact-input", "value"),
State("display-options", "value"),
State("confidence-slider", "value"),
State("session-id", "data"),
],
prevent_initial_call=True
)
def handle_actions(prune, resurrect, inject, correct, incorrect,
selected_node, fact_text, options, threshold, session_id):
"""Handle steering actions."""
if not session_id:
raise PreventUpdate
state = get_user_state(session_id)
triggered = ctx.triggered_id
feedback_status = dash.no_update
clear_fact_input = dash.no_update
sm = get_session_manager()
logger.info(f"Action triggered: {triggered}, selected_node: {selected_node}")
if triggered == "btn-prune" and selected_node:
result = state.synchronizer.prune_node(selected_node)
sm.record_interaction(session_id, 'prune', node_id=selected_node)
logger.info(f"Pruned node {selected_node}: {result}")
feedback_status = html.Small(f"✂️ Pruned node", style={"color": "#f87171"})
elif triggered == "btn-resurrect" and selected_node:
result = state.synchronizer.resurrect_node(selected_node)
sm.record_interaction(session_id, 'resurrect', node_id=selected_node)
logger.info(f"Resurrected node {selected_node}: {result}")
feedback_status = html.Small(f"👻 Resurrected", style={"color": "#facc15"})
elif triggered == "btn-inject" and selected_node and fact_text:
result = state.synchronizer.inject_fact(selected_node, fact_text)
sm.record_interaction(session_id, 'inject', node_id=selected_node, content=fact_text)
logger.info(f"Injected fact to {selected_node}: {result}")
feedback_status = html.Small(f"➕ Fact injected", style={"color": "#4ade80"})
clear_fact_input = "" # Clear the input
elif triggered == "btn-correct" and selected_node:
state.synchronizer.record_feedback(selected_node, "correct")
sm.add_feedback(session_id, selected_node, "correct")
feedback_status = html.Small("✓ Marked correct", style={"color": "#4ade80"})
elif triggered == "btn-incorrect" and selected_node:
state.synchronizer.record_feedback(selected_node, "incorrect")
sm.add_feedback(session_id, selected_node, "incorrect")
feedback_status = html.Small("✗ Marked incorrect", style={"color": "#f87171"})
else:
raise PreventUpdate
state.save()
include_ghosts = "ghosts" in (options or [])
elements = state.kg.to_cytoscape_elements(
include_ghosts=include_ghosts,
confidence_threshold=threshold
)
stats = state.kg.get_stats()
return (
elements,
f"📊 {stats['nodes']} nodes • {stats['edges']} edges",
feedback_status,
clear_fact_input,
)
# Store for branch anchor node
_branch_anchor_store: Dict[str, str] = {}
@callback(
[Output("chat-input", "placeholder"),
Output("chat-input", "value", allow_duplicate=True)],
Input("btn-branch", "n_clicks"),
[State("selected-node-id", "data"),
State("session-id", "data")],
prevent_initial_call=True
)
def set_branch_anchor(n_clicks, selected_node, session_id):
"""Set anchor node for branching reasoning."""
if not n_clicks or not selected_node or not session_id:
raise PreventUpdate
_branch_anchor_store[session_id] = selected_node
logger.info(f"Set branch anchor for session {session_id}: {selected_node}")
return "Enter new reasoning to branch from selected node...", ""
@callback(
Output("reasoning-graph", "elements", allow_duplicate=True),
Output("chat-history", "children", allow_duplicate=True),
Output("stats-display", "children", allow_duplicate=True),
Output("chat-input", "placeholder", allow_duplicate=True),
Input("btn-send", "n_clicks"),
[State("chat-input", "value"),
State("display-options", "value"),
State("confidence-slider", "value"),
State("session-id", "data")],
prevent_initial_call=True
)
def handle_branch_send(n_clicks, input_text, options, threshold, session_id):
"""Handle sending with potential branch anchor."""
if not n_clicks or not input_text or not session_id:
raise PreventUpdate
# Check if there's a branch anchor set
anchor_node = _branch_anchor_store.pop(session_id, None)
if not anchor_node:
# No anchor, let the main callback handle it
raise PreventUpdate
state = get_user_state(session_id)
lang = detect_language(input_text)
state.language = lang
# Add user message
state.add_message("user", f"[Branching from node] {input_text.strip()}")
# Generate reasoning from anchor
try:
context = state.engine.build_context(input_text, anchor_node)
context.language = lang
context.is_branching = True
config = GenerationConfig(
model="gpt-4o-mini" if state.provider == "openai" else "local",
language=lang
)
response_content = ""
for node in state.engine.generate(context, config):
if node.node_type == NodeType.CONCLUSION:
response_content = node.content
state.add_message(
"assistant",
response_content or "Branch analysis complete. See the reasoning graph."
)
except Exception as e:
logger.error(f"Branch generation error: {e}")
state.add_message("error", f"Branch failed: {str(e)}")
state.save()
chat_display = build_chat_display(state.get_chat_history())
include_ghosts = "ghosts" in (options or [])
elements = state.kg.to_cytoscape_elements(
include_ghosts=include_ghosts,
confidence_threshold=threshold
)
stats = state.kg.get_stats()
return (
elements,
chat_display,
f"📊 {stats['nodes']} nodes • {stats['edges']} edges",
"Describe your symptoms...",
)
@callback(
Output("reasoning-graph", "layout"),
[Input("btn-layout-dag", "n_clicks"),
Input("btn-layout-force", "n_clicks"),
Input("btn-layout-radial", "n_clicks")],
prevent_initial_call=True
)
def change_layout(dag, force, radial):
"""Change graph layout."""
layouts = {
"btn-layout-dag": "hierarchical",
"btn-layout-force": "force",
"btn-layout-radial": "radial",
}
return LAYOUT_CONFIGS.get(layouts.get(ctx.triggered_id, "hierarchical"))
@callback(
Output("reasoning-graph", "zoom"),
[Input("btn-zoom-in", "n_clicks"),
Input("btn-zoom-out", "n_clicks"),
Input("btn-zoom-fit", "n_clicks")],
State("reasoning-graph", "zoom"),
prevent_initial_call=True
)
def handle_zoom(zoom_in, zoom_out, fit, current):
"""Handle zoom controls."""
current = current or 1.0
triggered = ctx.triggered_id
if triggered == "btn-zoom-in":
return min(current * 1.3, 3.0)
elif triggered == "btn-zoom-out":
return max(current * 0.7, 0.2)
return 1.0
@callback(
Output("cleanup-interval", "disabled"),
Input("cleanup-interval", "n_intervals"),
)
def periodic_cleanup(n):
"""Periodic cleanup."""
cleanup_user_states()
get_session_manager().cleanup_stale_sessions()
return False
def build_chat_display(history: List[Dict]) -> List:
"""Build chat display from history."""
if not history:
return [create_welcome_message()]
display = []
for msg in history:
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "user":
display.append(html.Div([
html.Span("You: ", style={"fontWeight": "600", "color": "#a5b4fc"}),
content
], className="mb-2 p-2", style={"backgroundColor": "#1e3a5f", "borderRadius": "8px"}))
elif role == "assistant":
display.append(html.Div([
html.Span("🤖 ", className="me-1"),
content
], className="mb-2 p-2", style={"backgroundColor": "#1e293b", "borderRadius": "8px"}))
else:
display.append(html.Div([
html.Span("⚠️ ", className="me-1"),
content
], className="mb-2 p-2", style={"backgroundColor": "#450a0a", "borderRadius": "8px"}))
return display
# Session history storage (simple in-memory for now)
_session_history_storage: Dict[str, List[Dict]] = {}
@callback(
Output("session-history-list", "children"),
[Input("btn-save-session", "n_clicks"),
Input("btn-clear-history", "n_clicks"),
Input({"type": "load-session", "index": ALL}, "n_clicks")],
[State("session-id", "data"),
State("chat-history", "children")],
prevent_initial_call=True
)
def handle_session_history(save_clicks, clear_clicks, load_clicks, session_id, chat_children):
"""Handle session history operations."""
global _session_history_storage
triggered = ctx.triggered_id
if not session_id:
raise PreventUpdate
# Initialize storage for this user
if session_id not in _session_history_storage:
_session_history_storage[session_id] = []
# Clear history
if triggered == "btn-clear-history":
_session_history_storage[session_id] = []
return [html.P("History cleared.", className="text-muted small text-center mt-3")]
# Save current session
if triggered == "btn-save-session":
state = get_user_state(session_id)
history = state.get_chat_history()
if history:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M")
preview = history[0].get("content", "Empty")[:50] + "..."
saved_session = {
"timestamp": timestamp,
"preview": preview,
"messages": history,
"graph_state": state.kg.get_state()
}
_session_history_storage[session_id].insert(0, saved_session)
# Keep only last 10 sessions
_session_history_storage[session_id] = _session_history_storage[session_id][:10]
# Load session (handled by separate callback)
if isinstance(triggered, dict) and triggered.get("type") == "load-session":
# This is handled by load_session_callback
pass
# Build history list display
sessions = _session_history_storage.get(session_id, [])
if not sessions:
return [html.P("No saved sessions yet.", className="text-muted small text-center mt-3")]
items = []
for i, sess in enumerate(sessions):
items.append(html.Div([
html.Div([
html.Small(sess["timestamp"], className="text-muted"),
html.Div(sess["preview"], style={"fontSize": "0.85rem"}),
], style={"flex": "1"}),
dbc.Button("Load", id={"type": "load-session", "index": i},
size="sm", color="info", outline=True),
], className="d-flex justify-content-between align-items-center p-2 mb-2",
style={"backgroundColor": "#1e3a5f", "borderRadius": "6px"}))
return items
@callback(
[Output("reasoning-graph", "elements", allow_duplicate=True),
Output("chat-history", "children", allow_duplicate=True),
Output("chat-tabs", "active_tab")],
Input({"type": "load-session", "index": ALL}, "n_clicks"),
State("session-id", "data"),
prevent_initial_call=True
)
def load_saved_session(clicks, session_id):
"""Load a saved session."""
if not any(clicks):
raise PreventUpdate
triggered = ctx.triggered_id
if not isinstance(triggered, dict):
raise PreventUpdate
index = triggered.get("index")
sessions = _session_history_storage.get(session_id, [])
if index is None or index >= len(sessions):
raise PreventUpdate
saved = sessions[index]
state = get_user_state(session_id)
# Restore messages using session manager
sm = get_session_manager()
session = sm.get_or_create(session_id)
session.chat_history.clear()
for msg in saved["messages"]:
state.add_message(msg["role"], msg["content"])
# Restore graph state
if saved.get("graph_state"):
state.kg.restore_state(saved["graph_state"])
# Build display
chat_display = build_chat_display(saved["messages"])
elements = state.kg.to_cytoscape_elements()
return elements, chat_display, "tab-chat"
# ============================================================================
# MAIN
# ============================================================================
if __name__ == "__main__":
print("=" * 60)
print(" ⚕️ HITL-KG Medical Reasoning System")
print("=" * 60)
print(f" 🔑 OpenAI: {'✅' if OPENAI_API_KEY else '❌ Local mode'}")
print(f" 🌐 Embeddings: Multilingual (50+ languages)")
print(f" 📊 Default provider: {DEFAULT_PROVIDER}")
print(f" 🚀 Starting at http://localhost:{config.port}")
print("=" * 60)
app.run(
debug=config.debug,
host=config.host,
port=config.port,
)