|
|
""" |
|
|
Knowledge Graph Core Module (Refactored) |
|
|
|
|
|
Simplified knowledge graph implementation focusing on: |
|
|
1. Clean data structures |
|
|
2. Proper CRUD operations |
|
|
3. Cytoscape visualization support |
|
|
4. Session serialization |
|
|
|
|
|
Removed over-engineered TreeCache in favor of simpler state management. |
|
|
""" |
|
|
|
|
|
import uuid |
|
|
import json |
|
|
import logging |
|
|
from enum import Enum |
|
|
from datetime import datetime |
|
|
from dataclasses import dataclass, field, asdict |
|
|
from typing import Dict, List, Optional, Any, Tuple, Set |
|
|
|
|
|
import networkx as nx |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class NodeType(str, Enum): |
|
|
"""Types of nodes in the reasoning graph.""" |
|
|
QUERY = "query" |
|
|
FACT = "fact" |
|
|
REASONING = "reasoning" |
|
|
HYPOTHESIS = "hypothesis" |
|
|
CONCLUSION = "conclusion" |
|
|
EVIDENCE = "evidence" |
|
|
CONSTRAINT = "constraint" |
|
|
GHOST = "ghost" |
|
|
|
|
|
|
|
|
class EdgeType(str, Enum): |
|
|
"""Types of relationships between nodes.""" |
|
|
LEADS_TO = "leads_to" |
|
|
SUPPORTS = "supports" |
|
|
CONTRADICTS = "contradicts" |
|
|
REQUIRES = "requires" |
|
|
ALTERNATIVE = "alternative" |
|
|
FOLLOW_UP = "follow_up" |
|
|
CAUSES = "causes" |
|
|
TREATS = "treats" |
|
|
INDICATES = "indicates" |
|
|
|
|
|
|
|
|
class EntityCategory(str, Enum): |
|
|
"""Categories of medical entities.""" |
|
|
SYMPTOM = "symptom" |
|
|
DISEASE = "disease" |
|
|
TREATMENT = "treatment" |
|
|
MEDICATION = "medication" |
|
|
PROCEDURE = "procedure" |
|
|
FINDING = "finding" |
|
|
ANATOMY = "anatomy" |
|
|
|
|
|
|
|
|
|
|
|
NODE_TYPE_INFO = { |
|
|
NodeType.QUERY: { |
|
|
"icon": "❓", "name": "Query", "color": "#38bdf8", |
|
|
"description": "Your input question or symptom description" |
|
|
}, |
|
|
NodeType.FACT: { |
|
|
"icon": "📋", "name": "Fact", "color": "#4ade80", |
|
|
"description": "Verified medical fact from knowledge base" |
|
|
}, |
|
|
NodeType.REASONING: { |
|
|
"icon": "🔍", "name": "Reasoning", "color": "#818cf8", |
|
|
"description": "Logical inference step" |
|
|
}, |
|
|
NodeType.HYPOTHESIS: { |
|
|
"icon": "💡", "name": "Hypothesis", "color": "#fbbf24", |
|
|
"description": "Potential diagnosis being considered" |
|
|
}, |
|
|
NodeType.CONCLUSION: { |
|
|
"icon": "✅", "name": "Conclusion", "color": "#f472b6", |
|
|
"description": "Final diagnostic conclusion" |
|
|
}, |
|
|
NodeType.EVIDENCE: { |
|
|
"icon": "📊", "name": "Evidence", "color": "#2dd4bf", |
|
|
"description": "Supporting medical evidence" |
|
|
}, |
|
|
NodeType.CONSTRAINT: { |
|
|
"icon": "⚠️", "name": "Constraint", "color": "#fb7185", |
|
|
"description": "Limitation or warning" |
|
|
}, |
|
|
NodeType.GHOST: { |
|
|
"icon": "👻", "name": "Ghost", "color": "#94a3b8", |
|
|
"description": "Pruned reasoning path" |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
def create_node_id() -> str: |
|
|
"""Generate a unique node ID.""" |
|
|
return f"n_{uuid.uuid4().hex[:8]}" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ReasoningNode: |
|
|
"""A node in the reasoning graph.""" |
|
|
id: str |
|
|
label: str |
|
|
node_type: NodeType |
|
|
content: str |
|
|
confidence: float = 1.0 |
|
|
kg_entity_id: Optional[str] = None |
|
|
metadata: Dict[str, Any] = field(default_factory=dict) |
|
|
timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) |
|
|
language: str = "en" |
|
|
|
|
|
def to_dict(self) -> Dict: |
|
|
"""Serialize to dictionary.""" |
|
|
return { |
|
|
"id": self.id, |
|
|
"label": self.label, |
|
|
"node_type": self.node_type.value if isinstance(self.node_type, NodeType) else self.node_type, |
|
|
"content": self.content, |
|
|
"confidence": self.confidence, |
|
|
"kg_entity_id": self.kg_entity_id, |
|
|
"metadata": self.metadata, |
|
|
"timestamp": self.timestamp, |
|
|
"language": self.language, |
|
|
} |
|
|
|
|
|
def to_cytoscape(self) -> Dict: |
|
|
"""Convert to Cytoscape element format.""" |
|
|
type_val = self.node_type.value if isinstance(self.node_type, NodeType) else self.node_type |
|
|
type_info = NODE_TYPE_INFO.get(NodeType(type_val), {"icon": "●", "name": "Unknown"}) |
|
|
|
|
|
|
|
|
display_label = self.label[:60] + "..." if len(self.label) > 60 else self.label |
|
|
|
|
|
return { |
|
|
"data": { |
|
|
"id": self.id, |
|
|
"label": display_label, |
|
|
"full_label": self.label, |
|
|
"type": type_val, |
|
|
"content": self.content, |
|
|
"confidence": self.confidence, |
|
|
"kg_entity_id": self.kg_entity_id or "", |
|
|
"timestamp": self.timestamp, |
|
|
"type_icon": type_info["icon"], |
|
|
"type_name": type_info["name"], |
|
|
"language": self.language, |
|
|
}, |
|
|
"classes": type_val |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, data: Dict) -> "ReasoningNode": |
|
|
"""Deserialize from dictionary.""" |
|
|
node_type = data.get("node_type", "reasoning") |
|
|
if isinstance(node_type, str): |
|
|
node_type = NodeType(node_type) |
|
|
|
|
|
return cls( |
|
|
id=data["id"], |
|
|
label=data["label"], |
|
|
node_type=node_type, |
|
|
content=data["content"], |
|
|
confidence=data.get("confidence", 1.0), |
|
|
kg_entity_id=data.get("kg_entity_id"), |
|
|
metadata=data.get("metadata", {}), |
|
|
timestamp=data.get("timestamp", datetime.now().isoformat()), |
|
|
language=data.get("language", "en"), |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ReasoningEdge: |
|
|
"""An edge in the reasoning graph.""" |
|
|
source: str |
|
|
target: str |
|
|
edge_type: EdgeType |
|
|
weight: float = 1.0 |
|
|
label: str = "" |
|
|
metadata: Dict[str, Any] = field(default_factory=dict) |
|
|
|
|
|
@property |
|
|
def id(self) -> str: |
|
|
return f"{self.source}-{self.target}" |
|
|
|
|
|
def to_dict(self) -> Dict: |
|
|
"""Serialize to dictionary.""" |
|
|
return { |
|
|
"source": self.source, |
|
|
"target": self.target, |
|
|
"edge_type": self.edge_type.value if isinstance(self.edge_type, EdgeType) else self.edge_type, |
|
|
"weight": self.weight, |
|
|
"label": self.label, |
|
|
"metadata": self.metadata, |
|
|
} |
|
|
|
|
|
def to_cytoscape(self) -> Dict: |
|
|
"""Convert to Cytoscape element format.""" |
|
|
type_val = self.edge_type.value if isinstance(self.edge_type, EdgeType) else self.edge_type |
|
|
|
|
|
return { |
|
|
"data": { |
|
|
"id": self.id, |
|
|
"source": self.source, |
|
|
"target": self.target, |
|
|
"type": type_val, |
|
|
"weight": self.weight, |
|
|
"label": self.label or type_val.replace("_", " ").title(), |
|
|
}, |
|
|
"classes": type_val |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, data: Dict) -> "ReasoningEdge": |
|
|
"""Deserialize from dictionary.""" |
|
|
edge_type = data.get("edge_type", "leads_to") |
|
|
if isinstance(edge_type, str): |
|
|
edge_type = EdgeType(edge_type) |
|
|
|
|
|
return cls( |
|
|
source=data["source"], |
|
|
target=data["target"], |
|
|
edge_type=edge_type, |
|
|
weight=data.get("weight", 1.0), |
|
|
label=data.get("label", ""), |
|
|
metadata=data.get("metadata", {}), |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Entity: |
|
|
"""A knowledge base entity (symptom, disease, treatment, etc.).""" |
|
|
id: str |
|
|
name: str |
|
|
category: EntityCategory |
|
|
description: str = "" |
|
|
synonyms: List[str] = field(default_factory=list) |
|
|
properties: Dict[str, Any] = field(default_factory=dict) |
|
|
xrefs: Dict[str, str] = field(default_factory=dict) |
|
|
|
|
|
def to_dict(self) -> Dict: |
|
|
"""Serialize to dictionary.""" |
|
|
return { |
|
|
"id": self.id, |
|
|
"name": self.name, |
|
|
"category": self.category.value if isinstance(self.category, EntityCategory) else self.category, |
|
|
"description": self.description, |
|
|
"synonyms": self.synonyms, |
|
|
"properties": self.properties, |
|
|
"xrefs": self.xrefs, |
|
|
} |
|
|
|
|
|
def to_embedding_text(self) -> str: |
|
|
"""Generate text for embedding.""" |
|
|
parts = [self.name] |
|
|
if self.description: |
|
|
parts.append(self.description) |
|
|
parts.extend(self.synonyms) |
|
|
return " ".join(parts) |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, data: Dict) -> "Entity": |
|
|
"""Deserialize from dictionary.""" |
|
|
category = data.get("category", "finding") |
|
|
if isinstance(category, str): |
|
|
try: |
|
|
category = EntityCategory(category) |
|
|
except ValueError: |
|
|
category = EntityCategory.FINDING |
|
|
|
|
|
return cls( |
|
|
id=data["id"], |
|
|
name=data["name"], |
|
|
category=category, |
|
|
description=data.get("description", ""), |
|
|
synonyms=data.get("synonyms", []), |
|
|
properties=data.get("properties", {}), |
|
|
xrefs=data.get("xrefs", {}), |
|
|
) |
|
|
|
|
|
|
|
|
class KnowledgeGraph: |
|
|
""" |
|
|
Core Knowledge Graph managing: |
|
|
1. Static knowledge base (entities and relations) |
|
|
2. Dynamic reasoning graph (nodes and edges) |
|
|
|
|
|
Simplified from original - removed TreeCache, streamlined operations. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
|
|
|
self.entities: Dict[str, Entity] = {} |
|
|
self.kb_graph = nx.DiGraph() |
|
|
|
|
|
|
|
|
self.nodes: Dict[str, ReasoningNode] = {} |
|
|
self.edges: Dict[str, ReasoningEdge] = {} |
|
|
self.reasoning_graph = nx.DiGraph() |
|
|
|
|
|
|
|
|
self.version = 0 |
|
|
self._last_node_id: Optional[str] = None |
|
|
|
|
|
|
|
|
|
|
|
def add_entity(self, entity: Entity): |
|
|
"""Add an entity to the knowledge base.""" |
|
|
self.entities[entity.id] = entity |
|
|
entity_dict = entity.to_dict() |
|
|
|
|
|
entity_dict.pop('id', None) |
|
|
entity_dict.pop('category', None) |
|
|
self.kb_graph.add_node( |
|
|
entity.id, |
|
|
category=entity.category.value, |
|
|
**entity_dict |
|
|
) |
|
|
|
|
|
def add_relation( |
|
|
self, |
|
|
source_id: str, |
|
|
target_id: str, |
|
|
relation_type: str, |
|
|
weight: float = 1.0, |
|
|
**properties |
|
|
): |
|
|
"""Add a relationship between entities.""" |
|
|
self.kb_graph.add_edge( |
|
|
source_id, target_id, |
|
|
relation=relation_type, |
|
|
weight=weight, |
|
|
**properties |
|
|
) |
|
|
|
|
|
def get_entity(self, entity_id: str) -> Optional[Entity]: |
|
|
"""Get entity by ID.""" |
|
|
return self.entities.get(entity_id) |
|
|
|
|
|
def get_related_entities( |
|
|
self, |
|
|
entity_id: str, |
|
|
relation_type: Optional[str] = None |
|
|
) -> List[Tuple[str, str, Dict]]: |
|
|
"""Get entities related to a given entity.""" |
|
|
if entity_id not in self.kb_graph: |
|
|
return [] |
|
|
|
|
|
results = [] |
|
|
for _, target, data in self.kb_graph.out_edges(entity_id, data=True): |
|
|
if relation_type is None or data.get("relation") == relation_type: |
|
|
results.append((target, data.get("relation"), data)) |
|
|
|
|
|
return results |
|
|
|
|
|
def get_entities_by_category(self, category: EntityCategory) -> List[Entity]: |
|
|
"""Get all entities of a specific category.""" |
|
|
return [e for e in self.entities.values() if e.category == category] |
|
|
|
|
|
def get_diseases_for_symptoms( |
|
|
self, |
|
|
symptom_ids: List[str] |
|
|
) -> List[Tuple[Entity, float]]: |
|
|
"""Get possible diseases given a set of symptoms with match scores.""" |
|
|
disease_scores: Dict[str, float] = {} |
|
|
|
|
|
for symptom_id in symptom_ids: |
|
|
|
|
|
for source, _, data in self.kb_graph.in_edges(symptom_id, data=True): |
|
|
if data.get("relation") == "causes": |
|
|
entity = self.entities.get(source) |
|
|
if entity and entity.category == EntityCategory.DISEASE: |
|
|
weight = data.get("weight", 1.0) |
|
|
disease_scores[source] = disease_scores.get(source, 0) + weight |
|
|
|
|
|
|
|
|
results = [] |
|
|
for disease_id, score in sorted(disease_scores.items(), key=lambda x: -x[1]): |
|
|
entity = self.entities.get(disease_id) |
|
|
if entity: |
|
|
|
|
|
total_symptoms = len(self.get_symptoms_for_disease(disease_id)) |
|
|
normalized_score = score / max(total_symptoms, 1) |
|
|
results.append((entity, min(normalized_score, 1.0))) |
|
|
|
|
|
return results |
|
|
|
|
|
def get_symptoms_for_disease(self, disease_id: str) -> List[Entity]: |
|
|
"""Get symptoms associated with a disease.""" |
|
|
symptoms = [] |
|
|
for target, relation, _ in self.get_related_entities(disease_id, "causes"): |
|
|
entity = self.entities.get(target) |
|
|
if entity and entity.category == EntityCategory.SYMPTOM: |
|
|
symptoms.append(entity) |
|
|
return symptoms |
|
|
|
|
|
def get_treatments_for_disease(self, disease_id: str) -> List[Entity]: |
|
|
"""Get treatments for a disease.""" |
|
|
treatments = [] |
|
|
|
|
|
|
|
|
for target, relation, _ in self.get_related_entities(disease_id, "treats"): |
|
|
entity = self.entities.get(target) |
|
|
if entity: |
|
|
treatments.append(entity) |
|
|
|
|
|
|
|
|
for source, _, data in self.kb_graph.in_edges(disease_id, data=True): |
|
|
if data.get("relation") == "treats": |
|
|
entity = self.entities.get(source) |
|
|
if entity and entity not in treatments: |
|
|
treatments.append(entity) |
|
|
|
|
|
return treatments |
|
|
|
|
|
|
|
|
|
|
|
def add_node(self, node: ReasoningNode) -> str: |
|
|
"""Add a node to the reasoning graph.""" |
|
|
self.nodes[node.id] = node |
|
|
self.reasoning_graph.add_node(node.id, **node.to_dict()) |
|
|
|
|
|
if node.node_type not in [NodeType.GHOST, NodeType.EVIDENCE]: |
|
|
self._last_node_id = node.id |
|
|
|
|
|
self.version += 1 |
|
|
return node.id |
|
|
|
|
|
def add_edge(self, edge: ReasoningEdge) -> str: |
|
|
"""Add an edge to the reasoning graph.""" |
|
|
|
|
|
if edge.source not in self.nodes: |
|
|
logger.warning(f"Edge source node {edge.source} not found in graph - edge not created") |
|
|
return "" |
|
|
if edge.target not in self.nodes: |
|
|
logger.warning(f"Edge target node {edge.target} not found in graph - edge not created") |
|
|
return "" |
|
|
|
|
|
self.edges[edge.id] = edge |
|
|
self.reasoning_graph.add_edge(edge.source, edge.target, **edge.to_dict()) |
|
|
logger.debug(f"Created edge: {edge.source[:8]}... --[{edge.edge_type.value}]--> {edge.target[:8]}...") |
|
|
self.version += 1 |
|
|
return edge.id |
|
|
|
|
|
def update_node(self, node_id: str, **updates) -> bool: |
|
|
"""Update a node's properties.""" |
|
|
if node_id not in self.nodes: |
|
|
return False |
|
|
|
|
|
node = self.nodes[node_id] |
|
|
for key, value in updates.items(): |
|
|
if hasattr(node, key): |
|
|
setattr(node, key, value) |
|
|
|
|
|
self.reasoning_graph.nodes[node_id].update(node.to_dict()) |
|
|
self.version += 1 |
|
|
return True |
|
|
|
|
|
def delete_node(self, node_id: str) -> List[str]: |
|
|
"""Delete a node and its edges.""" |
|
|
if node_id not in self.nodes: |
|
|
return [] |
|
|
|
|
|
|
|
|
deleted_edges = [] |
|
|
for edge_id in list(self.edges.keys()): |
|
|
edge = self.edges[edge_id] |
|
|
if edge.source == node_id or edge.target == node_id: |
|
|
if self.reasoning_graph.has_edge(edge.source, edge.target): |
|
|
self.reasoning_graph.remove_edge(edge.source, edge.target) |
|
|
del self.edges[edge_id] |
|
|
deleted_edges.append(edge_id) |
|
|
|
|
|
|
|
|
if node_id in self.reasoning_graph: |
|
|
self.reasoning_graph.remove_node(node_id) |
|
|
del self.nodes[node_id] |
|
|
|
|
|
self.version += 1 |
|
|
return deleted_edges |
|
|
|
|
|
def prune_branch(self, node_id: str) -> Dict[str, List[str]]: |
|
|
""" |
|
|
Soft prune: Convert node and descendants to GHOST type. |
|
|
Preserves reasoning history for RLHF and allows resurrection. |
|
|
""" |
|
|
if node_id not in self.nodes: |
|
|
return {"nodes": [], "edges": []} |
|
|
|
|
|
|
|
|
try: |
|
|
descendants = list(nx.descendants(self.reasoning_graph, node_id)) |
|
|
except Exception: |
|
|
descendants = [] |
|
|
|
|
|
all_nodes = [node_id] + descendants |
|
|
affected_edges = [] |
|
|
|
|
|
|
|
|
for nid in all_nodes: |
|
|
if nid in self.nodes: |
|
|
node = self.nodes[nid] |
|
|
node.metadata["original_type"] = node.node_type.value |
|
|
node.node_type = NodeType.GHOST |
|
|
node.confidence *= 0.3 |
|
|
self.reasoning_graph.nodes[nid].update(node.to_dict()) |
|
|
|
|
|
|
|
|
for edge_id, edge in self.edges.items(): |
|
|
if edge.source in all_nodes or edge.target in all_nodes: |
|
|
affected_edges.append(edge_id) |
|
|
|
|
|
self.version += 1 |
|
|
return {"nodes": all_nodes, "edges": affected_edges} |
|
|
|
|
|
def resurrect_node(self, node_id: str) -> bool: |
|
|
"""Restore a ghost node to its original type.""" |
|
|
node = self.nodes.get(node_id) |
|
|
if not node or node.node_type != NodeType.GHOST: |
|
|
return False |
|
|
|
|
|
original_type = node.metadata.get("original_type", "hypothesis") |
|
|
try: |
|
|
node.node_type = NodeType(original_type) |
|
|
except ValueError: |
|
|
node.node_type = NodeType.HYPOTHESIS |
|
|
|
|
|
node.confidence = max(node.confidence * 2, 0.6) |
|
|
self.reasoning_graph.nodes[node_id].update(node.to_dict()) |
|
|
|
|
|
self.version += 1 |
|
|
return True |
|
|
|
|
|
def get_last_active_node(self) -> Optional[ReasoningNode]: |
|
|
"""Get the most recent active (non-ghost) node.""" |
|
|
if self._last_node_id and self._last_node_id in self.nodes: |
|
|
node = self.nodes[self._last_node_id] |
|
|
if node.node_type != NodeType.GHOST: |
|
|
return node |
|
|
|
|
|
|
|
|
valid_nodes = [ |
|
|
n for n in self.nodes.values() |
|
|
if n.node_type not in [NodeType.GHOST, NodeType.EVIDENCE] |
|
|
] |
|
|
|
|
|
if not valid_nodes: |
|
|
return None |
|
|
|
|
|
return max(valid_nodes, key=lambda x: x.timestamp) |
|
|
|
|
|
def get_node_children(self, node_id: str) -> List[ReasoningNode]: |
|
|
"""Get direct children of a node.""" |
|
|
children = [] |
|
|
for edge in self.edges.values(): |
|
|
if edge.source == node_id: |
|
|
child = self.nodes.get(edge.target) |
|
|
if child: |
|
|
children.append(child) |
|
|
return children |
|
|
|
|
|
def get_node_parents(self, node_id: str) -> List[ReasoningNode]: |
|
|
"""Get direct parents of a node.""" |
|
|
parents = [] |
|
|
for edge in self.edges.values(): |
|
|
if edge.target == node_id: |
|
|
parent = self.nodes.get(edge.source) |
|
|
if parent: |
|
|
parents.append(parent) |
|
|
return parents |
|
|
|
|
|
|
|
|
|
|
|
def to_cytoscape_elements( |
|
|
self, |
|
|
include_ghosts: bool = False, |
|
|
confidence_threshold: float = 0.0 |
|
|
) -> List[Dict]: |
|
|
"""Convert reasoning graph to Cytoscape format.""" |
|
|
elements = [] |
|
|
|
|
|
|
|
|
for node in self.nodes.values(): |
|
|
if node.node_type == NodeType.GHOST and not include_ghosts: |
|
|
continue |
|
|
if node.confidence < confidence_threshold: |
|
|
continue |
|
|
elements.append(node.to_cytoscape()) |
|
|
|
|
|
|
|
|
visible_node_ids = {e["data"]["id"] for e in elements} |
|
|
|
|
|
for edge in self.edges.values(): |
|
|
if edge.source not in visible_node_ids or edge.target not in visible_node_ids: |
|
|
continue |
|
|
elements.append(edge.to_cytoscape()) |
|
|
|
|
|
return elements |
|
|
|
|
|
def get_stats(self) -> Dict[str, int]: |
|
|
"""Get graph statistics.""" |
|
|
return { |
|
|
"nodes": len(self.nodes), |
|
|
"edges": len(self.edges), |
|
|
"entities": len(self.entities), |
|
|
"version": self.version, |
|
|
"ghosts": sum(1 for n in self.nodes.values() if n.node_type == NodeType.GHOST), |
|
|
} |
|
|
|
|
|
def clear_reasoning(self): |
|
|
"""Clear the reasoning graph while keeping the knowledge base.""" |
|
|
self.nodes.clear() |
|
|
self.edges.clear() |
|
|
self.reasoning_graph.clear() |
|
|
self._last_node_id = None |
|
|
self.version = 0 |
|
|
|
|
|
|
|
|
|
|
|
def get_state(self) -> Dict: |
|
|
"""Get complete state for serialization.""" |
|
|
return { |
|
|
"nodes": [n.to_dict() for n in self.nodes.values()], |
|
|
"edges": [e.to_dict() for e in self.edges.values()], |
|
|
"version": self.version, |
|
|
"last_node_id": self._last_node_id, |
|
|
} |
|
|
|
|
|
def restore_state(self, state: Dict): |
|
|
"""Restore state from serialized data.""" |
|
|
self.clear_reasoning() |
|
|
|
|
|
for node_data in state.get("nodes", []): |
|
|
node = ReasoningNode.from_dict(node_data) |
|
|
self.nodes[node.id] = node |
|
|
self.reasoning_graph.add_node(node.id, **node.to_dict()) |
|
|
|
|
|
for edge_data in state.get("edges", []): |
|
|
edge = ReasoningEdge.from_dict(edge_data) |
|
|
self.edges[edge.id] = edge |
|
|
self.reasoning_graph.add_edge(edge.source, edge.target, **edge.to_dict()) |
|
|
|
|
|
self.version = state.get("version", 0) |
|
|
self._last_node_id = state.get("last_node_id") |
|
|
|
|
|
def export_json(self) -> str: |
|
|
"""Export reasoning graph to JSON.""" |
|
|
return json.dumps(self.get_state(), indent=2) |
|
|
|
|
|
def get_entity_dict_for_embedding(self) -> Dict[str, Dict]: |
|
|
"""Get entity data formatted for embedding service.""" |
|
|
return { |
|
|
entity_id: { |
|
|
"id": entity.id, |
|
|
"name": entity.name, |
|
|
"category": entity.category.value, |
|
|
"description": entity.description, |
|
|
"synonyms": entity.synonyms, |
|
|
} |
|
|
for entity_id, entity in self.entities.items() |
|
|
} |
|
|
|