Geoffrey Kip commited on
Commit
7b7595c
·
1 Parent(s): fc58768

Fix: Cache embedding model initialization to prevent concurrency crashes

Browse files
Files changed (2) hide show
  1. ct_agent_app.py +10 -1
  2. modules/utils.py +15 -9
ct_agent_app.py CHANGED
@@ -22,7 +22,13 @@ logging.getLogger("langchain_google_genai._function_utils").setLevel(logging.ERR
22
  load_dotenv()
23
 
24
  # Module Imports
25
- from modules.utils import load_index, setup_llama_index
 
 
 
 
 
 
26
  from modules.constants import COUNTRY_COORDINATES, STATE_COORDINATES
27
 
28
  # ... (imports)
@@ -68,6 +74,9 @@ st.markdown(
68
  unsafe_allow_html=True,
69
  )
70
 
 
 
 
71
  st.title("🧬 Clinical Trial Inspector Agent")
72
 
73
  # 1. Setup LLM & LlamaIndex Settings
 
22
  load_dotenv()
23
 
24
  # Module Imports
25
+ from modules.utils import (
26
+ load_environment,
27
+ load_index,
28
+ setup_llama_index,
29
+ init_embedding_model,
30
+ get_hybrid_retriever,
31
+ )
32
  from modules.constants import COUNTRY_COORDINATES, STATE_COORDINATES
33
 
34
  # ... (imports)
 
74
  unsafe_allow_html=True,
75
  )
76
 
77
+ # Initialize global resources (Embeddings) once
78
+ init_embedding_model()
79
+
80
  st.title("🧬 Clinical Trial Inspector Agent")
81
 
82
  # 1. Setup LLM & LlamaIndex Settings
modules/utils.py CHANGED
@@ -118,29 +118,35 @@ def load_environment():
118
 
119
 
120
  # --- Configuration ---
 
 
 
 
 
 
 
 
121
  def setup_llama_index(api_key: Optional[str] = None):
122
  """
123
- Configures global LlamaIndex settings (LLM and Embeddings).
 
124
  """
 
 
 
125
  # Use passed key, or fallback to env var
126
  final_key = api_key or os.environ.get("GOOGLE_API_KEY")
127
 
128
  if not final_key:
129
- # App handles prompting for key, so we just return or log warning
130
- pass
131
 
132
  try:
133
  # Pass the key explicitly if available
134
  Settings.llm = Gemini(model="models/gemini-2.5-flash", temperature=0, api_key=final_key)
135
  except Exception as e:
136
- print(f"⚠️ LLM initialization failed (likely missing API key): {e}")
137
- print("⚠️ Using MockLLM for testing/fallback.")
138
  from llama_index.core.llms import MockLLM
139
  Settings.llm = MockLLM()
140
-
141
- Settings.embed_model = HuggingFaceEmbedding(
142
- model_name="pritamdeka/S-PubMedBert-MS-MARCO"
143
- )
144
 
145
 
146
  @st.cache_resource
 
118
 
119
 
120
  # --- Configuration ---
121
+ @st.cache_resource
122
+ def init_embedding_model():
123
+ """Initializes and caches the embedding model globally."""
124
+ Settings.embed_model = HuggingFaceEmbedding(
125
+ model_name="pritamdeka/S-PubMedBert-MS-MARCO",
126
+ device="cpu"
127
+ )
128
+
129
  def setup_llama_index(api_key: Optional[str] = None):
130
  """
131
+ Configures global LlamaIndex settings (LLM).
132
+ Embedding model is handled by init_embedding_model().
133
  """
134
+ # Ensure embedding model is loaded
135
+ init_embedding_model()
136
+
137
  # Use passed key, or fallback to env var
138
  final_key = api_key or os.environ.get("GOOGLE_API_KEY")
139
 
140
  if not final_key:
141
+ return
 
142
 
143
  try:
144
  # Pass the key explicitly if available
145
  Settings.llm = Gemini(model="models/gemini-2.5-flash", temperature=0, api_key=final_key)
146
  except Exception as e:
147
+ print(f"⚠️ LLM initialization failed: {e}")
 
148
  from llama_index.core.llms import MockLLM
149
  Settings.llm = MockLLM()
 
 
 
 
150
 
151
 
152
  @st.cache_resource