Spaces:
Sleeping
Sleeping
Geoffrey Kip
commited on
Commit
·
7b7595c
1
Parent(s):
fc58768
Fix: Cache embedding model initialization to prevent concurrency crashes
Browse files- ct_agent_app.py +10 -1
- modules/utils.py +15 -9
ct_agent_app.py
CHANGED
|
@@ -22,7 +22,13 @@ logging.getLogger("langchain_google_genai._function_utils").setLevel(logging.ERR
|
|
| 22 |
load_dotenv()
|
| 23 |
|
| 24 |
# Module Imports
|
| 25 |
-
from modules.utils import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
from modules.constants import COUNTRY_COORDINATES, STATE_COORDINATES
|
| 27 |
|
| 28 |
# ... (imports)
|
|
@@ -68,6 +74,9 @@ st.markdown(
|
|
| 68 |
unsafe_allow_html=True,
|
| 69 |
)
|
| 70 |
|
|
|
|
|
|
|
|
|
|
| 71 |
st.title("🧬 Clinical Trial Inspector Agent")
|
| 72 |
|
| 73 |
# 1. Setup LLM & LlamaIndex Settings
|
|
|
|
| 22 |
load_dotenv()
|
| 23 |
|
| 24 |
# Module Imports
|
| 25 |
+
from modules.utils import (
|
| 26 |
+
load_environment,
|
| 27 |
+
load_index,
|
| 28 |
+
setup_llama_index,
|
| 29 |
+
init_embedding_model,
|
| 30 |
+
get_hybrid_retriever,
|
| 31 |
+
)
|
| 32 |
from modules.constants import COUNTRY_COORDINATES, STATE_COORDINATES
|
| 33 |
|
| 34 |
# ... (imports)
|
|
|
|
| 74 |
unsafe_allow_html=True,
|
| 75 |
)
|
| 76 |
|
| 77 |
+
# Initialize global resources (Embeddings) once
|
| 78 |
+
init_embedding_model()
|
| 79 |
+
|
| 80 |
st.title("🧬 Clinical Trial Inspector Agent")
|
| 81 |
|
| 82 |
# 1. Setup LLM & LlamaIndex Settings
|
modules/utils.py
CHANGED
|
@@ -118,29 +118,35 @@ def load_environment():
|
|
| 118 |
|
| 119 |
|
| 120 |
# --- Configuration ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
def setup_llama_index(api_key: Optional[str] = None):
|
| 122 |
"""
|
| 123 |
-
Configures global LlamaIndex settings (LLM
|
|
|
|
| 124 |
"""
|
|
|
|
|
|
|
|
|
|
| 125 |
# Use passed key, or fallback to env var
|
| 126 |
final_key = api_key or os.environ.get("GOOGLE_API_KEY")
|
| 127 |
|
| 128 |
if not final_key:
|
| 129 |
-
|
| 130 |
-
pass
|
| 131 |
|
| 132 |
try:
|
| 133 |
# Pass the key explicitly if available
|
| 134 |
Settings.llm = Gemini(model="models/gemini-2.5-flash", temperature=0, api_key=final_key)
|
| 135 |
except Exception as e:
|
| 136 |
-
print(f"⚠️ LLM initialization failed
|
| 137 |
-
print("⚠️ Using MockLLM for testing/fallback.")
|
| 138 |
from llama_index.core.llms import MockLLM
|
| 139 |
Settings.llm = MockLLM()
|
| 140 |
-
|
| 141 |
-
Settings.embed_model = HuggingFaceEmbedding(
|
| 142 |
-
model_name="pritamdeka/S-PubMedBert-MS-MARCO"
|
| 143 |
-
)
|
| 144 |
|
| 145 |
|
| 146 |
@st.cache_resource
|
|
|
|
| 118 |
|
| 119 |
|
| 120 |
# --- Configuration ---
|
| 121 |
+
@st.cache_resource
|
| 122 |
+
def init_embedding_model():
|
| 123 |
+
"""Initializes and caches the embedding model globally."""
|
| 124 |
+
Settings.embed_model = HuggingFaceEmbedding(
|
| 125 |
+
model_name="pritamdeka/S-PubMedBert-MS-MARCO",
|
| 126 |
+
device="cpu"
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
def setup_llama_index(api_key: Optional[str] = None):
|
| 130 |
"""
|
| 131 |
+
Configures global LlamaIndex settings (LLM).
|
| 132 |
+
Embedding model is handled by init_embedding_model().
|
| 133 |
"""
|
| 134 |
+
# Ensure embedding model is loaded
|
| 135 |
+
init_embedding_model()
|
| 136 |
+
|
| 137 |
# Use passed key, or fallback to env var
|
| 138 |
final_key = api_key or os.environ.get("GOOGLE_API_KEY")
|
| 139 |
|
| 140 |
if not final_key:
|
| 141 |
+
return
|
|
|
|
| 142 |
|
| 143 |
try:
|
| 144 |
# Pass the key explicitly if available
|
| 145 |
Settings.llm = Gemini(model="models/gemini-2.5-flash", temperature=0, api_key=final_key)
|
| 146 |
except Exception as e:
|
| 147 |
+
print(f"⚠️ LLM initialization failed: {e}")
|
|
|
|
| 148 |
from llama_index.core.llms import MockLLM
|
| 149 |
Settings.llm = MockLLM()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
|
| 152 |
@st.cache_resource
|