""" Embedding Service Module Provides multilingual semantic search using sentence-transformers. Uses paraphrase-multilingual-MiniLM-L12-v2 by default which supports 50+ languages including English, Ukrainian, Russian, Spanish, German, French, etc. References: - Reimers & Gurevych (2019): Sentence-BERT - Reimers & Gurevych (2020): Making Monolingual Sentence Embeddings Multilingual """ import os import json import logging import numpy as np from pathlib import Path from typing import Dict, List, Optional, Tuple, Any from dataclasses import dataclass import hashlib logger = logging.getLogger(__name__) @dataclass class SearchResult: """Result from semantic search.""" entity_id: str score: float entity_data: Dict[str, Any] class EmbeddingService: """ Multilingual embedding service for semantic search. Replaces keyword-based matching with embedding similarity, enabling language-agnostic symptom/entity matching. """ def __init__( self, model_name: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", cache_dir: str = "./data/embeddings", device: str = "cpu" ): self.model_name = model_name self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True) self.device = device self._model = None self._entity_embeddings: Dict[str, np.ndarray] = {} self._entity_data: Dict[str, Dict] = {} self._embedding_dim: int = 384 # Index for fast similarity search self._index = None self._index_ids: List[str] = [] @property def model(self): """Lazy load the embedding model.""" if self._model is None: try: from sentence_transformers import SentenceTransformer logger.info(f"Loading embedding model: {self.model_name}") self._model = SentenceTransformer(self.model_name, device=self.device) self._embedding_dim = self._model.get_sentence_embedding_dimension() logger.info(f"Model loaded. Embedding dimension: {self._embedding_dim}") except ImportError: logger.error( "sentence-transformers not installed. " "Run: pip install sentence-transformers" ) raise return self._model def encode(self, texts: List[str], batch_size: int = 32) -> np.ndarray: """ Encode texts to embeddings. Args: texts: List of text strings to encode batch_size: Batch size for encoding Returns: numpy array of shape (len(texts), embedding_dim) """ if not texts: return np.array([]) embeddings = self.model.encode( texts, batch_size=batch_size, show_progress_bar=len(texts) > 100, convert_to_numpy=True, normalize_embeddings=True # For cosine similarity via dot product ) return embeddings def encode_single(self, text: str) -> np.ndarray: """Encode a single text string.""" return self.encode([text])[0] def index_entities( self, entities: Dict[str, Dict[str, Any]], text_fields: List[str] = ["name", "description", "synonyms"] ): """ Build search index from entities. Args: entities: Dict of entity_id -> entity_data text_fields: Fields to combine for embedding text """ logger.info(f"Indexing {len(entities)} entities for semantic search") # Check cache cache_key = self._compute_cache_key(entities) if self._load_from_cache(cache_key): logger.info("Loaded embeddings from cache") return # Prepare texts for embedding texts = [] entity_ids = [] for entity_id, entity in entities.items(): # Combine relevant text fields text_parts = [] for field in text_fields: value = entity.get(field) if value: if isinstance(value, list): text_parts.extend(value) else: text_parts.append(str(value)) if text_parts: combined_text = " ".join(text_parts) texts.append(combined_text) entity_ids.append(entity_id) self._entity_data[entity_id] = entity if not texts: logger.warning("No texts to index") return # Compute embeddings logger.info(f"Computing embeddings for {len(texts)} entities...") embeddings = self.encode(texts) # Store embeddings for entity_id, embedding in zip(entity_ids, embeddings): self._entity_embeddings[entity_id] = embedding self._index_ids = entity_ids # Build FAISS index if available, else use numpy self._build_index(embeddings) # Save to cache self._save_to_cache(cache_key) logger.info(f"Indexed {len(self._entity_embeddings)} entities") def _build_index(self, embeddings: np.ndarray): """Build search index from embeddings.""" try: import faiss # Use IndexFlatIP for inner product (cosine similarity with normalized vectors) self._index = faiss.IndexFlatIP(self._embedding_dim) self._index.add(embeddings.astype(np.float32)) logger.info("Built FAISS index for fast similarity search") except ImportError: # Fallback to numpy-based search logger.info("FAISS not available, using numpy for similarity search") self._index = None self._embedding_matrix = embeddings def search( self, query: str, top_k: int = 10, threshold: float = 0.3, category_filter: Optional[str] = None ) -> List[SearchResult]: """ Search for entities similar to query. Args: query: Search query (any language) top_k: Maximum number of results threshold: Minimum similarity score (0-1) category_filter: Optional category to filter by Returns: List of SearchResult sorted by score descending """ if not self._entity_embeddings: logger.warning("No entities indexed. Call index_entities first.") return [] # Encode query query_embedding = self.encode_single(query) # Search if self._index is not None: # FAISS search scores, indices = self._index.search( query_embedding.reshape(1, -1).astype(np.float32), min(top_k * 2, len(self._index_ids)) # Get more for filtering ) scores = scores[0] indices = indices[0] else: # Numpy fallback scores = np.dot(self._embedding_matrix, query_embedding) indices = np.argsort(scores)[::-1][:top_k * 2] scores = scores[indices] # Build results with filtering results = [] for score, idx in zip(scores, indices): if score < threshold: continue if idx < 0 or idx >= len(self._index_ids): continue entity_id = self._index_ids[idx] entity_data = self._entity_data.get(entity_id, {}) # Apply category filter if category_filter and entity_data.get("category") != category_filter: continue results.append(SearchResult( entity_id=entity_id, score=float(score), entity_data=entity_data )) if len(results) >= top_k: break return results def search_multiple( self, queries: List[str], top_k_per_query: int = 5, threshold: float = 0.3, deduplicate: bool = True ) -> List[SearchResult]: """ Search with multiple queries, combining results. Useful for extracting multiple symptoms from a single user query. """ all_results: Dict[str, SearchResult] = {} for query in queries: results = self.search(query, top_k=top_k_per_query, threshold=threshold) for result in results: if result.entity_id not in all_results: all_results[result.entity_id] = result else: # Keep highest score if result.score > all_results[result.entity_id].score: all_results[result.entity_id] = result # Sort by score return sorted(all_results.values(), key=lambda x: x.score, reverse=True) def extract_entities_from_text( self, text: str, category: Optional[str] = None, top_k: int = 5, threshold: float = 0.4 ) -> List[SearchResult]: """ Extract relevant entities from free-form text. This is the main method for symptom extraction from user queries. Works across all supported languages. Args: text: User input text (any language) category: Filter by category (e.g., "symptom", "disease") top_k: Maximum entities to return threshold: Minimum similarity threshold """ # Direct search on full text results = self.search( query=text, top_k=top_k, threshold=threshold, category_filter=category ) # Also try splitting into phrases (helps with multiple symptoms) # Split on common separators import re phrases = re.split(r'[,;.]|\band\b|\bwith\b|\balso\b|\bі\b|\bта\b|\bи\b', text) phrases = [p.strip() for p in phrases if p.strip() and len(p.strip()) > 2] if len(phrases) > 1: phrase_results = self.search_multiple( phrases, top_k_per_query=3, threshold=threshold ) # Merge results seen_ids = {r.entity_id for r in results} for pr in phrase_results: if pr.entity_id not in seen_ids: results.append(pr) seen_ids.add(pr.entity_id) # Sort and limit results.sort(key=lambda x: x.score, reverse=True) return results[:top_k] def _compute_cache_key(self, entities: Dict) -> str: """Compute cache key from entities.""" # Hash based on entity IDs and model name entity_str = json.dumps(sorted(entities.keys())) key_str = f"{self.model_name}:{entity_str}" return hashlib.md5(key_str.encode()).hexdigest()[:16] def _load_from_cache(self, cache_key: str) -> bool: """Try to load embeddings from cache.""" embeddings_path = self.cache_dir / f"{cache_key}_embeddings.npy" metadata_path = self.cache_dir / f"{cache_key}_metadata.json" if not embeddings_path.exists() or not metadata_path.exists(): return False try: # Load metadata with open(metadata_path) as f: metadata = json.load(f) # Verify model matches if metadata.get("model") != self.model_name: return False # Load embeddings embeddings = np.load(embeddings_path) # Restore state self._index_ids = metadata["entity_ids"] self._entity_data = metadata["entity_data"] for i, entity_id in enumerate(self._index_ids): self._entity_embeddings[entity_id] = embeddings[i] # Rebuild index self._build_index(embeddings) return True except Exception as e: logger.warning(f"Failed to load from cache: {e}") return False def _save_to_cache(self, cache_key: str): """Save embeddings to cache.""" try: embeddings_path = self.cache_dir / f"{cache_key}_embeddings.npy" metadata_path = self.cache_dir / f"{cache_key}_metadata.json" # Prepare embeddings array embeddings = np.array([ self._entity_embeddings[eid] for eid in self._index_ids ]) # Save embeddings np.save(embeddings_path, embeddings) # Save metadata metadata = { "model": self.model_name, "entity_ids": self._index_ids, "entity_data": self._entity_data, "embedding_dim": self._embedding_dim } with open(metadata_path, "w") as f: json.dump(metadata, f) logger.info(f"Saved embeddings cache: {cache_key}") except Exception as e: logger.warning(f"Failed to save cache: {e}") def clear_cache(self): """Clear all cached embeddings.""" import shutil if self.cache_dir.exists(): shutil.rmtree(self.cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True) self._entity_embeddings.clear() self._entity_data.clear() self._index = None self._index_ids = [] # Global instance (lazy initialized) _embedding_service: Optional[EmbeddingService] = None def get_embedding_service() -> EmbeddingService: """Get the global embedding service instance.""" global _embedding_service if _embedding_service is None: from .config import get_config config = get_config() _embedding_service = EmbeddingService( model_name=config.embedding.model_name, cache_dir=config.embedding.cache_dir, device=config.embedding.device ) return _embedding_service