""" Dataset Loader Module (Refactored) Generic dataset loading supporting multiple formats: - OBO (Open Biomedical Ontologies) - CSV/TSV - JSON/JSON-LD - Custom adapters Configuration-driven to support any domain, not just medical. """ import os import re import json import csv import logging import hashlib import urllib.request from pathlib import Path from abc import ABC, abstractmethod from datetime import datetime from typing import Dict, List, Optional, Tuple, Any, Type from dataclasses import dataclass, field from .knowledge_graph import Entity, EntityCategory, KnowledgeGraph from .config import DatasetConfig, get_config logger = logging.getLogger(__name__) @dataclass class OntologyTerm: """Generic ontology term representation.""" id: str name: str definition: str = "" synonyms: List[str] = field(default_factory=list) xrefs: Dict[str, str] = field(default_factory=dict) is_a: List[str] = field(default_factory=list) relationships: List[Tuple[str, str]] = field(default_factory=list) namespace: str = "" is_obsolete: bool = False def to_entity(self, category: EntityCategory) -> Entity: """Convert to Entity.""" return Entity( id=self.id, name=self.name, category=category, description=self.definition, synonyms=self.synonyms, xrefs=self.xrefs, properties={"is_a": self.is_a, "namespace": self.namespace} ) class DatasetAdapter(ABC): """Abstract base class for dataset adapters.""" @abstractmethod def parse(self, content: str) -> Dict[str, OntologyTerm]: """Parse content and return dictionary of terms.""" pass @abstractmethod def can_handle(self, source_type: str) -> bool: """Check if this adapter can handle the source type.""" pass class OBOAdapter(DatasetAdapter): """Parser for OBO (Open Biomedical Ontologies) format.""" def can_handle(self, source_type: str) -> bool: return source_type.lower() == "obo" def parse(self, content: str) -> Dict[str, OntologyTerm]: """Parse OBO format content.""" terms = {} # Split into stanzas stanzas = re.split(r'\n\[', content) for stanza in stanzas[1:]: # Skip header if stanza.startswith('Term]'): term = self._parse_term(stanza[5:]) if term and not term.is_obsolete: terms[term.id] = term logger.info(f"Parsed {len(terms)} terms from OBO content") return terms def _parse_term(self, stanza: str) -> Optional[OntologyTerm]: """Parse a single term stanza.""" data = { "id": "", "name": "", "definition": "", "synonyms": [], "xrefs": {}, "is_a": [], "relationships": [], "namespace": "", "is_obsolete": False } for line in stanza.split('\n'): line = line.strip() if not line or line.startswith('!') or ':' not in line: continue tag, _, value = line.partition(':') tag, value = tag.strip(), value.strip() if tag == 'id': data['id'] = value elif tag == 'name': data['name'] = value elif tag == 'def': match = re.match(r'"([^"]*)"', value) if match: data['definition'] = match.group(1) elif tag == 'synonym': match = re.match(r'"([^"]*)"', value) if match: data['synonyms'].append(match.group(1)) elif tag == 'xref': if ':' in value: xref_ns, _, xref_id = value.partition(':') xref_id = xref_id.split()[0] if ' ' in xref_id else xref_id data['xrefs'][xref_ns.strip()] = xref_id.strip() elif tag == 'is_a': parent_id = value.split('!')[0].strip() data['is_a'].append(parent_id) elif tag == 'relationship': parts = value.split() if len(parts) >= 2: data['relationships'].append((parts[0], parts[1])) elif tag == 'is_obsolete': data['is_obsolete'] = value.lower() == 'true' elif tag == 'namespace': data['namespace'] = value if data['id'] and data['name']: return OntologyTerm(**data) return None class CSVAdapter(DatasetAdapter): """Parser for CSV/TSV format datasets.""" # Default column mappings DEFAULT_MAPPINGS = { "id": ["id", "ID", "identifier", "code"], "name": ["name", "Name", "label", "Label", "title"], "definition": ["definition", "description", "Description", "desc"], "synonyms": ["synonyms", "aliases", "alt_names"], } def __init__(self, column_mappings: Optional[Dict[str, str]] = None): self.column_mappings = column_mappings or {} def can_handle(self, source_type: str) -> bool: return source_type.lower() in ["csv", "tsv"] def parse(self, content: str) -> Dict[str, OntologyTerm]: """Parse CSV content.""" terms = {} # Detect delimiter dialect = csv.Sniffer().sniff(content[:1024]) reader = csv.DictReader(content.splitlines(), dialect=dialect) # Map columns col_map = self._map_columns(reader.fieldnames or []) for row in reader: term = self._row_to_term(row, col_map) if term: terms[term.id] = term logger.info(f"Parsed {len(terms)} terms from CSV content") return terms def _map_columns(self, fieldnames: List[str]) -> Dict[str, str]: """Map fieldnames to standard term fields.""" col_map = {} for field, possible_names in self.DEFAULT_MAPPINGS.items(): # Check explicit mappings first if field in self.column_mappings: col_map[field] = self.column_mappings[field] else: # Try to auto-detect for name in possible_names: if name in fieldnames: col_map[field] = name break return col_map def _row_to_term(self, row: Dict, col_map: Dict[str, str]) -> Optional[OntologyTerm]: """Convert CSV row to OntologyTerm.""" term_id = row.get(col_map.get("id", ""), "") name = row.get(col_map.get("name", ""), "") if not term_id or not name: return None definition = row.get(col_map.get("definition", ""), "") # Parse synonyms (comma-separated or JSON array) synonyms_raw = row.get(col_map.get("synonyms", ""), "") if synonyms_raw.startswith("["): try: synonyms = json.loads(synonyms_raw) except json.JSONDecodeError: synonyms = [] else: synonyms = [s.strip() for s in synonyms_raw.split(",") if s.strip()] return OntologyTerm( id=term_id, name=name, definition=definition, synonyms=synonyms ) class JSONAdapter(DatasetAdapter): """Parser for JSON format datasets.""" def __init__(self, terms_path: str = "terms", id_field: str = "id", name_field: str = "name"): self.terms_path = terms_path self.id_field = id_field self.name_field = name_field def can_handle(self, source_type: str) -> bool: return source_type.lower() in ["json", "json-ld"] def parse(self, content: str) -> Dict[str, OntologyTerm]: """Parse JSON content.""" terms = {} data = json.loads(content) # Navigate to terms array items = data if self.terms_path: for key in self.terms_path.split("."): if isinstance(items, dict): items = items.get(key, []) else: break if not isinstance(items, list): items = [items] if isinstance(items, dict) else [] for item in items: term = self._item_to_term(item) if term: terms[term.id] = term logger.info(f"Parsed {len(terms)} terms from JSON content") return terms def _item_to_term(self, item: Dict) -> Optional[OntologyTerm]: """Convert JSON item to OntologyTerm.""" term_id = item.get(self.id_field, "") name = item.get(self.name_field, "") if not term_id or not name: return None return OntologyTerm( id=term_id, name=name, definition=item.get("definition", item.get("description", "")), synonyms=item.get("synonyms", item.get("aliases", [])), xrefs=item.get("xrefs", {}), is_a=item.get("is_a", item.get("parents", [])), ) class DatasetLoader: """ Main dataset loader supporting multiple formats and sources. Usage: loader = DatasetLoader() loader.load_dataset(config) # Single dataset loader.load_all_datasets() # From config """ def __init__(self, cache_dir: Optional[str] = None): self.cache_dir = Path(cache_dir or get_config().cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True) # Register adapters self.adapters: List[DatasetAdapter] = [ OBOAdapter(), CSVAdapter(), JSONAdapter(), ] # Loaded data self.datasets: Dict[str, Dict[str, OntologyTerm]] = {} def register_adapter(self, adapter: DatasetAdapter): """Register a custom adapter.""" self.adapters.insert(0, adapter) # Custom adapters take priority def get_adapter(self, source_type: str) -> Optional[DatasetAdapter]: """Get adapter for source type.""" for adapter in self.adapters: if adapter.can_handle(source_type): return adapter return None def load_dataset(self, config: DatasetConfig) -> Dict[str, OntologyTerm]: """Load a single dataset based on configuration.""" logger.info(f"Loading dataset: {config.name}") # Check cache if config.cache_enabled: cached = self._load_from_cache(config) if cached: self.datasets[config.name] = cached return cached # Get content content = self._get_content(config) if not content: logger.warning(f"No content for dataset: {config.name}") return {} # Parse with appropriate adapter adapter = self.get_adapter(config.source_type) if not adapter: logger.error(f"No adapter for source type: {config.source_type}") return {} terms = adapter.parse(content) # Cache results if config.cache_enabled: self._save_to_cache(config, terms) self.datasets[config.name] = terms return terms def load_all_datasets(self) -> Dict[str, Dict[str, OntologyTerm]]: """Load all datasets from configuration.""" config = get_config() for dataset_config in config.datasets: self.load_dataset(dataset_config) return self.datasets def _get_content(self, config: DatasetConfig) -> Optional[str]: """Get content from URL or file path.""" # Try URL first if config.source_url: try: logger.info(f"Downloading from: {config.source_url}") req = urllib.request.Request( config.source_url, headers={'User-Agent': 'HITL-KG/1.0'} ) with urllib.request.urlopen(req, timeout=60) as response: return response.read().decode('utf-8') except Exception as e: logger.warning(f"Download failed: {e}") # Try local file if config.source_path: path = Path(config.source_path) if path.exists(): return path.read_text(encoding='utf-8') return None def _cache_path(self, config: DatasetConfig) -> Path: """Get cache file path for a dataset.""" return self.cache_dir / f"{config.name}_cache.json" def _load_from_cache(self, config: DatasetConfig) -> Optional[Dict[str, OntologyTerm]]: """Load dataset from cache if valid.""" cache_path = self._cache_path(config) if not cache_path.exists(): return None # Check age mtime = datetime.fromtimestamp(cache_path.stat().st_mtime) age_days = (datetime.now() - mtime).days if age_days > config.cache_max_age_days: return None try: with open(cache_path) as f: data = json.load(f) terms = {} for term_id, term_data in data.get("terms", {}).items(): terms[term_id] = OntologyTerm( id=term_data["id"], name=term_data["name"], definition=term_data.get("definition", ""), synonyms=term_data.get("synonyms", []), xrefs=term_data.get("xrefs", {}), is_a=term_data.get("is_a", []), relationships=term_data.get("relationships", []), namespace=term_data.get("namespace", ""), ) logger.info(f"Loaded {len(terms)} terms from cache: {config.name}") return terms except Exception as e: logger.warning(f"Cache load failed: {e}") return None def _save_to_cache(self, config: DatasetConfig, terms: Dict[str, OntologyTerm]): """Save dataset to cache.""" try: cache_path = self._cache_path(config) data = { "name": config.name, "source_type": config.source_type, "timestamp": datetime.now().isoformat(), "terms": { tid: { "id": t.id, "name": t.name, "definition": t.definition, "synonyms": t.synonyms, "xrefs": t.xrefs, "is_a": t.is_a, "relationships": t.relationships, "namespace": t.namespace, } for tid, t in terms.items() } } with open(cache_path, 'w') as f: json.dump(data, f) logger.info(f"Cached {len(terms)} terms for: {config.name}") except Exception as e: logger.warning(f"Cache save failed: {e}") def build_knowledge_graph(loader: DatasetLoader) -> KnowledgeGraph: """ Build a KnowledgeGraph from loaded datasets. This function: 1. Converts OntologyTerms to Entities 2. Creates relationships between entities 3. Indexes entities for semantic search """ kg = KnowledgeGraph() config = get_config() # Map dataset names to categories category_map = { ds.name: EntityCategory(ds.entity_category) for ds in config.datasets if ds.entity_category in [c.value for c in EntityCategory] } # Add entities from each dataset for dataset_name, terms in loader.datasets.items(): category = category_map.get(dataset_name, EntityCategory.FINDING) for term_id, term in terms.items(): entity = term.to_entity(category) kg.add_entity(entity) # Build relationships based on ontology structure _build_relationships(kg, loader) logger.info(f"Built KG with {len(kg.entities)} entities") return kg def _build_relationships(kg: KnowledgeGraph, loader: DatasetLoader): """Build relationships between entities.""" # Disease-symptom associations (curated mappings) disease_symptom_mappings = _get_disease_symptom_mappings() for disease_id, symptom_mappings in disease_symptom_mappings.items(): if disease_id not in kg.entities: continue for symptom_name, confidence in symptom_mappings: # Find symptom entity by name symptom_entity = None for entity in kg.entities.values(): if entity.category == EntityCategory.SYMPTOM: if (entity.name.lower() == symptom_name.lower() or symptom_name.lower() in [s.lower() for s in entity.synonyms]): symptom_entity = entity break if symptom_entity: kg.add_relation(disease_id, symptom_entity.id, "causes", confidence) # Add treatment relations _add_treatment_entities(kg) def _get_disease_symptom_mappings() -> Dict[str, List[Tuple[str, float]]]: """ Get curated disease-symptom mappings. These are based on medical literature and provide high-quality associations that may not be present in the raw ontologies. """ return { "DOID:8469": [ # Influenza ("fever", 0.95), ("cough", 0.85), ("fatigue", 0.90), ("body aches", 0.85), ("headache", 0.80), ("chills", 0.75), ], "DOID:0080600": [ # COVID-19 ("fever", 0.80), ("cough", 0.85), ("fatigue", 0.90), ("shortness of breath", 0.70), ("headache", 0.60), ("loss of taste", 0.50), ("loss of smell", 0.50), ], "DOID:10459": [ # Common cold ("runny nose", 0.95), ("sore throat", 0.80), ("cough", 0.75), ("nasal congestion", 0.85), ("sneezing", 0.80), ], "DOID:552": [ # Pneumonia ("fever", 0.90), ("cough", 0.95), ("shortness of breath", 0.85), ("chest pain", 0.70), ("fatigue", 0.80), ], "DOID:6132": [ # Bronchitis ("cough", 0.95), ("fatigue", 0.60), ("shortness of breath", 0.50), ], "DOID:10534": [ # Strep throat ("sore throat", 0.98), ("fever", 0.80), ("headache", 0.50), ], "DOID:13084": [ # Sinusitis ("headache", 0.85), ("nasal congestion", 0.90), ("runny nose", 0.80), ], "DOID:8893": [ # Migraine ("headache", 0.99), ("nausea", 0.70), ], } def _add_treatment_entities(kg: KnowledgeGraph): """Add treatment entities and relationships.""" treatments = [ Entity("tx_rest", "Rest", EntityCategory.TREATMENT, "Physical and mental rest", ["bed rest"]), Entity("tx_fluids", "Fluid Intake", EntityCategory.TREATMENT, "Increased hydration", ["hydration"]), Entity("tx_acetaminophen", "Acetaminophen", EntityCategory.MEDICATION, "Pain and fever reducer", ["paracetamol", "Tylenol"]), Entity("tx_ibuprofen", "Ibuprofen", EntityCategory.MEDICATION, "NSAID for pain and inflammation", ["Advil", "Motrin"]), Entity("tx_antiviral", "Antiviral Medication", EntityCategory.MEDICATION, "Medications for viral infections", ["oseltamivir", "Tamiflu"]), Entity("tx_decongestant", "Decongestants", EntityCategory.MEDICATION, "Nasal congestion relief", ["pseudoephedrine"]), ] for tx in treatments: kg.add_entity(tx) # Treatment relationships treatment_map = { "DOID:8469": ["tx_rest", "tx_fluids", "tx_acetaminophen", "tx_antiviral"], "DOID:0080600": ["tx_rest", "tx_fluids", "tx_acetaminophen"], "DOID:10459": ["tx_rest", "tx_fluids", "tx_decongestant"], "DOID:552": ["tx_rest"], } for disease_id, treatment_ids in treatment_map.items(): if disease_id in kg.entities: for tx_id in treatment_ids: if tx_id in kg.entities: kg.add_relation(tx_id, disease_id, "treats", 0.8) def load_knowledge_graph(use_embeddings: bool = True) -> KnowledgeGraph: """ Main entry point: Load datasets and build knowledge graph. Args: use_embeddings: If True, also index entities for semantic search """ loader = DatasetLoader() loader.load_all_datasets() kg = build_knowledge_graph(loader) if use_embeddings: try: from .embedding_service import get_embedding_service embedding_service = get_embedding_service() embedding_service.index_entities(kg.get_entity_dict_for_embedding()) except Exception as e: logger.warning(f"Failed to initialize embeddings: {e}") return kg