Geoffrey Kip commited on
Commit
507be68
·
0 Parent(s):

Initial Release

Browse files
.dockerignore ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Git
2
+ .git
3
+ .gitignore
4
+
5
+ # Python
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+ *.so
10
+ .Python
11
+ env/
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+
28
+ # Virtual Environment
29
+ venv/
30
+ .venv/
31
+
32
+ # Environment Variables (CRITICAL: Do not include secrets)
33
+ .env
34
+ .env.local
35
+
36
+ # IDE
37
+ .vscode/
38
+ .idea/
39
+
40
+ # Mac
41
+ .DS_Store
42
+
43
+ # Logs
44
+ *.log
45
+
46
+ # Temporary
47
+ *.tmp
.flake8 ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [flake8]
2
+ max-line-length = 120
3
+ extend-ignore = E203
4
+ exclude = venv, .git, __pycache__, build, dist
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ ct_gov_lancedb/**/* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ *.egg-info/
8
+ *.pyc
9
+ *.pyo
10
+
11
+ # Build artifacts
12
+ dist/
13
+ build/
14
+ *.spec
15
+
16
+ # Virtual Environment
17
+ .venv/
18
+ venv/
19
+ env/
20
+ ENV/
21
+
22
+ # Environment variables
23
+ .env
24
+ .env.local
25
+
26
+ # Session files
27
+ amazon_session.json
28
+
29
+ # Database files
30
+ agent_data.db
31
+ *.db
32
+ *.db-journal
33
+ ct_gov_lancedb/
34
+
35
+ # Chrome/Browser session data
36
+ user_session/
37
+
38
+ # IDE
39
+ .vscode/
40
+ .idea/
41
+ *.swp
42
+ *.swo
43
+ *~
44
+ *.code-workspace
45
+
46
+ # OS
47
+ .DS_Store
48
+ Thumbs.db
49
+ .DS_Store?
50
+
51
+ # Logs
52
+ *.log
53
+
54
+ # Dev Containers
55
+ .devcontainer/
56
+
57
+ # macOS App Bundle (generated, but keep source)
58
+ *.app/
59
+
60
+ # Temporary files
61
+ *.tmp
62
+ *.temp
DEPLOYMENT.md ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Deployment Guide: Hugging Face Spaces 🐳
2
+
3
+ This guide will walk you through deploying the **Clinical Trial Inspector Agent** to **Hugging Face Spaces** using Docker.
4
+
5
+ ## Prerequisites
6
+
7
+ 1. **Hugging Face Account**: [Sign up here](https://huggingface.co/join).
8
+ 2. **Git LFS (Large File Storage)**: Required to upload the database (~700MB).
9
+ * **Mac**: `brew install git-lfs`
10
+ * **Windows**: Download from [git-lfs.com](https://git-lfs.com/)
11
+ * **Linux**: `sudo apt-get install git-lfs`
12
+
13
+ ## Step 1.5: Authentication (Crucial!) 🔑
14
+
15
+ Hugging Face requires an **Access Token** for Git operations (passwords don't work).
16
+
17
+ 1. Go to **[Settings > Access Tokens](https://huggingface.co/settings/tokens)**.
18
+ 2. Click **Create new token**.
19
+ 3. **Type**: Select **Write** (important!).
20
+ 4. Copy the token (starts with `hf_...`).
21
+ 5. **Usage**: When `git push` asks for a password, **paste this token**.
22
+
23
+ ## Step 2: Create a New Space
24
+
25
+ 1. Go to [huggingface.co/new-space](https://huggingface.co/new-space).
26
+ 2. **Space Name**: e.g., `clinical-trial-agent`.
27
+ 3. **License**: `MIT` (or your choice).
28
+ 4. **SDK**: Select **Docker**.
29
+ 5. **Visibility**: Public or Private.
30
+ 6. Click **Create Space**.
31
+
32
+ ## Step 2: Prepare Your Local Repo
33
+
34
+ You need to initialize Git LFS to track the large LanceDB files.
35
+
36
+ ```bash
37
+ # Initialize LFS
38
+ git lfs install
39
+
40
+ # Track the LanceDB files
41
+ git lfs track "ct_gov_lancedb/**/*"
42
+ git add .gitattributes
43
+ ```
44
+
45
+ ## Step 3: Push to Hugging Face
46
+
47
+ You can either push your existing repo or clone the Space and copy files. Pushing existing is easier:
48
+
49
+ ```bash
50
+ # Add the Space as a remote (replace YOUR_USERNAME and SPACE_NAME)
51
+ git remote add space https://huggingface.co/spaces/YOUR_USERNAME/SPACE_NAME
52
+
53
+ # Push the main branch
54
+ git push space main
55
+ # OR if you are on a feature branch:
56
+ git push space feature/deploy_app:main
57
+ ```
58
+
59
+ > **Note**: The first push will take time as it uploads the 700MB database.
60
+
61
+ ## Step 4: Configure Secrets (Optional but Recommended)
62
+
63
+ To run in **Admin Mode** (no user prompt for API key):
64
+
65
+ 1. Go to your Space's **Settings** tab.
66
+ 2. Scroll to **Variables and secrets**.
67
+ 3. Click **New secret**.
68
+ 4. **Name**: `GOOGLE_API_KEY`
69
+ 5. **Value**: Your Google API Key (starts with `AIza...`).
70
+
71
+ ## Step 5: Verify Deployment
72
+
73
+ 1. Go to the **App** tab in your Space.
74
+ 2. You should see "Building..." in the logs.
75
+ 3. Once built, the app will launch! 🚀
76
+
77
+ ---
78
+
79
+ ## Troubleshooting
80
+
81
+ * **"LFS upload failed"**: Ensure you ran `git lfs install` and `git lfs track`.
82
+ * **"Runtime Error"**: Check the **Logs** tab. If it says "API Key Missing", ensure you set the Secret or enter it in the UI.
Dockerfile ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.10-slim
3
+
4
+ # Set the working directory in the container
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies
8
+ # build-essential is often needed for compiling python packages
9
+ # git is needed if you install packages from git
10
+ RUN apt-get update && apt-get install -y \
11
+ build-essential \
12
+ git \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ # Copy the requirements file into the container at /app
16
+ COPY requirements.txt .
17
+
18
+ # Install any needed packages specified in requirements.txt
19
+ RUN pip install --no-cache-dir -r requirements.txt
20
+
21
+ # Copy the current directory contents into the container at /app
22
+ COPY . .
23
+
24
+ # Expose port 8501 for Streamlit
25
+ EXPOSE 8501
26
+
27
+ # Define environment variable for Streamlit to run in headless mode
28
+ ENV STREAMLIT_SERVER_HEADLESS=true
29
+ ENV STREAMLIT_SERVER_PORT=8501
30
+ ENV STREAMLIT_SERVER_ADDRESS=0.0.0.0
31
+
32
+ # Run the application
33
+ CMD ["streamlit", "run", "ct_agent_app.py"]
README.md ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Clinical Trial Inspector Agent 🕵️‍♂️💊
2
+
3
+ **Clinical Trial Inspector** is an advanced AI agent designed to revolutionize how researchers, clinicians, and analysts explore clinical trial data. By combining **Semantic Search**, **Retrieval-Augmented Generation (RAG)**, and **Visual Analytics**, it transforms raw data from [ClinicalTrials.gov](https://clinicaltrials.gov/) into actionable insights.
4
+
5
+ Built with **LangChain**, **LlamaIndex**, **Streamlit**, **Altair**, **Streamlit-Agraph**, and **Google Gemini**, this tool goes beyond simple keyword search. It understands natural language, generates inline visualizations, performs complex multi-dimensional analysis, and visualizes relationships in an interactive knowledge graph.
6
+
7
+ ## ✨ Key Features
8
+
9
+ ### 2. 🧠 Intelligent Search & Retrieval
10
+ * **Hybrid Search**: Combines **Semantic Search** (vector similarity) with **BM25 Keyword Search** (sparse retrieval) using **LanceDB's Native Hybrid Search**. This ensures you find studies that match both the *meaning* (e.g., "kidney cancer" -> "renal cell carcinoma") and *exact terms* (e.g., "NCT04589845", "Teclistamab").
11
+ * **Smart Filtering**:
12
+ * **Strict Pre-Filtering**: For specific sponsors (e.g., "Pfizer"), it forces the engine to look *only* at that sponsor's studies first, ensuring 100% recall.
13
+ * **Strict Keyword Filtering (Analytics Only)**: For counting questions (e.g., "How many studies..."), the **Analytics Engine** (`get_study_analytics`) prioritizes studies where the query explicitly appears in the **Title** or **Conditions**, ensuring high precision and accurate counts.
14
+ * **Sponsor Alias Support**: Intelligently maps aliases (e.g., "J&J", "MSD") to their canonical sponsor names ("Janssen", "Merck Sharp & Dohme") for accurate aggregation.
15
+ * **Smart Summary**: Returns a clean, concise list of relevant studies.
16
+ * **Query Expansion**: Automatically expands your search terms with medical synonyms (e.g., "Heart Attack" -> "Myocardial Infarction").
17
+ * **Re-Ranking**: Uses a Cross-Encoder (`ms-marco-MiniLM`) to re-score results for maximum relevance.
18
+ * **Query Decomposition**: Breaks down complex multi-part questions (e.g., *"Compare the primary outcomes of Keytruda vs Opdivo"*) into sub-questions for precise answers.
19
+ * **Cohort SQL Generation**: Translates eligibility criteria into standard SQL queries (OMOP CDM) for patient cohort identification.
20
+
21
+ ### 📊 Visual Analytics & Insights
22
+ - **Inline Charts (Contextual)**: The agent automatically generates **Bar Charts** and **Line Charts** directly in the chat stream when you ask aggregation questions (e.g., *"Top sponsors for Multiple Myeloma"*).
23
+ - **Analytics Dashboard (Global)**: A dedicated dashboard to analyze trends across the **entire dataset** (60,000+ studies), independent of your chat session.
24
+ - **Interactive Knowledge Graph**: Visualize connections between **Studies**, **Sponsors**, and **Conditions** in a dynamic, interactive network graph.
25
+
26
+ ### 🌍 Geospatial Dashboard
27
+ - **Global Trial Map**: Visualize the geographic distribution of clinical trials on an interactive world map.
28
+ - **Region Toggle**: Switch between **World View** (Country-level aggregation) and **USA View** (State-level aggregation).
29
+ - **Dot Visualization**: Uses dynamic **CircleMarkers** (dots) sized by trial count to show density.
30
+ - **Interactive Filters**: Filter the map by **Phase**, **Status**, **Sponsor**, **Start Year**, and **Study Type**.
31
+
32
+ ### 🔍 Multi-Filter Analysis
33
+ - **Complex Filtering**: Answer sophisticated questions by applying multiple filters simultaneously.
34
+ - *Example*: *"For **Phase 2 and 3** studies, what are **Pfizer's** most common study indications?"*
35
+ - **Full Dataset Scope**: General analytics questions analyze the **entire database**, not just a sample.
36
+ - **Smart Retrieval**: Retrieves up to **5,000 relevant studies** for comprehensive analysis.
37
+
38
+ ### ⚡ High-Performance Ingestion
39
+ - **Parallel Processing**: Uses multi-core processing to ingest and embed thousands of studies per minute.
40
+ - **LanceDB Integration**: Uses **LanceDB** for high-performance vector storage and native hybrid search.
41
+ - **Idempotent Updates**: Smartly updates existing records without duplication, allowing for seamless data refreshes.
42
+
43
+ ## 🤖 Agent Capabilities & Tools
44
+
45
+ The agent is equipped with specialized tools to handle different types of requests:
46
+
47
+ ### 1. `search_trials`
48
+ * **Purpose**: Finds specific clinical trials based on natural language queries.
49
+ * **Capabilities**: Semantic Search, Smart Filtering (Phase, Status, Sponsor, Intervention), Query Expansion, Hybrid Search, Re-Ranking.
50
+
51
+ ### 2. `get_study_analytics`
52
+ * **Purpose**: Aggregates data to reveal trends and insights.
53
+ * **Capabilities**: Multi-Filtering, Grouping (Phase, Status, Sponsor, Year, Condition), Full Dataset Access, Inline Visualization.
54
+
55
+ ### 3. `compare_studies`
56
+ * **Purpose**: Handles complex comparison or multi-part questions.
57
+ * **Capabilities**: Uses **Query Decomposition** to break a complex query into sub-queries, executes them against the database, and synthesizes the results.
58
+
59
+ ### 4. `find_similar_studies`
60
+ * **Purpose**: Discovers studies that are semantically similar to a specific trial.
61
+ * **Capabilities**:
62
+ * **NCT Lookup**: Automatically fetches content if queried with an NCT ID.
63
+ * **Self-Exclusion**: Filters out the reference study from results.
64
+ * **Scoring**: Returns similarity scores for transparency.
65
+
66
+ ### 5. `get_study_details`
67
+ * **Purpose**: Fetches the full text content of a specific study by NCT ID.
68
+ * **Capabilities**: Retrieves all chunks of a study to provide comprehensive details (Criteria, Summary, Protocol).
69
+
70
+ ### 6. `get_cohort_sql`
71
+ * **Purpose**: Translates clinical trial eligibility criteria into standard SQL queries for claims data analysis.
72
+ * **Capabilities**:
73
+ * **Extraction**: Parses text into structured inclusion/exclusion rules (Concepts, Codes).
74
+ * **SQL Generation**: Generates OMOP-compatible SQL queries targeting `medical_claims` and `pharmacy_claims`.
75
+ * **Logic Enforcement**: Applies temporal logic (e.g., "2 diagnoses > 30 days apart") for chronic conditions.
76
+
77
+ ## ⚙️ How It Works (RAG Pipeline)
78
+
79
+ 1. **Ingestion**: `ingest_ct.py` fetches study data from ClinicalTrials.gov. It extracts rich text (including **Eligibility Criteria** and **Interventions**) and structured metadata. It uses **multiprocessing** for speed.
80
+ 2. **Embedding**: Text is converted into vector embeddings using `PubMedBERT` and stored in **LanceDB**.
81
+ 3. **Retrieval**:
82
+ * **Query Transformation**: Synonyms are injected via LLM.
83
+ * **Pre-Filtering**: Strict filters (Status, Year, Sponsor) reduce the search scope.
84
+ * **Hybrid Search**: Parallel **Vector Search** (Semantic) and **BM25** (Keyword) combined via **LanceDB Native Hybrid Search**.
85
+ * **Post-Filtering**: Additional metadata checks (Phase, Intervention) on retrieved candidates.
86
+ * **Re-Ranking**: Cross-Encoder re-scoring.
87
+ 4. **Synthesis**: **Google Gemini** synthesizes the final answer.
88
+
89
+ ### 🏗️ Ingestion Pipeline
90
+
91
+ ```mermaid
92
+ graph TD
93
+ API[ClinicalTrials.gov API] -->|Fetch Batches| Script[ingest_ct.py]
94
+ Script -->|Process & Embed| LanceDB[(LanceDB)]
95
+ ```
96
+
97
+ ### 🧠 RAG Retrieval Flow
98
+
99
+ ```mermaid
100
+ graph TD
101
+ User[User Query] -->|Expand| Synonyms[Synonym Injection]
102
+ Synonyms -->|Pre-Filter| PreFilter[Pre-Retrieval Filters]
103
+ PreFilter -->|Filtered Scope| Hybrid[Hybrid Search]
104
+ Hybrid -->|Parallel Search| Vector[Vector Search] & BM25[BM25 Keyword Search]
105
+ Vector & BM25 -->|Reciprocal Rank Fusion| Fusion[Merged Candidates]
106
+ Fusion -->|Candidates| PostFilter[Post-Retrieval Filters]
107
+ PostFilter -->|Top N| ReRank[Cross-Encoder Re-Ranking]
108
+ ReRank -->|Context| LLM[Google Gemini]
109
+ LLM -->|Answer| Response[Final Response]
110
+ ```
111
+
112
+ ### 🕸️ Knowledge Graph
113
+
114
+ ```mermaid
115
+ graph TD
116
+ LanceDB[(LanceDB)] -->|Metadata| GraphBuilder[build_graph]
117
+ GraphBuilder -->|Nodes & Edges| Agraph[Streamlit Agraph]
118
+ ```
119
+
120
+ ## 🛠️ Tech Stack
121
+
122
+ - **Frontend**: Streamlit, Altair, Streamlit-Agraph
123
+ - **LLM**: Google Gemini (`gemini-2.5-flash`)
124
+ - **Orchestration**: LangChain (Agents, Tool Calling)
125
+ - **Retrieval (RAG)**: LlamaIndex (VectorStoreIndex, SubQuestionQueryEngine)
126
+ - **Vector Database**: LanceDB (Local)
127
+ - **Embeddings**: HuggingFace (`pritamdeka/S-PubMedBert-MS-MARCO`)
128
+
129
+ ## 🚀 Getting Started
130
+
131
+ ### Prerequisites
132
+
133
+ - Python 3.10+
134
+ - A Google Cloud API Key with access to Gemini
135
+
136
+ ### Installation
137
+
138
+ 1. **Clone the repository**
139
+ ```bash
140
+ git clone <repository-url>
141
+ cd clinical_trial_agent
142
+ ```
143
+
144
+ 2. **Create and activate a virtual environment**
145
+ ```bash
146
+ python -m venv venv
147
+ source venv/bin/activate # On Windows: venv\Scripts\activate
148
+ ```
149
+
150
+ 3. **Install dependencies**
151
+ ```bash
152
+ pip install -r requirements.txt
153
+ ```
154
+
155
+ 4. **Set up Environment Variables**
156
+ Create a `.env` file in the root directory and add your Google API Key:
157
+ ```bash
158
+ GOOGLE_API_KEY=your_google_api_key_here
159
+ ```
160
+
161
+ ## 📖 Usage
162
+
163
+ ### 1. Ingest Data
164
+ Populate the local database. The script uses parallel processing for speed.
165
+
166
+ ```bash
167
+ # Recommended: Ingest 5000 recent studies
168
+ python scripts/ingest_ct.py --limit 5000 --years 5
169
+
170
+ # Ingest ALL studies (Warning: Large download!)
171
+ python scripts/ingest_ct.py --limit -1 --years 10
172
+ ```
173
+
174
+ ### 2. Run the Agent
175
+ Launch the Streamlit application:
176
+
177
+ ```bash
178
+ streamlit run ct_agent_app.py
179
+ ```
180
+
181
+ ### 3. Ask Questions!
182
+ - **Search**: *"Find studies for Multiple Myeloma."*
183
+ - **Comparison**: *"Compare the primary outcomes of Keytruda vs Opdivo."*
184
+ - **Analytics**: *"Who are the top sponsors for Breast Cancer?"* (Now supports grouping by **Intervention** and **Study Type**!)
185
+ - **Graph**: Go to the **Knowledge Graph** tab to visualize connections.
186
+
187
+ ## 🧪 Testing & Quality
188
+
189
+ - **Unit Tests**: Run `python -m pytest tests/test_unit.py` to verify core logic.
190
+ - **Hybrid Search Tests**: Run `python -m pytest tests/test_hybrid_search.py` to verify the search engine's precision and recall.
191
+ - **Data Integrity**: Run `python -m unittest tests/test_data_integrity.py` to verify database content against known ground truths.
192
+ - **Sponsor Normalization**: Run `python -m pytest tests/test_sponsor_normalization.py` to verify alias mapping logic.
193
+ - **Linting**: Codebase is formatted with `black` and linted with `flake8`.
194
+
195
+ ## 📂 Project Structure
196
+
197
+ - `ct_agent_app.py`: Main application logic.
198
+ - `modules/`:
199
+ - `utils.py`: Configuration, Normalization, Custom Filters.
200
+ - `constants.py`: Static data (Coordinates, Mappings).
201
+ - `tools.py`: Tool definitions (`search_trials`, `compare_studies`, etc.).
202
+ - `cohort_tools.py`: SQL generation logic (`get_cohort_sql`).
203
+ - `graph_viz.py`: Knowledge Graph logic.
204
+ - `scripts/`:
205
+ - `ingest_ct.py`: Parallel data ingestion pipeline.
206
+ - `analyze_db.py`: Database inspection.
207
+
208
+ - `ct_gov_lancedb/`: Persisted LanceDB vector store.
209
+ - `tests/`:
210
+ - `test_unit.py`: Core logic tests.
211
+ - `test_hybrid_search.py`: Integration tests for search engine.
212
+
213
+ ## 🐳 Deployment
214
+
215
+ The application is container-ready and can be deployed using Docker.
216
+
217
+ ### Build the Image
218
+ ```bash
219
+ docker build -t clinical-trial-agent .
220
+ ```
221
+
222
+ ### Run the Container
223
+ You can run the container in two modes:
224
+
225
+ **1. Admin Mode (API Key in Environment)**
226
+ Pass the key as an environment variable. Users will not be prompted.
227
+ ```bash
228
+ docker run -p 8501:8501 -e GOOGLE_API_KEY=your_key_here clinical-trial-agent
229
+ ```
230
+
231
+ **2. User Mode (Prompt for Key)**
232
+ Run without the key. Users will be prompted to enter their own key in the sidebar.
233
+ ```bash
234
+ docker run -p 8501:8501 clinical-trial-agent
235
+ ```
236
+
237
+ ### Hosting Options
238
+ - **Hugging Face Spaces**: Select "Docker" SDK. Add `GOOGLE_API_KEY` to Secrets for Admin Mode.
239
+ - **Google Cloud Run**: Deploy the container and map port 8501.
240
+
ct_agent_app.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Clinical Trial Inspector Agent Application.
3
+
4
+ This is the main Streamlit application script. It orchestrates:
5
+ 1. **LLM & Agents**: Initializes Google Gemini and the LangChain agent.
6
+ 2. **RAG Pipeline**: Loads the LlamaIndex vector store for semantic retrieval.
7
+ 3. **User Interface**: Renders the Streamlit UI with tabs for Chat, Analytics, and Raw Data.
8
+ 4. **Visualization**: Handles dynamic chart generation using Altair.
9
+ """
10
+
11
+ import streamlit as st
12
+ import pandas as pd
13
+ import os
14
+ import altair as alt
15
+ import logging
16
+ from dotenv import load_dotenv
17
+
18
+ # Suppress logging
19
+ logging.getLogger("langchain_google_genai._function_utils").setLevel(logging.ERROR)
20
+
21
+ # Load environment variables
22
+ load_dotenv()
23
+
24
+ # Module Imports
25
+ from modules.utils import load_index, setup_llama_index
26
+ from modules.constants import COUNTRY_COORDINATES, STATE_COORDINATES
27
+
28
+ # ... (imports)
29
+ from modules.tools import (
30
+ search_trials,
31
+ find_similar_studies,
32
+ get_study_analytics,
33
+ compare_studies,
34
+ get_study_details,
35
+ fetch_study_analytics_data,
36
+ )
37
+ from modules.cohort_tools import get_cohort_sql
38
+ from modules.graph_viz import build_graph
39
+ from streamlit_agraph import agraph
40
+ from streamlit_option_menu import option_menu
41
+ import folium
42
+ from streamlit_folium import st_folium
43
+
44
+ # LangChain Imports
45
+ from langchain_google_genai import ChatGoogleGenerativeAI
46
+ from langchain.agents import AgentExecutor, create_tool_calling_agent
47
+ from langchain_core.prompts import ChatPromptTemplate
48
+ from langchain_core.messages import HumanMessage, AIMessage
49
+ from langchain_core.prompts import MessagesPlaceholder
50
+
51
+ # --- App Configuration ---
52
+ st.set_page_config(
53
+ page_title="Clinical Trial Inspector",
54
+ layout="wide",
55
+ initial_sidebar_state="expanded",
56
+ )
57
+
58
+ # --- Custom CSS for Sidebar Width ---
59
+ st.markdown(
60
+ """
61
+ <style>
62
+ [data-testid="stSidebar"] {
63
+ min-width: 200px;
64
+ max-width: 250px;
65
+ }
66
+ </style>
67
+ """,
68
+ unsafe_allow_html=True,
69
+ )
70
+
71
+ st.title("🧬 Clinical Trial Inspector Agent")
72
+
73
+ # 1. Setup LLM & LlamaIndex Settings
74
+ # We use Google Gemini-2.5-Flash for fast and accurate responses.
75
+ api_key = os.environ.get("GOOGLE_API_KEY")
76
+
77
+ if not api_key:
78
+ st.sidebar.warning("⚠️ API Key Missing")
79
+ user_key = st.sidebar.text_input("Enter Google API Key:", type="password", help="Get one at https://aistudio.google.com/")
80
+ if user_key:
81
+ st.session_state["api_key"] = user_key
82
+ api_key = user_key
83
+ st.sidebar.success("Key set!")
84
+ st.rerun()
85
+ else:
86
+ # Check if key is already in session state (from previous run)
87
+ if "api_key" in st.session_state:
88
+ api_key = st.session_state["api_key"]
89
+ else:
90
+ st.warning("Please enter your Google API Key in the sidebar to continue.")
91
+ st.stop()
92
+ else:
93
+ # Env var exists, ensure it's in session state for tools to find
94
+ st.session_state["api_key"] = api_key
95
+
96
+ # Ensure LlamaIndex settings (Embeddings, LLM) are applied on every run
97
+ setup_llama_index(api_key=api_key)
98
+
99
+ llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0, google_api_key=api_key)
100
+
101
+ # 2. Load LlamaIndex (Cached)
102
+ # The index is loaded once and cached to avoid reloading on every interaction.
103
+ index = load_index()
104
+
105
+
106
+ # 3. Define Agent (Cached)
107
+ @st.cache_resource
108
+ def get_agent():
109
+ """Initializes and caches the LangChain agent."""
110
+ tools = [
111
+ search_trials,
112
+ find_similar_studies,
113
+ get_study_analytics,
114
+ compare_studies,
115
+ get_study_details,
116
+ get_cohort_sql,
117
+ ]
118
+
119
+ prompt = ChatPromptTemplate.from_messages(
120
+ [
121
+ (
122
+ "system",
123
+ "You are a Clinical Trial Expert Assistant. "
124
+ "Your goal is to help researchers and analysts understand clinical trial data. "
125
+ "You have access to a local database of clinical trials (embedded from ClinicalTrials.gov). "
126
+ "Use the available tools to search for studies, find similar studies, and generate analytics. "
127
+ "When asked about 'trends', 'counts', 'how many', or 'most common', ALWAYS use the `get_study_analytics` tool. "
128
+ "Do NOT use `search_trials` for counting questions like 'How many studies...'. "
129
+ "When asked to 'find studies', 'search', or 'list', use `search_trials`. "
130
+ "When asked to 'compare' multiple studies or answer complex multi-part questions, use `compare_studies`. "
131
+ "If the user asks for a specific study by ID (e.g., NCT12345678), `search_trials` handles that automatically. "
132
+ "However, if the user asks for specific **details**, **criteria**, **summary**, or **protocol** of a single study, "
133
+ "you MUST use the `get_study_details` tool to fetch the full content. "
134
+ "If the user asks to **generate SQL**, **build a cohort**, or **translate criteria to code** for a study, "
135
+ "use the `get_cohort_sql` tool. "
136
+ "When reporting 'similar studies', ALWAYS include the similarity score provided by the tool "
137
+ "and DO NOT include the study that was used as the query (the reference study). "
138
+ "Provide concise, evidence-based answers citing specific studies when possible.",
139
+ ),
140
+ MessagesPlaceholder(variable_name="chat_history"),
141
+ ("human", "{input}"),
142
+ ("placeholder", "{agent_scratchpad}"),
143
+ ]
144
+ )
145
+
146
+ agent = create_tool_calling_agent(llm, tools, prompt)
147
+ return AgentExecutor(agent=agent, tools=tools, verbose=True)
148
+
149
+
150
+ agent_executor = get_agent()
151
+
152
+ # --- Sidebar ---
153
+ with st.sidebar:
154
+ st.image(
155
+ "https://cdn-icons-png.flaticon.com/512/3004/3004458.png", width=50
156
+ )
157
+ st.title("Clinical Trial Agent")
158
+
159
+ page = option_menu(
160
+ "Main Menu",
161
+ ["Chat Assistant", "Analytics Dashboard", "Knowledge Graph", "Study Map", "Raw Data"],
162
+ icons=["chat-dots", "graph-up", "diagram-3", "map", "database"],
163
+ menu_icon="cast",
164
+ default_index=0,
165
+ )
166
+
167
+
168
+ # --- Helper Functions ---
169
+ def generate_dashboard_analytics():
170
+ """Callback to generate analytics and update session state."""
171
+ # Map UI selection to tool arguments
172
+ group_map = {
173
+ "Phase": "phase",
174
+ "Status": "status",
175
+ "Sponsor": "sponsor",
176
+ "Start Year": "start_year",
177
+ "Intervention": "intervention",
178
+ "Study Type": "study_type",
179
+ }
180
+
181
+ # Get values from session state
182
+ # We use .get() to avoid KeyErrors if the widget hasn't initialized yet (though it should have)
183
+ g_by = st.session_state.get("dash_group_by", "Sponsor")
184
+ p_filter = st.session_state.get("dash_phase", "")
185
+ s_filter = st.session_state.get("dash_sponsor", "")
186
+
187
+ with st.spinner(f"Analyzing studies by {g_by}..."):
188
+ # Call the tool directly
189
+ result = get_study_analytics.invoke(
190
+ {
191
+ "query": "overall",
192
+ "group_by": group_map.get(g_by, "sponsor"),
193
+ "phase": p_filter if p_filter else None,
194
+ "sponsor": s_filter if s_filter else None,
195
+ }
196
+ )
197
+
198
+ # The tool sets session state 'inline_chart_data'
199
+ if "inline_chart_data" in st.session_state:
200
+ st.session_state["dashboard_data"] = st.session_state["inline_chart_data"]
201
+ else:
202
+ st.warning(result)
203
+
204
+
205
+ # --- PAGE 1: CHAT ---
206
+ if page == "Chat Assistant":
207
+ st.header("💬 Chat Assistant")
208
+ if "messages" not in st.session_state:
209
+ st.session_state.messages = []
210
+
211
+ # Render Chat History
212
+ for message in st.session_state.messages:
213
+ with st.chat_message(message["role"]):
214
+ st.markdown(message["content"])
215
+ # Render chart if present in message history (persisted charts)
216
+ if "chart_data" in message:
217
+ chart_data = message["chart_data"]
218
+ st.caption(chart_data["title"])
219
+ chart = (
220
+ alt.Chart(pd.DataFrame(chart_data["data"]))
221
+ .mark_bar()
222
+ .encode(
223
+ x=alt.X(
224
+ chart_data["x"], sort="-y", axis=alt.Axis(labelLimit=200)
225
+ ),
226
+ y=alt.Y(chart_data["y"], title="Count"),
227
+ tooltip=[chart_data["x"], chart_data["y"]],
228
+ )
229
+ .interactive()
230
+ )
231
+ st.altair_chart(chart, theme="streamlit", width="stretch")
232
+
233
+ # Chat Input
234
+ if prompt := st.chat_input("Ask about clinical trials..."):
235
+ st.session_state.messages.append({"role": "user", "content": prompt})
236
+ with st.chat_message("user"):
237
+ st.markdown(prompt)
238
+
239
+ with st.chat_message("assistant"):
240
+ with st.spinner("Analyzing clinical trials..."):
241
+ try:
242
+ # Clear previous inline chart data to avoid stale charts
243
+ if "inline_chart_data" in st.session_state:
244
+ del st.session_state["inline_chart_data"]
245
+
246
+ # Construct chat history for the agent context
247
+ chat_history = []
248
+ for msg in st.session_state.messages[:-1]:
249
+ if msg["role"] == "user":
250
+ chat_history.append(HumanMessage(content=msg["content"]))
251
+ else:
252
+ chat_history.append(AIMessage(content=msg["content"]))
253
+
254
+ # Invoke Agent
255
+ response = agent_executor.invoke(
256
+ {"input": prompt, "chat_history": chat_history}
257
+ )
258
+ output = response["output"]
259
+ st.markdown(output)
260
+
261
+ # Check for inline chart data (set by tools)
262
+ chart_data = None
263
+ if "inline_chart_data" in st.session_state:
264
+ chart_data = st.session_state["inline_chart_data"]
265
+ st.caption(chart_data["title"])
266
+ if chart_data["type"] == "bar":
267
+ # Use Altair for better charts
268
+ chart = (
269
+ alt.Chart(pd.DataFrame(chart_data["data"]))
270
+ .mark_bar()
271
+ .encode(
272
+ x=alt.X(
273
+ chart_data["x"],
274
+ sort="-y",
275
+ axis=alt.Axis(labelLimit=200),
276
+ ),
277
+ y=alt.Y(chart_data["y"], title="Count"),
278
+ tooltip=[chart_data["x"], chart_data["y"]],
279
+ )
280
+ .interactive()
281
+ )
282
+ st.altair_chart(chart, theme="streamlit", width="stretch")
283
+
284
+ # Clean up session state
285
+ del st.session_state["inline_chart_data"]
286
+
287
+ # Save message with chart data if present
288
+ msg_obj = {"role": "assistant", "content": output}
289
+ if chart_data:
290
+ msg_obj["chart_data"] = chart_data
291
+ st.session_state.messages.append(msg_obj)
292
+
293
+ except Exception as e:
294
+ st.error(f"An error occurred: {e}")
295
+
296
+ # --- PAGE 2: ANALYTICS DASHBOARD ---
297
+ if page == "Analytics Dashboard":
298
+ st.header("📊 Global Analytics")
299
+ st.write(
300
+ "Analyze trends across the entire clinical trial dataset (60,000+ studies)."
301
+ )
302
+
303
+ col1, col2 = st.columns([1, 3])
304
+
305
+ with col1:
306
+ st.subheader("Configuration")
307
+ group_by = st.selectbox(
308
+ "Group By",
309
+ ["Phase", "Status", "Sponsor", "Start Year", "Intervention", "Study Type"],
310
+ index=2,
311
+ key="dash_group_by",
312
+ )
313
+
314
+ # Optional Filters
315
+ st.markdown("---")
316
+ st.markdown("**Filters (Optional)**")
317
+ filter_phase = st.text_input("Phase (e.g., Phase 2)", key="dash_phase")
318
+ filter_sponsor = st.text_input("Sponsor (e.g., Pfizer)", key="dash_sponsor")
319
+
320
+ st.button(
321
+ "Generate Analytics", type="primary", on_click=generate_dashboard_analytics
322
+ )
323
+
324
+ with col2:
325
+ # Always render if data exists in session state
326
+ if "dashboard_data" in st.session_state:
327
+ c_data = st.session_state["dashboard_data"]
328
+ st.subheader(c_data["title"])
329
+
330
+ # Altair Chart Rendering
331
+ if (
332
+ c_data["x"] == "start_year" or group_by == "Start Year"
333
+ ): # Check both key and UI selection
334
+ # Line chart for years
335
+ chart = (
336
+ alt.Chart(pd.DataFrame(c_data["data"]))
337
+ .mark_line(point=True)
338
+ .encode(
339
+ x=alt.X(
340
+ c_data["x"], axis=alt.Axis(format="d"), title="Year"
341
+ ), # 'd' for integer year
342
+ y=alt.Y(c_data["y"], title="Count"),
343
+ tooltip=[c_data["x"], c_data["y"]],
344
+ )
345
+ .interactive()
346
+ )
347
+ else:
348
+ # Bar chart for others
349
+ chart = (
350
+ alt.Chart(pd.DataFrame(c_data["data"]))
351
+ .mark_bar()
352
+ .encode(
353
+ x=alt.X(
354
+ c_data["x"],
355
+ sort="-y",
356
+ axis=alt.Axis(labelLimit=200),
357
+ ),
358
+ y=alt.Y(c_data["y"], title="Count"),
359
+ tooltip=[c_data["x"], c_data["y"]],
360
+ )
361
+ .interactive()
362
+ )
363
+
364
+ st.altair_chart(chart, theme="streamlit", width="stretch")
365
+
366
+ # Show raw table
367
+ with st.expander("View Source Data"):
368
+ st.dataframe(pd.DataFrame(c_data["data"]))
369
+
370
+ # --- PAGE 3: KNOWLEDGE GRAPH ---
371
+ if page == "Knowledge Graph":
372
+ st.header("🕸️ Interactive Knowledge Graph")
373
+ st.write("Visualize connections between Studies, Sponsors, and Conditions.")
374
+
375
+ col_g1, col_g2 = st.columns([1, 3])
376
+
377
+ with col_g1:
378
+ st.subheader("Graph Settings")
379
+ graph_query = st.text_input("Search Topic", value="Cancer")
380
+ limit = st.slider("Max Nodes", 10, 100, 50)
381
+
382
+ if st.button("Build Graph"):
383
+ with st.spinner("Fetching data and building graph..."):
384
+ # Use retriever to get relevant nodes
385
+ retriever = index.as_retriever(similarity_top_k=limit)
386
+ nodes = retriever.retrieve(graph_query)
387
+ data = [n.metadata for n in nodes]
388
+
389
+ # Build Graph
390
+ g_nodes, g_edges, g_config = build_graph(data)
391
+
392
+ st.session_state["graph_data"] = {
393
+ "nodes": g_nodes,
394
+ "edges": g_edges,
395
+ "config": g_config,
396
+ }
397
+
398
+ with col_g2:
399
+ if "graph_data" in st.session_state:
400
+ g_data = st.session_state["graph_data"]
401
+ st.success(
402
+ f"Graph built with {len(g_data['nodes'])} nodes and {len(g_data['edges'])} edges."
403
+ )
404
+ agraph(
405
+ nodes=g_data["nodes"], edges=g_data["edges"], config=g_data["config"]
406
+ )
407
+ else:
408
+ st.info("Enter a topic and click 'Build Graph' to visualize connections.")
409
+
410
+ # --- PAGE# --- Study Map Tab ---
411
+ elif page == "Study Map":
412
+ st.header("🌍 Global Clinical Trial Map")
413
+ st.markdown("Visualize the geographic distribution of clinical trials.")
414
+
415
+ # Sidebar Filters for Map
416
+ st.sidebar.markdown("### 🗺️ Map Filters")
417
+ map_region = st.sidebar.radio("Region", ["World", "USA"], index=0)
418
+
419
+ map_phase = st.sidebar.multiselect(
420
+ "Phase", ["PHASE1", "PHASE2", "PHASE3", "PHASE4"], default=["PHASE2", "PHASE3"]
421
+ )
422
+ map_status = st.sidebar.selectbox(
423
+ "Status", ["RECRUITING", "COMPLETED", "ACTIVE_NOT_RECRUITING"], index=0
424
+ )
425
+ map_sponsor = st.sidebar.text_input("Sponsor (Optional)", "")
426
+ map_year = st.sidebar.number_input("Start Year (>=)", min_value=2000, value=2020)
427
+ map_type = st.sidebar.selectbox(
428
+ "Study Type", ["Interventional", "Observational", "All"], index=0
429
+ )
430
+
431
+ # Convert filters to arguments
432
+ phase_str = ",".join(map_phase) if map_phase else None
433
+ type_arg = map_type if map_type != "All" else None
434
+
435
+ if st.button("Update Map"):
436
+ with st.spinner("Aggregating geographic data..."):
437
+ # Determine grouping based on Region
438
+ group_by_field = "state" if map_region == "USA" else "country"
439
+
440
+ # Call analytics logic directly
441
+ summary = fetch_study_analytics_data(
442
+ query="overall",
443
+ group_by=group_by_field,
444
+ phase=phase_str,
445
+ status=map_status,
446
+ sponsor=map_sponsor,
447
+ start_year=map_year,
448
+ study_type=type_arg,
449
+ )
450
+
451
+ # Retrieve data from session state
452
+ chart_data = st.session_state.get("inline_chart_data", {})
453
+ data_records = chart_data.get("data", [])
454
+
455
+ if not data_records:
456
+ st.warning("No data found for these filters.")
457
+ st.session_state["map_data"] = None
458
+ st.session_state["map_region"] = map_region # Store region too
459
+ else:
460
+ # Store in session state for persistence
461
+ st.session_state["map_data"] = data_records
462
+ st.session_state["map_region"] = map_region
463
+
464
+ # Render Map (Outside Button Block)
465
+ if st.session_state.get("map_data"):
466
+ data_records = st.session_state["map_data"]
467
+ region_mode = st.session_state.get("map_region", "World")
468
+ df_map = pd.DataFrame(data_records)
469
+
470
+ # Configure Map Center/Zoom
471
+ if region_mode == "USA":
472
+ m = folium.Map(location=[37.0902, -95.7129], zoom_start=4)
473
+ coord_map = STATE_COORDINATES
474
+ else:
475
+ m = folium.Map(location=[20, 0], zoom_start=2)
476
+ coord_map = COUNTRY_COORDINATES
477
+
478
+ # Add CircleMarkers
479
+ for _, row in df_map.iterrows():
480
+ loc_name = row["category"]
481
+ count = row["count"]
482
+
483
+ # Clean name if needed (strip trailing parenthesis)
484
+ loc_clean = loc_name.rstrip(")")
485
+ coords = coord_map.get(loc_clean)
486
+
487
+ if coords:
488
+ folium.CircleMarker(
489
+ location=coords,
490
+ radius=min(max(count / 5, 3), 20), # Adjust scale
491
+ popup=f"{loc_clean}: {count} trials",
492
+ color="blue" if region_mode == "USA" else "crimson",
493
+ fill=True,
494
+ fill_color="blue" if region_mode == "USA" else "crimson",
495
+ ).add_to(m)
496
+
497
+ st_folium(m, width=800, height=500)
498
+
499
+ # Show data table
500
+ st.subheader(f"{region_mode} Data")
501
+ st.dataframe(df_map)
502
+
503
+ # --- PAGE 4: RAW DATA ---
504
+ if page == "Raw Data":
505
+ st.header("📂 Raw Data Explorer")
506
+ st.write("View and filter the underlying dataset.")
507
+
508
+ # Load a sample or full dataset? Full might be slow.
509
+ # We load a sample (top 100) to avoid performance issues.
510
+ col_raw_1, col_raw_2 = st.columns([1, 1])
511
+
512
+ with col_raw_1:
513
+ if st.button("Load Sample Data (Top 100)"):
514
+ with st.spinner("Fetching data..."):
515
+ retriever = index.as_retriever(similarity_top_k=100)
516
+ nodes = retriever.retrieve("clinical trial")
517
+ data = [n.metadata for n in nodes]
518
+ df_raw = pd.DataFrame(data)
519
+
520
+ # Format Year to remove commas (e.g., 2,023 -> 2023)
521
+ if "start_year" in df_raw.columns:
522
+ df_raw["start_year"] = (
523
+ pd.to_numeric(df_raw["start_year"], errors="coerce")
524
+ .astype("Int64")
525
+ .astype(str)
526
+ .str.replace(",", "")
527
+ )
528
+
529
+ # Store in session state to persist the table
530
+ st.session_state["sample_data"] = df_raw
531
+
532
+ with col_raw_2:
533
+ # Download Full Dataset Logic
534
+ if st.button("Prepare Full Download (CSV)"):
535
+ with st.spinner("Fetching all records from database..."):
536
+ try:
537
+ # Access LanceDB directly for speed
538
+ import lancedb
539
+ db = lancedb.connect("./ct_gov_lancedb")
540
+ tbl = db.open_table("clinical_trials")
541
+
542
+ # Fetch all data
543
+ df_full = tbl.to_pandas()
544
+
545
+ # Handle metadata flattening if needed
546
+ if "metadata" in df_full.columns:
547
+ meta_df = pd.json_normalize(df_full["metadata"])
548
+ # Combine or just use metadata
549
+ df_full = meta_df
550
+
551
+ # Convert to CSV
552
+ csv = df_full.to_csv(index=False).encode("utf-8")
553
+ st.session_state["full_csv"] = csv
554
+ st.success(f"Ready! Fetched {len(df_full)} records.")
555
+ else:
556
+ st.warning("No data found in database.")
557
+ except Exception as e:
558
+ st.error(f"Error fetching data: {e}")
559
+
560
+ if "full_csv" in st.session_state:
561
+ st.download_button(
562
+ label="⬇️ Download Full CSV",
563
+ data=st.session_state["full_csv"],
564
+ file_name="clinical_trials_full.csv",
565
+ mime="text/csv",
566
+ )
567
+
568
+ # Display Sample Data Table (Full Width)
569
+ if "sample_data" in st.session_state:
570
+ st.markdown("### Sample Data (Top 100)")
571
+ st.dataframe(
572
+ st.session_state["sample_data"],
573
+ column_config={
574
+ "nct_id": "NCT ID",
575
+ "title": "Study Title",
576
+ "start_year": st.column_config.TextColumn(
577
+ "Start Year"
578
+ ), # Force text to avoid commas
579
+ "url": st.column_config.LinkColumn("URL"),
580
+ },
581
+ width="stretch",
582
+ hide_index=True,
583
+ )
modules/__init__.py ADDED
File without changes
modules/cohort_tools.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from langchain.tools import tool
3
+ from langchain_google_genai import ChatGoogleGenerativeAI
4
+ from langchain.prompts import PromptTemplate
5
+ from modules.tools import get_study_details
6
+ from modules.utils import load_environment
7
+
8
+ import streamlit as st
9
+ import os
10
+
11
+ # Load env for API key
12
+ load_environment()
13
+
14
+ def get_llm():
15
+ """Retrieves LLM instance with dynamic API key."""
16
+ # Check session state first (User provided key)
17
+ api_key = None
18
+ if hasattr(st, "session_state") and "api_key" in st.session_state:
19
+ api_key = st.session_state["api_key"]
20
+
21
+ # Fallback to environment variable
22
+ if not api_key:
23
+ api_key = os.environ.get("GOOGLE_API_KEY")
24
+
25
+ if not api_key:
26
+ raise ValueError("Google API Key not found in session state or environment.")
27
+
28
+ return ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0, google_api_key=api_key)
29
+
30
+ # Initialize LLM (Dynamic)
31
+ # llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0)
32
+
33
+ EXTRACT_PROMPT = PromptTemplate(
34
+ template="""
35
+ You are a Clinical Informatics Expert.
36
+ Your task is to extract structured cohort requirements from the following Clinical Trial Eligibility Criteria.
37
+
38
+ Output a JSON object with two keys: "inclusion" and "exclusion".
39
+ Each key should contain a list of rules.
40
+ Each rule should have:
41
+ - "concept": The medical concept (e.g., "Type 2 Diabetes", "Metformin").
42
+ - "domain": The domain (Condition, Drug, Measurement, Procedure, Observation).
43
+ - "temporal": Any temporal logic (e.g., "History of", "Within last 6 months").
44
+ - "codes": A list of potential ICD-10 or RxNorm codes (make a best guess).
45
+
46
+ CRITERIA:
47
+ {criteria}
48
+
49
+ JSON OUTPUT:
50
+ """,
51
+ input_variables=["criteria"],
52
+ )
53
+
54
+ SQL_PROMPT = PromptTemplate(
55
+ template="""
56
+ You are a SQL Expert specializing in Healthcare Claims Data Analysis.
57
+ Generate a standard SQL query to define a cohort of patients based on the following structured requirements.
58
+
59
+ ### Schema Assumptions
60
+ 1. **`medical_claims`** (Diagnoses & Procedures):
61
+ - `patient_id`, `claim_date`, `diagnosis_code` (ICD-10), `procedure_code` (CPT/HCPCS).
62
+ 2. **`pharmacy_claims`** (Drugs):
63
+ - `patient_id`, `fill_date`, `ndc_code`.
64
+
65
+ ### Logic Rules
66
+ 1. **Conditions (Diagnoses)**:
67
+ - Require **at least 2 distinct claim dates** where the diagnosis code matches.
68
+ - These 2 claims must be **at least 30 days apart** (to confirm chronic condition).
69
+ 2. **Drugs**:
70
+ - Require at least 1 claim with a matching NDC code.
71
+ 3. **Procedures**:
72
+ - Require at least 1 claim with a matching CPT/HCPCS code.
73
+ 4. **Exclusions**:
74
+ - Exclude patients who have ANY matching claims for exclusion criteria.
75
+
76
+ ### Requirements (JSON)
77
+ {requirements}
78
+
79
+ ### Output
80
+ Generate a single SQL query that selects `patient_id` from the claims tables meeting the criteria.
81
+ Use Common Table Expressions (CTEs) for clarity.
82
+ Do NOT output markdown formatting (```sql), just the raw SQL.
83
+
84
+ SQL QUERY:
85
+ """,
86
+ input_variables=["requirements"],
87
+ )
88
+
89
+
90
+ def extract_cohort_requirements(criteria_text: str) -> dict:
91
+ """Uses LLM to parse criteria text into structured JSON."""
92
+ llm = get_llm()
93
+ chain = EXTRACT_PROMPT | llm
94
+ response = chain.invoke({"criteria": criteria_text})
95
+ try:
96
+ # Clean up potential markdown code blocks
97
+ text = response.content.replace("```json", "").replace("```", "").strip()
98
+ return json.loads(text)
99
+ except json.JSONDecodeError:
100
+ return {"error": "Failed to parse LLM output", "raw_output": response.content}
101
+
102
+
103
+ def generate_cohort_sql(requirements: dict) -> str:
104
+ """Uses LLM to translate structured requirements into SQL."""
105
+ llm = get_llm()
106
+ chain = SQL_PROMPT | llm
107
+ response = chain.invoke({"requirements": json.dumps(requirements, indent=2)})
108
+ return response.content.replace("```sql", "").replace("```", "").strip()
109
+
110
+
111
+ @tool("get_cohort_sql")
112
+ def get_cohort_sql(nct_id: str) -> str:
113
+ """
114
+ Generates a SQL query to define the patient cohort for a specific study (NCT ID).
115
+
116
+ Args:
117
+ nct_id (str): The ClinicalTrials.gov identifier (e.g., NCT01234567).
118
+
119
+ Returns:
120
+ str: A formatted string containing the Extracted Requirements (JSON) and the Generated SQL.
121
+ """
122
+ # 1. Fetch Study Details
123
+ # We reuse the existing tool logic to get the text
124
+ study_text = get_study_details.invoke(nct_id)
125
+
126
+ if "No study found" in study_text:
127
+ return f"Could not find study {nct_id}."
128
+
129
+ # 2. Extract Requirements
130
+ requirements = extract_cohort_requirements(study_text)
131
+
132
+ # 3. Generate SQL
133
+ sql_query = generate_cohort_sql(requirements)
134
+
135
+ return f"""
136
+ ### 📋 Extracted Cohort Requirements
137
+ ```json
138
+ {json.dumps(requirements, indent=2)}
139
+ ```
140
+
141
+ ### 💾 Generated SQL Query (OMOP CDM)
142
+ ```sql
143
+ {sql_query}
144
+ ```
145
+ """
modules/constants.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # --- Geographic Constants ---
3
+ COUNTRY_COORDINATES = {
4
+ "United States": [37.0902, -95.7129],
5
+ "Canada": [56.1304, -106.3468],
6
+ "United Kingdom": [55.3781, -3.4360],
7
+ "Germany": [51.1657, 10.4515],
8
+ "France": [46.2276, 2.2137],
9
+ "China": [35.8617, 104.1954],
10
+ "Japan": [36.2048, 138.2529],
11
+ "Australia": [-25.2744, 133.7751],
12
+ "Brazil": [-14.2350, -51.9253],
13
+ "India": [20.5937, 78.9629],
14
+ "Russia": [61.5240, 105.3188],
15
+ "South Korea": [35.9078, 127.7669],
16
+ "Italy": [41.8719, 12.5674],
17
+ "Spain": [40.4637, -3.7492],
18
+ "Netherlands": [52.1326, 5.2913],
19
+ "Belgium": [50.5039, 4.4699],
20
+ "Switzerland": [46.8182, 8.2275],
21
+ "Sweden": [60.1282, 18.6435],
22
+ "Israel": [31.0461, 34.8516],
23
+ "Poland": [51.9194, 19.1451],
24
+ "Taiwan": [23.6978, 120.9605],
25
+ "Mexico": [23.6345, -102.5528],
26
+ "Argentina": [-38.4161, -63.6167],
27
+ "South Africa": [-30.5595, 22.9375],
28
+ "Turkey": [38.9637, 35.2433],
29
+ "Denmark": [56.2639, 9.5018],
30
+ "New Zealand": [-40.9006, 174.8860],
31
+ "Czech Republic": [49.8175, 15.4730],
32
+ "Hungary": [47.1625, 19.5033],
33
+ "Finland": [61.9241, 25.7482],
34
+ "Norway": [60.4720, 8.4689],
35
+ "Austria": [47.5162, 14.5501],
36
+ "Greece": [39.0742, 21.8243],
37
+ "Ireland": [53.1424, -7.6921],
38
+ "Portugal": [39.3999, -8.2245],
39
+ "Ukraine": [48.3794, 31.1656],
40
+ "Egypt": [26.8206, 30.8025],
41
+ "Thailand": [15.8700, 100.9925],
42
+ "Singapore": [1.3521, 103.8198],
43
+ "Malaysia": [4.2105, 101.9758],
44
+ "Vietnam": [14.0583, 108.2772],
45
+ "Philippines": [12.8797, 121.7740],
46
+ "Indonesia": [-0.7893, 113.9213],
47
+ "Saudi Arabia": [23.8859, 45.0792],
48
+ "United Arab Emirates": [23.4241, 53.8478],
49
+ }
50
+
51
+ STATE_COORDINATES = {
52
+ "Alabama": [32.806671, -86.791130],
53
+ "Alaska": [61.370716, -152.404419],
54
+ "Arizona": [33.729759, -111.431221],
55
+ "Arkansas": [34.969704, -92.373123],
56
+ "California": [36.116203, -119.681564],
57
+ "Colorado": [39.059811, -105.311104],
58
+ "Connecticut": [41.597782, -72.755371],
59
+ "Delaware": [39.318523, -75.507141],
60
+ "District of Columbia": [38.897438, -77.026817],
61
+ "Florida": [27.766279, -81.686783],
62
+ "Georgia": [33.040619, -83.643074],
63
+ "Hawaii": [21.094318, -157.498337],
64
+ "Idaho": [44.240459, -114.478828],
65
+ "Illinois": [40.349457, -88.986137],
66
+ "Indiana": [39.849426, -86.258278],
67
+ "Iowa": [42.011539, -93.210526],
68
+ "Kansas": [38.526600, -96.726486],
69
+ "Kentucky": [37.668140, -84.670067],
70
+ "Louisiana": [31.169546, -91.867805],
71
+ "Maine": [44.693947, -69.381927],
72
+ "Maryland": [39.063946, -76.802101],
73
+ "Massachusetts": [42.230171, -71.530106],
74
+ "Michigan": [43.326618, -84.536095],
75
+ "Minnesota": [45.694454, -93.900192],
76
+ "Mississippi": [32.741646, -89.678696],
77
+ "Missouri": [38.456085, -92.288368],
78
+ "Montana": [46.921925, -110.454353],
79
+ "Nebraska": [41.125370, -98.268082],
80
+ "Nevada": [38.313515, -117.055374],
81
+ "New Hampshire": [43.452492, -71.563896],
82
+ "New Jersey": [40.298904, -74.521011],
83
+ "New Mexico": [34.840515, -106.248482],
84
+ "New York": [42.165726, -74.948051],
85
+ "North Carolina": [35.630066, -79.806419],
86
+ "North Dakota": [47.528912, -99.784012],
87
+ "Ohio": [40.388783, -82.764915],
88
+ "Oklahoma": [35.565342, -96.928917],
89
+ "Oregon": [44.572021, -122.070938],
90
+ "Pennsylvania": [41.203323, -77.194527],
91
+ "Rhode Island": [41.680893, -71.511780],
92
+ "South Carolina": [33.856892, -80.945007],
93
+ "South Dakota": [44.299782, -99.438828],
94
+ "Tennessee": [35.747845, -86.692345],
95
+ "Texas": [31.054487, -97.563461],
96
+ "Utah": [40.150032, -111.862434],
97
+ "Vermont": [44.045876, -72.710686],
98
+ "Virginia": [37.769337, -78.169968],
99
+ "Washington": [47.400902, -121.490494],
100
+ "West Virginia": [38.491226, -80.954453],
101
+ "Wisconsin": [44.268543, -89.616508],
102
+ "Wyoming": [42.755966, -107.302490],
103
+ }
modules/graph_viz.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from streamlit_agraph import Node, Edge, Config
2
+
3
+
4
+ def build_graph(data):
5
+ """
6
+ Constructs a knowledge graph from clinical trial data.
7
+
8
+ Args:
9
+ data (list): List of study metadata dictionaries.
10
+
11
+ Returns:
12
+ tuple: (nodes, edges, config) for streamlit-agraph.
13
+ """
14
+ nodes = []
15
+ edges = []
16
+
17
+ # Sets to track unique entities
18
+ study_ids = set()
19
+ sponsors = set()
20
+ conditions = set()
21
+
22
+ for study in data:
23
+ nct_id = study.get("nct_id", "Unknown")
24
+ title = study.get("title", "Unknown")
25
+ # Use 'sponsor' if available (new ingestion), else fallback to 'org'
26
+ sponsor = study.get("sponsor", study.get("org", "Unknown"))
27
+ condition_str = study.get("condition", "")
28
+
29
+ # 1. Study Node
30
+ if nct_id not in study_ids:
31
+ nodes.append(
32
+ Node(
33
+ id=nct_id,
34
+ label=nct_id,
35
+ size=20,
36
+ color="#4B8BBE", # Blue
37
+ title=title,
38
+ shape="dot",
39
+ )
40
+ )
41
+ study_ids.add(nct_id)
42
+
43
+ # 2. Sponsor Node & Edge
44
+ if sponsor and sponsor != "Unknown":
45
+ if sponsor not in sponsors:
46
+ nodes.append(
47
+ Node(
48
+ id=sponsor,
49
+ label=sponsor,
50
+ size=15,
51
+ color="#FF6B6B", # Red
52
+ shape="triangle",
53
+ )
54
+ )
55
+ sponsors.add(sponsor)
56
+
57
+ # Edge: Study -> Sponsor
58
+ edges.append(
59
+ Edge(
60
+ source=nct_id, target=sponsor, label="sponsored_by", color="#CCCCCC"
61
+ )
62
+ )
63
+
64
+ # 3. Condition Nodes & Edges
65
+ if condition_str:
66
+ conds = [c.strip() for c in condition_str.split(",") if c.strip()]
67
+ for cond in conds:
68
+ if cond not in conditions:
69
+ nodes.append(
70
+ Node(
71
+ id=cond,
72
+ label=cond,
73
+ size=15,
74
+ color="#6BCB77", # Green
75
+ shape="diamond",
76
+ )
77
+ )
78
+ conditions.add(cond)
79
+
80
+ # Edge: Study -> Condition
81
+ edges.append(
82
+ Edge(source=nct_id, target=cond, label="studies", color="#CCCCCC")
83
+ )
84
+
85
+ # Configuration
86
+ config = Config(
87
+ width=800,
88
+ height=600,
89
+ directed=True,
90
+ physics=True,
91
+ hierarchical=False,
92
+ nodeHighlightBehavior=True,
93
+ highlightColor="#F7A7A6",
94
+ collapsible=False,
95
+ )
96
+
97
+ return nodes, edges, config
modules/tools.py ADDED
@@ -0,0 +1,706 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LangChain Tools for the Clinical Trial Agent.
3
+
4
+ This module defines the tools that the agent can use to interact with the clinical trial data.
5
+ Tools include:
6
+ 1. **search_trials**: Semantic search with optional strict filtering.
7
+ 2. **find_similar_studies**: Finding studies semantically similar to a given text.
8
+ 3. **get_study_analytics**: Aggregating data for trends and insights (with inline charts).
9
+ """
10
+
11
+ import pandas as pd
12
+ import streamlit as st
13
+ from typing import Optional
14
+ from langchain.tools import tool as langchain_tool
15
+ from llama_index.core.vector_stores import (
16
+ MetadataFilter,
17
+ MetadataFilters,
18
+ FilterOperator,
19
+ )
20
+ from llama_index.core import Settings
21
+ from llama_index.core.postprocessor import MetadataReplacementPostProcessor
22
+ from llama_index.core.postprocessor import SentenceTransformerRerank
23
+ from llama_index.core.query_engine import SubQuestionQueryEngine
24
+ from llama_index.core.tools import QueryEngineTool, ToolMetadata
25
+ from modules.utils import (
26
+ load_index,
27
+ normalize_sponsor,
28
+ get_sponsor_variations,
29
+ get_hybrid_retriever,
30
+ )
31
+ import re
32
+ import traceback
33
+
34
+ # --- Tools ---
35
+
36
+
37
+ def expand_query(query: str) -> str:
38
+ """Expands a search query with synonyms using the LLM."""
39
+ if not query or len(query.split()) > 10: # Skip expansion for long queries
40
+ return query
41
+
42
+ # Skip expansion if it looks like an NCT ID
43
+ if re.search(r"NCT\d+", query, re.IGNORECASE):
44
+ return query
45
+
46
+ prompt = (
47
+ f"You are a helpful medical assistant. "
48
+ f"Expand the following search query with relevant medical synonyms and acronyms. "
49
+ f"Return ONLY the expanded query string combined with OR operators. "
50
+ f"Do not add any explanation.\n\n"
51
+ f"Query: {query}\n"
52
+ f"Expanded Query:"
53
+ )
54
+ try:
55
+ # Use the global Settings.llm
56
+ if not Settings.llm:
57
+ # Fallback if not initialized (though load_index does it)
58
+ from modules.utils import setup_llama_index
59
+
60
+ setup_llama_index()
61
+
62
+ response = Settings.llm.complete(prompt)
63
+ expanded = response.text.strip()
64
+ # Clean up if LLM is chatty
65
+ if "Expanded Query:" in expanded:
66
+ expanded = expanded.split("Expanded Query:")[-1].strip()
67
+
68
+ if not expanded:
69
+ print(f"⚠️ Expansion returned empty. Using original query.")
70
+ return query
71
+
72
+ print(f"✨ Expanded Query: '{query}' -> '{expanded}'")
73
+ return expanded
74
+ except Exception as e:
75
+ print(f"⚠️ Query expansion failed: {e}")
76
+ return query
77
+
78
+
79
+ @langchain_tool("search_trials")
80
+ def search_trials(
81
+ query: str = None,
82
+ status: str = None,
83
+ phase: str = None,
84
+ sponsor: str = None,
85
+ intervention: str = None,
86
+ year: int = None,
87
+ ):
88
+ """
89
+ Searches for clinical trials using semantic search with robust filtering.
90
+
91
+ Args:
92
+ query (str, optional): The natural language search query.
93
+ status (str, optional): Filter by recruitment status.
94
+ phase (str, optional): Filter by trial phase.
95
+ sponsor (str, optional): Filter by sponsor name.
96
+ intervention (str, optional): Filter by intervention/drug name.
97
+ year (int, optional): Filter for studies starting on or after this year.
98
+
99
+ Returns:
100
+ str: A structured list of relevant studies.
101
+ """
102
+ index = load_index()
103
+
104
+ # Constants
105
+ TOP_K_STRICT = 500 # High recall for pre-filtered search
106
+
107
+ # --- Query Construction ---
108
+ if not query:
109
+ parts = [p for p in [sponsor, intervention, phase, status] if p]
110
+ query = " ".join(parts) if parts else "clinical trial"
111
+ else:
112
+ # Inject context for vector search
113
+ if sponsor and normalize_sponsor(sponsor).lower() not in query.lower():
114
+ query = f"{normalize_sponsor(sponsor)} {query}"
115
+ if intervention and intervention.lower() not in query.lower():
116
+ query = f"{intervention} {query}"
117
+
118
+ query = expand_query(query)
119
+
120
+ print(f"🔍 Tool Called: search_trials(query='{query}', sponsor='{sponsor}')")
121
+
122
+ # --- Strategy 1: Strict Pre-Retrieval Filtering (High Precision) ---
123
+ # Filter by Sponsor/Status/Year at the database level first.
124
+ pre_filters = []
125
+
126
+ # NCT ID Match
127
+ nct_match = re.search(r"NCT\d+", query, re.IGNORECASE)
128
+ if nct_match:
129
+ nct_id = nct_match.group(0).upper()
130
+ pre_filters.append(MetadataFilter(key="nct_id", value=nct_id, operator=FilterOperator.EQ))
131
+
132
+ if status:
133
+ pre_filters.append(MetadataFilter(key="status", value=status.upper(), operator=FilterOperator.EQ))
134
+ if year:
135
+ pre_filters.append(MetadataFilter(key="start_year", value=year, operator=FilterOperator.GTE))
136
+
137
+ # Sponsor Pre-Filter
138
+ if sponsor:
139
+ from modules.utils import get_sponsor_variations
140
+ variations = get_sponsor_variations(sponsor)
141
+ if variations:
142
+ print(f"🎯 Applying strict pre-filter for sponsor '{sponsor}' ({len(variations)} variants)")
143
+ # Use 'sponsor' field which is the Lead Sponsor
144
+ pre_filters.append(MetadataFilter(key="sponsor", value=variations, operator=FilterOperator.IN))
145
+ else:
146
+ print(f"⚠️ No strict mapping for sponsor '{sponsor}'. Will rely on fuzzy post-filtering.")
147
+
148
+ metadata_filters = MetadataFilters(filters=pre_filters) if pre_filters else None
149
+
150
+ # Post-processors (Reranking)
151
+ reranker = SentenceTransformerRerank(model="cross-encoder/ms-marco-MiniLM-L-12-v2", top_n=50)
152
+
153
+ # --- HYBRID SEARCH IMPLEMENTATION ---
154
+ # Combine Vector + BM25 using get_hybrid_retriever
155
+ try:
156
+ retriever = get_hybrid_retriever(index, similarity_top_k=TOP_K_STRICT, filters=metadata_filters)
157
+ nodes = retriever.retrieve(query)
158
+
159
+ # (QueryFusionRetriever returns nodes, but we want to rerank them)
160
+ if nodes:
161
+ from llama_index.core.schema import QueryBundle
162
+ nodes = reranker.postprocess_nodes(nodes, query_bundle=QueryBundle(query_str=query))
163
+
164
+ except Exception as e:
165
+ print(f"⚠️ Hybrid search failed: {e}. Falling back to standard vector search.")
166
+ traceback.print_exc()
167
+ query_engine = index.as_query_engine(
168
+ similarity_top_k=TOP_K_STRICT,
169
+ filters=metadata_filters,
170
+ node_postprocessors=[reranker]
171
+ )
172
+ response = query_engine.query(query)
173
+ nodes = response.source_nodes
174
+
175
+ # --- Strict Metadata Filtering (Post-Fusion) ---
176
+ # BM25 results might not respect the vector filters, so filter them out.
177
+ final_nodes = []
178
+ for node in nodes:
179
+ meta = node.metadata
180
+ keep = True
181
+
182
+ # Re-apply filters to ensure BM25 results are valid
183
+ if status and meta.get("status", "").upper() != status.upper():
184
+ keep = False
185
+ if year:
186
+ try:
187
+ if int(meta.get("start_year", 0)) < year:
188
+ keep = False
189
+ except:
190
+ pass
191
+ if sponsor:
192
+ # Strict logic for sponsor in pre-filters is ignored by BM25.
193
+ # Check if the sponsor matches one of the variations OR fuzzy match
194
+ # If strict variations exist, enforce them.
195
+ variations = get_sponsor_variations(sponsor)
196
+ node_sponsor = meta.get("sponsor", "")
197
+ # Fallback to org if sponsor is missing (legacy data)
198
+ if not node_sponsor:
199
+ node_sponsor = meta.get("org", "")
200
+
201
+ if variations:
202
+ if node_sponsor not in variations:
203
+ keep = False
204
+ else:
205
+ # Fuzzy fallback
206
+ if normalize_sponsor(sponsor).lower() not in normalize_sponsor(node_sponsor).lower():
207
+ keep = False
208
+
209
+ if keep:
210
+ final_nodes.append(node)
211
+
212
+ nodes = final_nodes
213
+
214
+ # --- Strict Keyword Filtering ---
215
+ # BM25 handles keyword relevance naturally, so rely on the Hybrid Search + Reranker
216
+ # rather than applying an aggressive substring check here.
217
+
218
+ # Update response object structure to match expected format if we used retriever
219
+ class MockResponse:
220
+ def __init__(self, nodes):
221
+ self.source_nodes = nodes
222
+
223
+ response = MockResponse(nodes)
224
+
225
+ # --- Strategy 2: Hybrid Search (Fallback) ---
226
+ # Hybrid Search is enabled by default.
227
+ # Strict filters are handled in post-processing above.
228
+
229
+
230
+ # --- Formatting Output ---
231
+ if not response.source_nodes:
232
+ return "No matching studies found. Try broadening your search terms or filters."
233
+
234
+ # Filter by Relevance Score for display
235
+ MIN_SCORE = 1.5
236
+ relevant_nodes = [node for node in response.source_nodes if node.score > MIN_SCORE]
237
+
238
+ # If strict filtering removes too much, show at least top 3 to be helpful
239
+ if len(relevant_nodes) < 3 and len(response.source_nodes) > 0:
240
+ relevant_nodes = response.source_nodes[:3]
241
+
242
+ display_limit = 20
243
+ display_nodes = relevant_nodes[:display_limit]
244
+
245
+ results = []
246
+ for node in display_nodes:
247
+ meta = node.metadata
248
+ entry = (
249
+ f"**{meta.get('title', 'Untitled')}**\n"
250
+ f" - ID: {meta.get('nct_id')}\n"
251
+ f" - Phase: {meta.get('phase', 'N/A')}\n"
252
+ f" - Status: {meta.get('status', 'N/A')}\n"
253
+ f" - Sponsor: {meta.get('sponsor', meta.get('org', 'Unknown'))}\n"
254
+ f" - Relevance: {node.score:.2f}"
255
+ )
256
+ results.append(entry)
257
+
258
+ return f"Found {len(results)} relevant studies:\n\n" + "\n\n".join(results)
259
+
260
+
261
+ @langchain_tool("find_similar_studies")
262
+ def find_similar_studies(query: str):
263
+ """
264
+ Finds studies semantically similar to a given query or study description.
265
+
266
+ This tool is useful for "more like this" functionality. It relies purely
267
+ on vector similarity without strict metadata filtering.
268
+
269
+ Args:
270
+ query (str): The text to match against (e.g., a study title or description).
271
+
272
+ Returns:
273
+ str: A string containing the top 5 similar studies with their titles and summaries.
274
+ """
275
+ index = load_index()
276
+
277
+ # 1. Check if query is an NCT ID
278
+ nct_match = re.search(r"NCT\d+", query, re.IGNORECASE)
279
+ target_nct = None
280
+ search_text = query
281
+
282
+ if nct_match:
283
+ target_nct = nct_match.group(0).upper()
284
+ print(f"🎯 Detected NCT ID for similarity: {target_nct}")
285
+
286
+ # Fetch the study content to use as the semantic query
287
+ # Use the vector store directly to get the text
288
+ retriever = index.as_retriever(
289
+ filters=MetadataFilters(
290
+ filters=[MetadataFilter(key="nct_id", value=target_nct, operator=FilterOperator.EQ)]
291
+ ),
292
+ similarity_top_k=1
293
+ )
294
+ nodes = retriever.retrieve(target_nct)
295
+
296
+ if nodes:
297
+ # Use the study's text (Title + Summary) as the query
298
+ search_text = nodes[0].text
299
+ print(f"✅ Found study content. Using {len(search_text)} chars for semantic search.")
300
+ else:
301
+ print(f"⚠️ Study {target_nct} not found. Falling back to text search.")
302
+
303
+ # 2. Perform Semantic Search
304
+ # Fetch more candidates (10) to allow for filtering
305
+ retriever = index.as_retriever(similarity_top_k=10)
306
+ nodes = retriever.retrieve(search_text)
307
+
308
+ results = []
309
+ count = 0
310
+ for node in nodes:
311
+ # 3. Self-Exclusion
312
+ if target_nct and node.metadata.get("nct_id") == target_nct:
313
+ continue
314
+
315
+ # Deduplication (if multiple chunks of same study appear)
316
+ if any(r["nct_id"] == node.metadata.get("nct_id") for r in results):
317
+ continue
318
+
319
+ results.append({
320
+ "nct_id": node.metadata.get("nct_id"),
321
+ "text": f"Study: {node.metadata['title']} (NCT: {node.metadata.get('nct_id')})\nScore: {node.score:.4f}\nSummary: {node.text[:200]}..."
322
+ })
323
+
324
+ count += 1
325
+ if count >= 5: # Limit to top 5 unique results
326
+ break
327
+
328
+ if not results:
329
+ return "No similar studies found."
330
+
331
+ return "\n\n".join([r["text"] for r in results])
332
+
333
+
334
+ def fetch_study_analytics_data(
335
+ query: str,
336
+ group_by: str,
337
+ phase: Optional[str] = None,
338
+ status: Optional[str] = None,
339
+ sponsor: Optional[str] = None,
340
+ intervention: Optional[str] = None,
341
+ start_year: Optional[int] = None,
342
+ study_type: Optional[str] = None,
343
+ ) -> str:
344
+ """
345
+ Underlying logic for fetching and aggregating clinical trial data.
346
+ See get_study_analytics for full docstring.
347
+ """
348
+ index = load_index()
349
+
350
+ # 1. Retrieve Data
351
+ if query.lower() == "overall":
352
+ try:
353
+ # Connect to LanceDB directly for speed
354
+ import lancedb
355
+ db = lancedb.connect("./ct_gov_lancedb")
356
+ tbl = db.open_table("clinical_trials")
357
+ # Fetch all data as pandas DataFrame
358
+ df = tbl.to_pandas()
359
+
360
+ # LlamaIndex stores metadata in a 'metadata' column (usually as a dict/struct)
361
+ # We need to flatten it to get columns like 'status', 'phase', etc.
362
+ if "metadata" in df.columns:
363
+ # Check if it's already a dict or needs parsing
364
+ # LanceDB to_pandas() converts struct to dict
365
+ meta_df = pd.json_normalize(df["metadata"])
366
+ df = meta_df
367
+
368
+ # If columns are already flat (depending on schema evolution), we are good.
369
+ # But usually it's nested.
370
+
371
+ except Exception as e:
372
+ return f"Error fetching full dataset: {e}"
373
+ else:
374
+ filters = []
375
+ if status:
376
+ filters.append(
377
+ MetadataFilter(
378
+ key="status", value=status.upper(), operator=FilterOperator.EQ
379
+ )
380
+ )
381
+ if phase and "," not in phase:
382
+ pass
383
+
384
+ if sponsor:
385
+ # Use the helper to get all variations (e.g. "Pfizer" -> ["Pfizer", "Pfizer Inc."])
386
+ sponsor_variations = get_sponsor_variations(sponsor)
387
+ if sponsor_variations:
388
+ print(f"🎯 Using strict pre-filter for sponsor '{sponsor}': {len(sponsor_variations)} variations found.")
389
+ filters.append(
390
+ MetadataFilter(
391
+ key="sponsor", value=sponsor_variations, operator=FilterOperator.IN
392
+ )
393
+ )
394
+
395
+ metadata_filters = MetadataFilters(filters=filters) if filters else None
396
+
397
+ search_query = query
398
+ if sponsor and sponsor.lower() not in query.lower():
399
+ search_query = f"{sponsor} {query}"
400
+
401
+ # Use hybrid search for better recall
402
+ retriever = index.as_retriever(
403
+ similarity_top_k=5000,
404
+ filters=metadata_filters,
405
+ vector_store_query_mode="hybrid"
406
+ )
407
+ nodes = retriever.retrieve(search_query)
408
+
409
+ # --- Strict Keyword Filtering ---
410
+ # Strictly check if the query appears in Title or Conditions to ensure accurate counting.
411
+ # EXCEPTION: If the query matches the requested sponsor, we also check the 'org' field.
412
+ if query.lower() != "overall":
413
+ q_term = query.lower()
414
+
415
+ # Check if the query is essentially the sponsor name
416
+ is_sponsor_query = False
417
+
418
+ # Check if the query itself normalizes to a known sponsor
419
+ query_normalized = normalize_sponsor(query)
420
+ if query_normalized and query_normalized != query:
421
+ # If normalization changed it (or found a mapping), it's likely a sponsor
422
+ is_sponsor_query = True
423
+
424
+ if sponsor:
425
+ # Normalize both to see if they refer to the same entity
426
+ norm_query = normalize_sponsor(query)
427
+ norm_sponsor = normalize_sponsor(sponsor)
428
+
429
+ if norm_query and norm_sponsor and norm_query.lower() == norm_sponsor.lower():
430
+ is_sponsor_query = True
431
+ elif sponsor.lower() in query.lower() or query.lower() in sponsor.lower():
432
+ is_sponsor_query = True
433
+
434
+ filtered_nodes = []
435
+ for node in nodes:
436
+ meta = node.metadata
437
+ title = meta.get("title", "").lower()
438
+ conditions = meta.get("condition", "").lower() # Note: key is 'condition' in DB
439
+ org = meta.get("org", "").lower()
440
+ sponsor_val = meta.get("sponsor", "").lower()
441
+
442
+ # If it's a sponsor query, we allow matches on the Organization OR Sponsor field
443
+ # AND we check if the normalized values match (handling aliases like J&J -> Janssen)
444
+ match = False
445
+ if q_term in title or q_term in conditions:
446
+ match = True
447
+ elif is_sponsor_query:
448
+ # Check raw match
449
+ if q_term in org or q_term in sponsor_val:
450
+ match = True
451
+ else:
452
+ # Check normalized match
453
+ norm_org = normalize_sponsor(org)
454
+ norm_val = normalize_sponsor(sponsor_val)
455
+
456
+ # Compare against the normalized query (which is the sponsor in this case)
457
+ target_norm = norm_sponsor if sponsor else query_normalized
458
+
459
+ if norm_org and target_norm and norm_org.lower() == target_norm.lower():
460
+ match = True
461
+ elif norm_val and target_norm and norm_val.lower() == target_norm.lower():
462
+ match = True
463
+
464
+ if match:
465
+ filtered_nodes.append(node)
466
+
467
+ print(f"📉 Strict Filter: {len(nodes)} -> {len(filtered_nodes)} nodes for '{query}'")
468
+ nodes = filtered_nodes
469
+
470
+ data = [node.metadata for node in nodes]
471
+ df = pd.DataFrame(data)
472
+
473
+ if "nct_id" in df.columns:
474
+ df = df.drop_duplicates(subset="nct_id")
475
+
476
+ if df.empty:
477
+ return "No studies found for analytics."
478
+
479
+ # --- APPLY FILTERS (Pandas) ---
480
+ if phase:
481
+ target_phases = [p.strip().upper().replace(" ", "") for p in phase.split(",")]
482
+ df["phase_upper"] = df["phase"].astype(str).str.upper().str.replace(" ", "")
483
+ mask = df["phase_upper"].apply(lambda x: any(tp in x for tp in target_phases))
484
+ df = df[mask]
485
+
486
+ if status:
487
+ df = df[df["status"].str.upper() == status.upper()]
488
+
489
+ if sponsor:
490
+ target_sponsor = normalize_sponsor(sponsor).lower()
491
+ # Use 'sponsor' column if it exists, otherwise fallback to 'org'
492
+ if "sponsor" in df.columns:
493
+ df["sponsor_check"] = df["sponsor"].fillna(df["org"]).astype(str).apply(normalize_sponsor).str.lower()
494
+ else:
495
+ df["sponsor_check"] = df["org"].astype(str).apply(normalize_sponsor).str.lower()
496
+
497
+ df = df[df["sponsor_check"].str.contains(target_sponsor, regex=False)]
498
+
499
+ if intervention:
500
+ target_intervention = intervention.lower()
501
+ df["intervention_lower"] = df["intervention"].astype(str).str.lower()
502
+ df = df[df["intervention_lower"].str.contains(target_intervention, regex=False)]
503
+
504
+ if start_year:
505
+ df["start_year"] = pd.to_numeric(df["start_year"], errors="coerce").fillna(0)
506
+ df = df[df["start_year"] >= start_year]
507
+
508
+ if study_type:
509
+ df = df[df["study_type"].str.upper() == study_type.upper()]
510
+
511
+ if df.empty:
512
+ return "No studies found after applying filters."
513
+
514
+ key_map = {
515
+ "phase": "phase",
516
+ "status": "status",
517
+ "sponsor": "sponsor" if "sponsor" in df.columns else "org",
518
+ "start_year": "start_year",
519
+ "condition": "condition",
520
+ "intervention": "intervention",
521
+ "study_type": "study_type",
522
+ "country": "country",
523
+ "state": "state",
524
+ }
525
+
526
+ if group_by not in key_map:
527
+ return f"Invalid group_by field: {group_by}. Valid options: phase, status, sponsor, start_year, condition, intervention, study_type, country, state"
528
+
529
+ col = key_map[group_by]
530
+
531
+ if col == "start_year":
532
+ df[col] = pd.to_numeric(df[col], errors="coerce")
533
+ counts = df[col].value_counts().sort_index()
534
+ elif col == "condition":
535
+ counts = df[col].astype(str).str.split(", ").explode().value_counts().head(10)
536
+ elif col == "intervention":
537
+ all_interventions = []
538
+ for interventions in df[col].dropna():
539
+ parts = [i.strip() for i in interventions.split(";") if i.strip()]
540
+ all_interventions.extend(parts)
541
+ counts = pd.Series(all_interventions).value_counts().head(10)
542
+ else:
543
+ counts = df[col].value_counts().head(10)
544
+
545
+ summary = counts.to_string()
546
+
547
+ chart_df = counts.reset_index()
548
+ chart_df.columns = ["category", "count"]
549
+
550
+ chart_data = {
551
+ "type": "bar",
552
+ "title": f"Studies by {group_by.capitalize()}",
553
+ "data": chart_df.to_dict("records"),
554
+ "x": "category",
555
+ "y": "count",
556
+ }
557
+
558
+ if "inline_chart_data" not in st.session_state:
559
+ st.session_state["inline_chart_data"] = chart_data
560
+ else:
561
+ st.session_state["inline_chart_data"] = chart_data
562
+
563
+ return f"Found {len(df)} studies. Top counts:\n{summary}\n\n(Chart generated in UI)"
564
+
565
+
566
+ @langchain_tool("get_study_analytics")
567
+ def get_study_analytics(
568
+ query: str,
569
+ group_by: str,
570
+ phase: Optional[str] = None,
571
+ status: Optional[str] = None,
572
+ sponsor: Optional[str] = None,
573
+ intervention: Optional[str] = None,
574
+ start_year: Optional[int] = None,
575
+ study_type: Optional[str] = None,
576
+ ):
577
+ """
578
+ Aggregates clinical trial data based on a search query and groups by a specific field.
579
+
580
+ This tool performs the following steps:
581
+ 1. Retrieves a large number of relevant studies (up to 500).
582
+ 2. Applies strict filters (Phase, Status, Sponsor) in memory (Pandas).
583
+ 3. Groups the data by the requested field (e.g., Sponsor).
584
+ 4. Generates a summary string for the LLM.
585
+ 5. **Side Effect**: Injects chart data into `st.session_state` to trigger an inline chart in the UI.
586
+
587
+ Args:
588
+ query (str): The search query to filter studies (e.g., "cancer").
589
+ group_by (str): The field to group by. Options: "phase", "status", "sponsor", "start_year", "condition".
590
+ phase (Optional[str]): Optional filter for phase (e.g., "PHASE2").
591
+ status (Optional[str]): Optional filter for status (e.g., "RECRUITING").
592
+ sponsor (Optional[str]): Optional filter for sponsor (e.g., "Pfizer").
593
+ intervention (Optional[str]): Optional filter for intervention (e.g., "Keytruda").
594
+
595
+ Returns:
596
+ str: A summary string of the top counts and a note that a chart has been generated.
597
+ """
598
+ return fetch_study_analytics_data(
599
+ query=query,
600
+ group_by=group_by,
601
+ phase=phase,
602
+ status=status,
603
+ sponsor=sponsor,
604
+ intervention=intervention,
605
+ start_year=start_year,
606
+ study_type=study_type,
607
+ )
608
+
609
+
610
+ @langchain_tool("compare_studies")
611
+ def compare_studies(query: str):
612
+ """
613
+ Compares multiple studies or answers complex multi-part questions using query decomposition.
614
+
615
+ Use this tool when the user asks to "compare", "contrast", or analyze differences/similarities
616
+ between specific studies, sponsors, or phases. It breaks down the question into sub-questions.
617
+
618
+ Args:
619
+ query (str): The complex comparison query (e.g., "Compare the primary outcomes of Keytruda vs Opdivo").
620
+
621
+ Returns:
622
+ str: A detailed response synthesizing the answers to sub-questions.
623
+ """
624
+ index = load_index()
625
+
626
+ # Create a base query engine for the sub-questions
627
+ # Increase top_k and add re-ranking to improve recall for comparison queries
628
+ reranker = SentenceTransformerRerank(model="cross-encoder/ms-marco-MiniLM-L-12-v2", top_n=10)
629
+
630
+ base_engine = index.as_query_engine(
631
+ similarity_top_k=50,
632
+ node_postprocessors=[reranker]
633
+ )
634
+
635
+ # Wrap it in a QueryEngineTool
636
+ query_tool = QueryEngineTool(
637
+ query_engine=base_engine,
638
+ metadata=ToolMetadata(
639
+ name="clinical_trials_db",
640
+ description="Vector database of clinical trial protocols, results, and metadata.",
641
+ ),
642
+ )
643
+
644
+ # Create the SubQuestionQueryEngine
645
+ # Explicitly define the question generator to use the configured LLM (Gemini)
646
+ # This avoids the default behavior which might try to import OpenAI modules
647
+ from llama_index.core.question_gen import LLMQuestionGenerator
648
+ from llama_index.core import Settings
649
+
650
+ question_gen = LLMQuestionGenerator.from_defaults(llm=Settings.llm)
651
+
652
+ query_engine = SubQuestionQueryEngine.from_defaults(
653
+ query_engine_tools=[query_tool],
654
+ question_gen=question_gen,
655
+ use_async=True,
656
+ )
657
+
658
+ try:
659
+ response = query_engine.query(query)
660
+ return str(response) + "\n\n(Note: This analysis is based on the most relevant studies retrieved from the database, not necessarily an exhaustive list.)"
661
+ except Exception as e:
662
+ return f"Error during comparison: {e}"
663
+
664
+
665
+ @langchain_tool("get_study_details")
666
+ def get_study_details(nct_id: str):
667
+ """
668
+ Retrieves the full details of a specific clinical trial by its NCT ID.
669
+
670
+ Use this tool when the user asks for specific information about a single study,
671
+ such as "What are the inclusion criteria for NCT12345678?" or "Give me a summary of study NCT...".
672
+ It returns the full text content of the study document, including criteria, outcomes, and contacts.
673
+
674
+ Args:
675
+ nct_id (str): The NCT ID of the study (e.g., "NCT01234567").
676
+
677
+ Returns:
678
+ str: The full text content of the study, or a message if not found.
679
+ """
680
+ index = load_index()
681
+
682
+ # Clean the ID
683
+ clean_id = nct_id.strip().upper()
684
+
685
+ # Use a retriever with a strict metadata filter for the ID
686
+ # Set top_k=20 to capture all chunks if the document was split
687
+ filters = MetadataFilters(
688
+ filters=[
689
+ MetadataFilter(key="nct_id", value=clean_id, operator=FilterOperator.EQ)
690
+ ]
691
+ )
692
+
693
+ retriever = index.as_retriever(similarity_top_k=20, filters=filters)
694
+ nodes = retriever.retrieve(clean_id)
695
+
696
+ if not nodes:
697
+ return f"Study {clean_id} not found in the database."
698
+
699
+ # Sort nodes by their position in the document to reconstruct full text
700
+ # LlamaIndex nodes usually have 'start_char_idx' in metadata or relationships
701
+ # Try to sort by node ID or just concatenate them
702
+
703
+ # Simple concatenation (assuming retrieval order is roughly correct or sufficient)
704
+ full_text = "\n\n".join([node.text for node in nodes])
705
+
706
+ return f"Details for {clean_id} (Combined {len(nodes)} parts):\n\n{full_text}"
modules/utils.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for the Clinical Trial Agent.
3
+
4
+ Handles configuration, LanceDB index loading, data normalization, and custom filtering logic.
5
+ """
6
+
7
+ import os
8
+ import streamlit as st
9
+ from typing import List, Optional
10
+ from llama_index.core import VectorStoreIndex, StorageContext, Settings
11
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
12
+ from llama_index.vector_stores.lancedb import LanceDBVectorStore
13
+ from llama_index.llms.gemini import Gemini
14
+ import lancedb
15
+ from dotenv import load_dotenv
16
+
17
+ # --- MONKEYPATCH START ---
18
+ # Patch LanceDBVectorStore to handle 'nprobes' AttributeError and fix SQL quoting for IN filters.
19
+ original_query = LanceDBVectorStore.query
20
+
21
+ def patched_query(self, query, **kwargs):
22
+ try:
23
+ return original_query(self, query, **kwargs)
24
+ except Exception as e:
25
+ print(f"⚠️ LanceDB Query Error: {e}")
26
+ if hasattr(query, "filters"):
27
+ print(f" Filters: {query.filters}")
28
+
29
+ if "nprobes" in str(e):
30
+ from llama_index.core.vector_stores.types import VectorStoreQueryResult
31
+ return VectorStoreQueryResult(nodes=[], similarities=[], ids=[])
32
+ raise e
33
+
34
+ LanceDBVectorStore.query = patched_query
35
+
36
+ # Patch _to_lance_filter to fix SQL quoting for IN operator with strings.
37
+ from llama_index.vector_stores.lancedb import base as lancedb_base
38
+ from llama_index.core.vector_stores.types import FilterOperator
39
+
40
+ original_to_lance_filter = lancedb_base._to_lance_filter
41
+
42
+ def patched_to_lance_filter(standard_filters, metadata_keys):
43
+ if not standard_filters:
44
+ return None
45
+
46
+ # Reimplement filter logic to ensure correct SQL generation for LanceDB
47
+ filters = []
48
+ for filter in standard_filters.filters:
49
+ key = filter.key
50
+ if metadata_keys and key not in metadata_keys:
51
+ continue
52
+
53
+ # Prefix key with 'metadata.' for LanceDB struct column
54
+ lance_key = f"metadata.{key}"
55
+
56
+ # Handle IN operator with proper string quoting
57
+ if filter.operator == FilterOperator.IN:
58
+ if isinstance(filter.value, list):
59
+ # Quote strings properly
60
+ values = []
61
+ for v in filter.value:
62
+ if isinstance(v, str):
63
+ values.append(f"'{v}'") # Single quotes for SQL
64
+ else:
65
+ values.append(str(v))
66
+ val_str = ", ".join(values)
67
+ filters.append(f"{lance_key} IN ({val_str})")
68
+ continue
69
+
70
+ # Standard operators
71
+ op = filter.operator
72
+ val = filter.value
73
+
74
+ if op == FilterOperator.EQ:
75
+ if isinstance(val, str):
76
+ filters.append(f"{lance_key} = '{val}'")
77
+ else:
78
+ filters.append(f"{lance_key} = {val}")
79
+ elif op == FilterOperator.GT:
80
+ filters.append(f"{lance_key} > {val}")
81
+ elif op == FilterOperator.LT:
82
+ filters.append(f"{lance_key} < {val}")
83
+ elif op == FilterOperator.GTE:
84
+ filters.append(f"{lance_key} >= {val}")
85
+ elif op == FilterOperator.LTE:
86
+ filters.append(f"{lance_key} <= {val}")
87
+ elif op == FilterOperator.NE:
88
+ if isinstance(val, str):
89
+ filters.append(f"{lance_key} != '{val}'")
90
+ else:
91
+ filters.append(f"{lance_key} != {val}")
92
+ # Add other operators as needed
93
+
94
+ if not filters:
95
+ return None
96
+
97
+ return " AND ".join(filters)
98
+
99
+ lancedb_base._to_lance_filter = patched_to_lance_filter
100
+ # --- MONKEYPATCH END ---
101
+
102
+
103
+ def load_environment():
104
+ """Loads environment variables from .env file."""
105
+ load_dotenv()
106
+
107
+
108
+ # --- Configuration ---
109
+ def setup_llama_index(api_key: Optional[str] = None):
110
+ """
111
+ Configures global LlamaIndex settings (LLM and Embeddings).
112
+ """
113
+ # Use passed key, or fallback to env var
114
+ final_key = api_key or os.environ.get("GOOGLE_API_KEY")
115
+
116
+ if not final_key:
117
+ # App handles prompting for key, so we just return or log warning
118
+ pass
119
+
120
+ try:
121
+ # Pass the key explicitly if available
122
+ Settings.llm = Gemini(model="models/gemini-2.5-flash", temperature=0, api_key=final_key)
123
+ except Exception as e:
124
+ print(f"⚠️ LLM initialization failed (likely missing API key): {e}")
125
+ print("⚠️ Using MockLLM for testing/fallback.")
126
+ from llama_index.core.llms import MockLLM
127
+ Settings.llm = MockLLM()
128
+
129
+ Settings.embed_model = HuggingFaceEmbedding(
130
+ model_name="pritamdeka/S-PubMedBert-MS-MARCO"
131
+ )
132
+
133
+
134
+ @st.cache_resource
135
+ def load_index() -> VectorStoreIndex:
136
+ """
137
+ Loads and caches the persistent LanceDB index.
138
+ """
139
+ setup_llama_index()
140
+
141
+ # Initialize LanceDB
142
+ db_path = "./ct_gov_lancedb"
143
+ db = lancedb.connect(db_path)
144
+
145
+ # Define metadata keys explicitly to ensure filters work
146
+ metadata_keys = [
147
+ "nct_id", "title", "org", "sponsor", "status", "phase",
148
+ "study_type", "start_year", "condition", "intervention",
149
+ "country", "state"
150
+ ]
151
+
152
+ # Create the vector store wrapper
153
+ vector_store = LanceDBVectorStore(
154
+ uri=db_path,
155
+ table_name="clinical_trials",
156
+ query_mode="hybrid",
157
+ )
158
+
159
+ # Manually set metadata keys as constructor doesn't accept them
160
+ vector_store._metadata_keys = metadata_keys
161
+
162
+ # Create storage context
163
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
164
+
165
+ # Load the index from the vector store
166
+ index = VectorStoreIndex.from_vector_store(
167
+ vector_store, storage_context=storage_context
168
+ )
169
+ return index
170
+
171
+
172
+ def get_hybrid_retriever(index: VectorStoreIndex, similarity_top_k: int = 50, filters=None):
173
+ """
174
+ Creates a Hybrid Retriever using LanceDB's native hybrid search.
175
+
176
+ Args:
177
+ index (VectorStoreIndex): The loaded vector index.
178
+ similarity_top_k (int): Number of top results to retrieve.
179
+ filters (MetadataFilters, optional): Filters to apply.
180
+
181
+ Returns:
182
+ VectorIndexRetriever: The configured retriever.
183
+ """
184
+ # LanceDB supports native hybrid search via query_mode="hybrid"
185
+ # We pass this configuration to the retriever
186
+ # Use standard retriever first to avoid LanceDB hybrid search issues on small datasets
187
+ return index.as_retriever(
188
+ similarity_top_k=similarity_top_k,
189
+ filters=filters,
190
+ )
191
+
192
+
193
+ # --- Normalization ---
194
+
195
+ # Centralized Sponsor Mappings
196
+ # Key: Canonical Name
197
+ # Value: List of variations/aliases (including the canonical name itself if needed for matching)
198
+ SPONSOR_MAPPINGS = {
199
+ "GlaxoSmithKline": [
200
+ "gsk", "glaxo", "glaxosmithkline", "glaxosmithkline",
201
+ "GlaxoSmithKline"
202
+ ],
203
+ "Janssen": [
204
+ "j&j", "johnson & johnson", "johnson and johnson", "janssen", "Janssen",
205
+ "Janssen Research & Development, LLC",
206
+ "Janssen Vaccines & Prevention B.V.",
207
+ "Janssen Pharmaceutical K.K.",
208
+ "Janssen-Cilag International NV",
209
+ "Janssen Sciences Ireland UC",
210
+ "Janssen Pharmaceutica N.V., Belgium",
211
+ "Janssen Scientific Affairs, LLC",
212
+ "Janssen-Cilag Ltd.",
213
+ "Xian-Janssen Pharmaceutical Ltd.",
214
+ "Janssen Korea, Ltd., Korea",
215
+ "Janssen-Cilag G.m.b.H",
216
+ "Janssen-Cilag, S.A.",
217
+ "Janssen BioPharma, Inc.",
218
+ ],
219
+ "Bristol-Myers Squibb": [
220
+ "bms", "bristol", "bristol myers squibb", "bristol-myers squibb",
221
+ "Bristol-Myers Squibb"
222
+ ],
223
+ "Merck Sharp & Dohme": [
224
+ "merck", "msd", "merck sharp & dohme",
225
+ "Merck Sharp & Dohme LLC"
226
+ ],
227
+ "Pfizer": ["pfizer", "Pfizer", "Pfizer Inc."],
228
+ "AstraZeneca": ["astrazeneca", "AstraZeneca"],
229
+ "Eli Lilly and Company": ["lilly", "eli lilly", "Eli Lilly and Company"],
230
+ "Sanofi": ["sanofi", "Sanofi"],
231
+ "Novartis": ["novartis", "Novartis"],
232
+ }
233
+
234
+ def normalize_sponsor(sponsor: str) -> Optional[str]:
235
+ """
236
+ Normalizes sponsor names to canonical forms using centralized mappings.
237
+ """
238
+ if not sponsor:
239
+ return None
240
+
241
+ s = sponsor.lower().strip()
242
+
243
+ for canonical, variations in SPONSOR_MAPPINGS.items():
244
+ # Check if input matches canonical name (case-insensitive)
245
+ if s == canonical.lower():
246
+ return canonical
247
+
248
+ # Check variations and aliases
249
+ for v in variations:
250
+ v_lower = v.lower()
251
+ if v_lower == s:
252
+ return canonical
253
+ # If the variation is a known alias (like 'gsk'), check if it's in the string
254
+ if len(v) < 5 and v_lower in s:
255
+ return canonical
256
+
257
+ if canonical.lower() in s:
258
+ return canonical
259
+
260
+ return sponsor
261
+
262
+
263
+ def get_sponsor_variations(sponsor: str) -> Optional[List[str]]:
264
+ """
265
+ Returns list of exact database 'org' values for a given sponsor alias.
266
+ """
267
+ if not sponsor:
268
+ return None
269
+
270
+ # First, normalize the input to get the canonical name
271
+ canonical = normalize_sponsor(sponsor)
272
+
273
+ if canonical in SPONSOR_MAPPINGS:
274
+ return SPONSOR_MAPPINGS[canonical]
275
+
276
+ return None
277
+
278
+
279
+
280
+
281
+
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ requests
3
+ python-dotenv
4
+ langchain
5
+ langchain-community
6
+ langchain-google-genai==2.0.0
7
+ lancedb
8
+ lark
9
+ langchain-huggingface
10
+ llama-index
11
+ llama-index-vector-stores-lancedb
12
+ llama-index-embeddings-huggingface
13
+ llama-index-llms-gemini
14
+ streamlit-option-menu
15
+ streamlit-agraph
16
+ folium
17
+ streamlit-folium
18
+ rank_bm25
19
+ llama-index-retrievers-bm25
scripts/analyze_db.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Database Analysis Script.
3
+
4
+ This script connects to the local ChromaDB vector store and performs a quick analysis
5
+ of the ingested clinical trial data. It prints statistics about:
6
+ - Top Sponsors
7
+ - Phase Distribution
8
+ - Status Distribution
9
+ - Top Medical Conditions
10
+ - Sample of Recent Studies
11
+
12
+ Usage:
13
+ python scripts/analyze_db.py
14
+ # OR
15
+ cd scripts && python analyze_db.py
16
+ """
17
+
18
+ import lancedb
19
+ import pandas as pd
20
+ import os
21
+
22
+
23
+ def analyze_db():
24
+ """
25
+ Connects to ChromaDB and prints summary statistics of the dataset.
26
+ """
27
+ # Determine the project root directory (one level up from this script)
28
+ script_dir = os.path.dirname(os.path.abspath(__file__))
29
+ project_root = os.path.dirname(script_dir)
30
+ db_path = os.path.join(project_root, "ct_gov_lancedb")
31
+
32
+ if not os.path.exists(db_path):
33
+ print(f"❌ Database directory '{db_path}' does not exist.")
34
+ print(" Please run 'python scripts/ingest_ct.py' first to ingest data.")
35
+ return
36
+
37
+ print(f"📂 Loading database from {db_path}...")
38
+ try:
39
+ db = lancedb.connect(db_path)
40
+
41
+ # Check for table existence
42
+ if "clinical_trials" not in db.table_names():
43
+ print(f"❌ Table 'clinical_trials' not found. Available: {db.table_names()}")
44
+ return
45
+
46
+ tbl = db.open_table("clinical_trials")
47
+ count = len(tbl)
48
+ print(f"✅ Found 'clinical_trials' table with {count} documents.")
49
+
50
+ # Fetch all data for analysis
51
+ df = tbl.to_pandas()
52
+
53
+ if df.empty:
54
+ print("❌ No data found.")
55
+ return
56
+
57
+ # Handle metadata if nested (LlamaIndex might nest it)
58
+ if "metadata" in df.columns:
59
+ # Try to flatten if it's a struct/dict
60
+ try:
61
+ meta_df = pd.json_normalize(df["metadata"])
62
+ # Merge with original df or just use meta_df for analysis
63
+ # We'll use meta_df for the metadata fields analysis
64
+ # But we might need 'text' from original
65
+ df = pd.concat([df.drop(columns=["metadata"]), meta_df], axis=1)
66
+ except:
67
+ pass
68
+
69
+ if "nct_id" in df.columns:
70
+ unique_ncts = df["nct_id"].nunique()
71
+ print(f"🔢 Unique NCT IDs: {unique_ncts}")
72
+ if unique_ncts < count:
73
+ print(f"⚠️ Warning: {count - unique_ncts} duplicate records found!")
74
+ else:
75
+ print("⚠️ 'nct_id' field not found in metadata.")
76
+
77
+ # --- Analysis Sections ---
78
+
79
+ print("\n📊 --- Top 10 Sponsors ---")
80
+ if "org" in df.columns:
81
+ print(df["org"].value_counts().head(10))
82
+ else:
83
+ print("⚠️ 'org' field not found in metadata.")
84
+
85
+ print("\n📊 --- Phase Distribution ---")
86
+ if "phase" in df.columns:
87
+ print(df["phase"].value_counts())
88
+ else:
89
+ print("⚠️ 'phase' field not found in metadata.")
90
+
91
+ print("\n📊 --- Status Distribution ---")
92
+ if "status" in df.columns:
93
+ print(df["status"].value_counts())
94
+ else:
95
+ print("⚠️ 'status' field not found in metadata.")
96
+
97
+ print("\n📊 --- Top Conditions ---")
98
+ if "condition" in df.columns:
99
+ # Conditions are comma-separated strings, so we split and explode them
100
+ all_conditions = []
101
+ for conditions in df["condition"].dropna():
102
+ all_conditions.extend([c.strip() for c in conditions.split(",")])
103
+ print(pd.Series(all_conditions).value_counts().head(10))
104
+ else:
105
+ print("⚠️ 'condition' field not found in metadata.")
106
+
107
+ print("\n📊 --- Top Interventions ---")
108
+ if "intervention" in df.columns:
109
+ # Interventions are semicolon-separated strings (from ingest_ct.py), so we split by "; "
110
+ all_interventions = []
111
+ for interventions in df["intervention"].dropna():
112
+ # Split by semicolon and strip whitespace
113
+ parts = [i.strip() for i in interventions.split(";") if i.strip()]
114
+ all_interventions.extend(parts)
115
+
116
+ if all_interventions:
117
+ print(pd.Series(all_interventions).value_counts().head(20))
118
+ else:
119
+ print("No interventions found.")
120
+ else:
121
+ print("⚠️ 'intervention' field not found in metadata.")
122
+
123
+ print("\n📝 --- Sample Studies (Most Recent Start Years) ---")
124
+ if "start_year" in df.columns and "title" in df.columns:
125
+ # Ensure start_year is numeric for sorting
126
+ df["start_year"] = pd.to_numeric(df["start_year"], errors="coerce")
127
+ top_recent = df.sort_values(by="start_year", ascending=False).head(5)
128
+ for _, row in top_recent.iterrows():
129
+ print(
130
+ f"- [{row.get('start_year', 'N/A')}] {row.get('title', 'N/A')} ({row.get('nct_id', 'N/A')})"
131
+ )
132
+ print(f" Sponsor: {row.get('org', 'N/A')}")
133
+ print(f" Intervention: {row.get('intervention', 'N/A')}")
134
+
135
+ print("\n📊 --- Intervention Check ---")
136
+ if "intervention" in df.columns:
137
+ non_empty = df[df["intervention"].str.len() > 0]
138
+ print(f"Total records with interventions: {len(non_empty)}")
139
+ if not non_empty.empty:
140
+ print("Sample Intervention:", non_empty.iloc[0]["intervention"])
141
+ else:
142
+ print("⚠️ 'intervention' field not found.")
143
+
144
+ except Exception as e:
145
+ print(f"⚠️ Error analyzing DB: {e}")
146
+
147
+
148
+ if __name__ == "__main__":
149
+ analyze_db()
scripts/ingest_ct.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data Ingestion Script for Clinical Trial Agent.
3
+
4
+ This script fetches clinical trial data from the ClinicalTrials.gov API (v2),
5
+ processes it into a rich text format, and ingests it into a local ChromaDB vector index
6
+ using LlamaIndex and PubMedBERT embeddings.
7
+
8
+ Features:
9
+ - **Pagination**: Fetches data in batches using the API's pagination tokens.
10
+ - **Robustness**: Implements retry logic for network errors.
11
+ - **Efficiency**: Uses batch insertion and reuses the existing index.
12
+ - **Progress Tracking**: Displays a progress bar using `tqdm`.
13
+ """
14
+
15
+ import requests
16
+ import re
17
+ from datetime import datetime, timedelta
18
+ from dotenv import load_dotenv
19
+ import argparse
20
+ import time
21
+ from tqdm import tqdm
22
+ import os
23
+ import concurrent.futures
24
+
25
+ # LlamaIndex Imports
26
+ from llama_index.core import Document, VectorStoreIndex, StorageContext, Settings
27
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
28
+ from llama_index.vector_stores.lancedb import LanceDBVectorStore
29
+ import lancedb
30
+
31
+ # List of US States for extraction
32
+ US_STATES = [
33
+ "Alabama", "Alaska", "Arizona", "Arkansas", "California", "Colorado", "Connecticut",
34
+ "Delaware", "Florida", "Georgia", "Hawaii", "Idaho", "Illinois", "Indiana", "Iowa",
35
+ "Kansas", "Kentucky", "Louisiana", "Maine", "Maryland", "Massachusetts", "Michigan",
36
+ "Minnesota", "Mississippi", "Missouri", "Montana", "Nebraska", "Nevada", "New Hampshire",
37
+ "New Jersey", "New Mexico", "New York", "North Carolina", "North Dakota", "Ohio",
38
+ "Oklahoma", "Oregon", "Pennsylvania", "Rhode Island", "South Carolina", "South Dakota",
39
+ "Tennessee", "Texas", "Utah", "Vermont", "Virginia", "Washington", "West Virginia",
40
+ "Wisconsin", "Wyoming", "District of Columbia"
41
+ ]
42
+
43
+ load_dotenv()
44
+
45
+ # Disable LLM for ingestion (we only need embeddings, not generation)
46
+ Settings.llm = None
47
+
48
+
49
+ def clean_text(text: str) -> str:
50
+ """
51
+ Cleans raw text by removing HTML tags and normalizing whitespace.
52
+
53
+ Args:
54
+ text (str): The raw text string.
55
+
56
+ Returns:
57
+ str: The cleaned text.
58
+ """
59
+ if not text:
60
+ return ""
61
+ # Remove HTML tags
62
+ text = re.sub(r"<[^>]+>", "", text)
63
+ # Remove multiple spaces/newlines and trim
64
+ text = re.sub(r"\s+", " ", text).strip()
65
+ return text
66
+
67
+
68
+ def fetch_trials_generator(
69
+ years: int = 5, max_studies: int = 1000, status: list = None, phases: list = None
70
+ ):
71
+ """
72
+ Yields batches of clinical trials from the ClinicalTrials.gov API.
73
+
74
+ Handles pagination automatically and implements retry logic for API requests.
75
+
76
+ Args:
77
+ years (int): Number of years to look back for study start dates.
78
+ max_studies (int): Maximum total number of studies to fetch (-1 for all).
79
+ status (list): List of status strings to filter by (e.g., ["RECRUITING"]).
80
+ phases (list): List of phase strings to filter by (e.g., ["PHASE2"]).
81
+
82
+ Yields:
83
+ list: A batch of study dictionaries (JSON objects).
84
+ """
85
+ base_url = "https://clinicaltrials.gov/api/v2/studies"
86
+
87
+ # Calculate start date for filtering
88
+ start_date = (datetime.now() - timedelta(days=365 * years)).strftime("%Y-%m-%d")
89
+ print("📡 Connecting to CT.gov API...")
90
+ print(f"🔎 Fetching trials starting after: {start_date}")
91
+ if status:
92
+ print(f" Filters - Status: {status}")
93
+ if phases:
94
+ print(f" Filters - Phases: {phases}")
95
+
96
+ fetched_count = 0
97
+ next_page_token = None
98
+
99
+ # If max_studies is -1, fetch ALL studies (infinite limit)
100
+ fetch_limit = float("inf") if max_studies == -1 else max_studies
101
+
102
+ while fetched_count < fetch_limit:
103
+ # Determine batch size (max 1000 per API limit)
104
+ current_limit = 1000
105
+ if max_studies != -1:
106
+ current_limit = min(1000, max_studies - fetched_count)
107
+
108
+ # --- Query Construction ---
109
+ # Build the query term using the API's syntax
110
+ query_parts = [f"AREA[StartDate]RANGE[{start_date},MAX]"]
111
+
112
+ if status:
113
+ status_str = " OR ".join(status)
114
+ query_parts.append(f"AREA[OverallStatus]({status_str})")
115
+
116
+ if phases:
117
+ phase_str = " OR ".join(phases)
118
+ query_parts.append(f"AREA[Phase]({phase_str})")
119
+
120
+ full_query = " AND ".join(query_parts)
121
+
122
+ params = {
123
+ "query.term": full_query,
124
+ "pageSize": current_limit,
125
+ # Request specific fields to minimize payload size
126
+ "fields": ",".join(
127
+ [
128
+ "protocolSection.identificationModule.nctId",
129
+ "protocolSection.identificationModule.briefTitle",
130
+ "protocolSection.identificationModule.officialTitle",
131
+ "protocolSection.identificationModule.organization",
132
+ "protocolSection.statusModule.overallStatus",
133
+ "protocolSection.statusModule.startDateStruct",
134
+ "protocolSection.statusModule.completionDateStruct",
135
+ "protocolSection.designModule.phases",
136
+ "protocolSection.designModule.studyType",
137
+ "protocolSection.eligibilityModule.eligibilityCriteria",
138
+ "protocolSection.eligibilityModule.sex",
139
+ "protocolSection.eligibilityModule.stdAges",
140
+ "protocolSection.descriptionModule.briefSummary",
141
+ "protocolSection.conditionsModule.conditions",
142
+ "protocolSection.outcomesModule.primaryOutcomes",
143
+ "protocolSection.contactsLocationsModule.locations",
144
+ "protocolSection.outcomesModule.primaryOutcomes",
145
+ "protocolSection.contactsLocationsModule.locations",
146
+ "protocolSection.armsInterventionsModule",
147
+ "protocolSection.sponsorCollaboratorsModule.leadSponsor",
148
+ ]
149
+ ),
150
+ }
151
+
152
+ if next_page_token:
153
+ params["pageToken"] = next_page_token
154
+
155
+ # --- Retry Logic ---
156
+ retries = 3
157
+ for attempt in range(retries):
158
+ try:
159
+ response = requests.get(base_url, params=params, timeout=30)
160
+ if response.status_code == 200:
161
+ data = response.json()
162
+ studies = data.get("studies", [])
163
+
164
+ if not studies:
165
+ return # Stop generator if no studies returned
166
+
167
+ yield studies
168
+
169
+ fetched_count += len(studies)
170
+ next_page_token = data.get("nextPageToken")
171
+
172
+ if not next_page_token:
173
+ return # Stop generator if no more pages
174
+
175
+ break # Success, exit retry loop
176
+ else:
177
+ print(f"❌ API Error: {response.status_code} - {response.text}")
178
+ if attempt < retries - 1:
179
+ time.sleep(2)
180
+ else:
181
+ return # Stop generator on persistent error
182
+ except Exception as e:
183
+ print(f"❌ Request Error (Attempt {attempt+1}/{retries}): {e}")
184
+ if attempt < retries - 1:
185
+ time.sleep(2)
186
+ else:
187
+ return # Stop generator
188
+
189
+
190
+ def process_study(study):
191
+ """
192
+ Processes a single study dictionary into a LlamaIndex Document.
193
+ This function is designed to be run in parallel.
194
+ """
195
+ try:
196
+ # Extract Modules
197
+ protocol = study.get("protocolSection", {})
198
+ identification = protocol.get("identificationModule", {})
199
+ status_module = protocol.get("statusModule", {})
200
+ design = protocol.get("designModule", {})
201
+ eligibility = protocol.get("eligibilityModule", {})
202
+ description = protocol.get("descriptionModule", {})
203
+ conditions_module = protocol.get("conditionsModule", {})
204
+ outcomes_module = protocol.get("outcomesModule", {})
205
+ arms_interventions_module = protocol.get("armsInterventionsModule", {})
206
+ outcomes_module = protocol.get("outcomesModule", {})
207
+ arms_interventions_module = protocol.get("armsInterventionsModule", {})
208
+ locations_module = protocol.get("contactsLocationsModule", {})
209
+ sponsor_module = protocol.get("sponsorCollaboratorsModule", {})
210
+
211
+ # Extract Fields
212
+ nct_id = identification.get("nctId", "N/A")
213
+ title = identification.get("briefTitle", "N/A")
214
+ official_title = identification.get("officialTitle", "N/A")
215
+ official_title = identification.get("officialTitle", "N/A")
216
+ org = identification.get("organization", {}).get("fullName", "N/A")
217
+ sponsor_name = sponsor_module.get("leadSponsor", {}).get("name", "N/A")
218
+ summary = clean_text(description.get("briefSummary", "N/A"))
219
+
220
+ overall_status = status_module.get("overallStatus", "N/A")
221
+ start_date = status_module.get("startDateStruct", {}).get("date", "N/A")
222
+ completion_date = status_module.get("completionDateStruct", {}).get(
223
+ "date", "N/A"
224
+ )
225
+
226
+ phases = ", ".join(design.get("phases", []))
227
+ study_type = design.get("studyType", "N/A")
228
+
229
+ criteria = clean_text(eligibility.get("eligibilityCriteria", "N/A"))
230
+ gender = eligibility.get("sex", "N/A")
231
+ ages = ", ".join(eligibility.get("stdAges", []))
232
+
233
+ conditions = ", ".join(conditions_module.get("conditions", []))
234
+
235
+ interventions = []
236
+ for interv in arms_interventions_module.get("interventions", []):
237
+ name = interv.get("name", "")
238
+ type_ = interv.get("type", "")
239
+ interventions.append(f"{type_}: {name}")
240
+ interventions_str = "; ".join(interventions)
241
+
242
+ primary_outcomes = []
243
+ for outcome in outcomes_module.get("primaryOutcomes", []):
244
+ measure = outcome.get("measure", "")
245
+ desc = outcome.get("description", "")
246
+ primary_outcomes.append(f"- {measure}: {desc}")
247
+ outcomes_str = clean_text("\n".join(primary_outcomes))
248
+
249
+ locations = []
250
+ for loc in locations_module.get("locations", []):
251
+ facility = loc.get("facility", "N/A")
252
+ city = loc.get("city", "")
253
+ country = loc.get("country", "")
254
+ locations.append(f"{facility} ({city}, {country})")
255
+ locations_str = "; ".join(locations[:5]) # Limit to 5 locations to save space
256
+
257
+ # Extract State (First match)
258
+ state = "Unknown"
259
+ # Check locations for US States
260
+ for loc_str in locations:
261
+ if "United States" in loc_str:
262
+ for s in US_STATES:
263
+ if s in loc_str:
264
+ state = s
265
+ break
266
+ if state != "Unknown":
267
+ break
268
+
269
+ # Construct Rich Page Content with Markdown Headers
270
+ # This text is what gets embedded and searched
271
+ page_content = (
272
+ f"# {title}\n"
273
+ f"**NCT ID:** {nct_id}\n"
274
+ f"**Official Title:** {official_title}\n"
275
+ f"**Sponsor:** {sponsor_name}\n"
276
+ f"**Organization:** {org}\n"
277
+ f"**Status:** {overall_status}\n"
278
+ f"**Phase:** {phases}\n"
279
+ f"**Study Type:** {study_type}\n"
280
+ f"**Start Date:** {start_date}\n"
281
+ f"**Completion Date:** {completion_date}\n\n"
282
+ f"## Summary\n{summary}\n\n"
283
+ f"## Conditions\n{conditions}\n\n"
284
+ f"## Interventions\n{interventions_str}\n\n"
285
+ f"## Eligibility Criteria\n"
286
+ f"**Gender:** {gender}\n"
287
+ f"**Ages:** {ages}\n"
288
+ f"**Criteria:**\n{criteria}\n\n"
289
+ f"## Primary Outcomes\n{outcomes_str}\n\n"
290
+ f"## Locations\n{locations_str}"
291
+ )
292
+
293
+ # Metadata for filtering (Structured Data)
294
+ metadata = {
295
+ "nct_id": nct_id,
296
+ "title": title,
297
+ "org": org,
298
+ "sponsor": sponsor_name,
299
+ "status": overall_status,
300
+ "phase": phases,
301
+ "study_type": study_type,
302
+ "start_year": (int(start_date.split("-")[0]) if start_date != "N/A" else 0),
303
+ "condition": conditions,
304
+ "intervention": interventions_str,
305
+ "country": (
306
+ locations[0].split(",")[-1].strip() if locations else "Unknown"
307
+ ),
308
+ "state": state,
309
+ }
310
+
311
+ return Document(text=page_content, metadata=metadata, id_=nct_id)
312
+ except Exception as e:
313
+ print(
314
+ f"⚠️ Error processing study {study.get('protocolSection', {}).get('identificationModule', {}).get('nctId', 'Unknown')}: {e}"
315
+ )
316
+ return None
317
+
318
+
319
+ def run_ingestion():
320
+ """
321
+ Main execution function for the ingestion script.
322
+ Parses arguments, initializes the index, and runs the ingestion loop.
323
+ """
324
+ parser = argparse.ArgumentParser(description="Ingest Clinical Trials data.")
325
+ parser.add_argument(
326
+ "--limit",
327
+ type=int,
328
+ default=-1,
329
+ help="Number of studies to ingest. Set to -1 for ALL.",
330
+ )
331
+ parser.add_argument(
332
+ "--years", type=int, default=10, help="Number of years to look back."
333
+ )
334
+ parser.add_argument(
335
+ "--status",
336
+ type=str,
337
+ default="COMPLETED",
338
+ help="Comma-separated list of statuses (e.g., COMPLETED,RECRUITING).",
339
+ )
340
+ parser.add_argument(
341
+ "--phases",
342
+ type=str,
343
+ default="PHASE1,PHASE2,PHASE3,PHASE4",
344
+ help="Comma-separated list of phases (e.g., PHASE2,PHASE3).",
345
+ )
346
+ args = parser.parse_args()
347
+
348
+ status_list = args.status.split(",") if args.status else []
349
+ phase_list = args.phases.split(",") if args.phases else []
350
+
351
+ print(f"⚙️ Configuration: Limit={args.limit}, Years={args.years}")
352
+ print(f" Status Filter: {status_list}")
353
+ print(f" Phase Filter: {phase_list}")
354
+
355
+ # --- INITIALIZE LLAMAINDEX COMPONENTS ---
356
+ print("🧠 Initializing LlamaIndex Embeddings (PubMedBERT)...")
357
+ embed_model = HuggingFaceEmbedding(model_name="pritamdeka/S-PubMedBert-MS-MARCO")
358
+
359
+ # Initialize LanceDB (Persistent)
360
+ print("🚀 Initializing LanceDB...")
361
+
362
+ # Determine the project root directory (one level up from this script)
363
+ script_dir = os.path.dirname(os.path.abspath(__file__))
364
+ project_root = os.path.dirname(script_dir)
365
+ db_path = os.path.join(project_root, "ct_gov_lancedb")
366
+
367
+ # Connect to LanceDB
368
+ db = lancedb.connect(db_path)
369
+
370
+ table_name = "clinical_trials"
371
+ if table_name in db.table_names():
372
+ mode = "append"
373
+ print(f"ℹ️ Table '{table_name}' exists. Appending data.")
374
+ else:
375
+ mode = "create"
376
+ print(f"ℹ️ Table '{table_name}' does not exist. Creating new table.")
377
+
378
+ # Initialize Vector Store
379
+ vector_store = LanceDBVectorStore(
380
+ uri=db_path,
381
+ table_name=table_name,
382
+ mode=mode,
383
+ query_mode="hybrid" # Enable hybrid search support
384
+ )
385
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
386
+
387
+ # Initialize Index ONCE
388
+ # We pass the storage context to link it to the vector store
389
+ index = VectorStoreIndex.from_vector_store(
390
+ vector_store, storage_context=storage_context, embed_model=embed_model
391
+ )
392
+
393
+ total_ingested = 0
394
+
395
+ # Progress Bar
396
+ pbar = tqdm(
397
+ total=args.limit if args.limit > 0 else float("inf"),
398
+ desc="Ingesting Studies",
399
+ unit="study",
400
+ )
401
+
402
+ # --- INGESTION LOOP ---
403
+ # Use ProcessPoolExecutor for parallel processing of study data
404
+ with concurrent.futures.ProcessPoolExecutor() as executor:
405
+ for batch_studies in fetch_trials_generator(
406
+ years=args.years,
407
+ max_studies=args.limit,
408
+ status=status_list,
409
+ phases=phase_list,
410
+ ):
411
+ # Parallelize the processing of the batch
412
+ # map returns an iterator, so we convert to list to trigger execution
413
+ documents_iter = executor.map(process_study, batch_studies)
414
+
415
+ # Filter out None results (errors)
416
+ documents = [doc for doc in documents_iter if doc is not None]
417
+
418
+ if documents:
419
+ # Overwrite Logic:
420
+ # To avoid duplicates, we delete existing records with the same NCT IDs.
421
+ doc_ids = [doc.id_ for doc in documents]
422
+ try:
423
+ # LanceDB supports deletion via SQL-like filter
424
+ # We construct a filter string: "nct_id IN ('NCT123', 'NCT456')"
425
+ ids_str = ", ".join([f"'{id}'" for id in doc_ids])
426
+ if ids_str:
427
+ tbl = db.open_table("clinical_trials")
428
+ tbl.delete(f"nct_id IN ({ids_str})")
429
+ except Exception as e:
430
+ # Ignore if table doesn't exist yet
431
+ pass
432
+
433
+ # Efficient Batch Insertion
434
+ # We convert documents to nodes and insert them into the index.
435
+ # This handles embedding generation automatically.
436
+ parser = Settings.node_parser
437
+ nodes = parser.get_nodes_from_documents(documents)
438
+
439
+ index.insert_nodes(nodes)
440
+
441
+ total_ingested += len(documents)
442
+ pbar.update(len(documents))
443
+
444
+ pbar.close()
445
+ print(f"🎉 Ingestion Complete! Total studies in DB: {total_ingested}")
446
+
447
+
448
+ if __name__ == "__main__":
449
+ run_ingestion()
scripts/remove_duplicates.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script to remove duplicate records from the LanceDB database.
3
+
4
+ This script scans the 'clinical_trials' table, identifies records with duplicate content
5
+ (same 'nct_id' AND same 'text'), and removes the extras.
6
+
7
+ It uses a safe "Fetch -> Dedupe -> Overwrite" strategy:
8
+ 1. Identifies NCT IDs that have duplicates.
9
+ 2. For each such NCT ID, fetches ALL its records (chunks).
10
+ 3. Deduplicates these records in memory based on their text content.
11
+ 4. Deletes ALL records for that NCT ID from the database.
12
+ 5. Re-inserts the unique records.
13
+
14
+ This ensures that valid chunks of the same study are PRESERVED, while exact duplicates are removed.
15
+ """
16
+
17
+ import os
18
+ import pandas as pd
19
+ import lancedb
20
+
21
+ import argparse
22
+
23
+ def calculate_richness(record):
24
+ """Calculates a 'richness' score for a record based on metadata field count and content length."""
25
+ score = 0
26
+ if not record:
27
+ return 0
28
+
29
+ for key, value in record.items():
30
+ if key == "vector": continue
31
+
32
+ # Handle nested metadata
33
+ if key == "metadata" and isinstance(value, dict):
34
+ score += calculate_richness(value) # Recurse
35
+ continue
36
+
37
+ # Check for non-empty values
38
+ if value is not None and str(value).strip() != "":
39
+ score += 10 # Base points for having a populated field
40
+
41
+ # Bonus points for content length
42
+ if isinstance(value, str):
43
+ score += len(value) / 100.0
44
+
45
+ return score
46
+
47
+ def remove_duplicates(dry_run=False):
48
+ # Determine the project root directory
49
+ script_dir = os.path.dirname(os.path.abspath(__file__))
50
+ project_root = os.path.dirname(script_dir)
51
+ db_path = os.path.join(project_root, "ct_gov_lancedb")
52
+
53
+ if not os.path.exists(db_path):
54
+ print(f"❌ Database directory '{db_path}' does not exist.")
55
+ return
56
+
57
+ print(f"📂 Loading database from {db_path}...")
58
+ if dry_run:
59
+ print("🧪 RUNNING IN DRY-RUN MODE (No changes will be made)")
60
+
61
+ try:
62
+ db = lancedb.connect(db_path)
63
+ tbl = db.open_table("clinical_trials")
64
+
65
+ print("🔍 Scanning for duplicates...")
66
+ # Fetch all data
67
+ df = tbl.to_pandas()
68
+
69
+ if df.empty:
70
+ print("Database is empty.")
71
+ return
72
+
73
+ # Create a working copy to flatten metadata for analysis
74
+ working_df = df.copy()
75
+ if "metadata" in working_df.columns:
76
+ # Flatten metadata
77
+ meta_df = pd.json_normalize(working_df["metadata"])
78
+ # We drop the original metadata column from working_df and join the flattened one
79
+ working_df = pd.concat([working_df.drop(columns=["metadata"]), meta_df], axis=1)
80
+
81
+ if "nct_id" not in working_df.columns:
82
+ print("❌ 'nct_id' column not found (checked metadata too).")
83
+ return
84
+
85
+ if "text" not in working_df.columns:
86
+ print("❌ 'text' column not found. Cannot safely deduplicate chunks.")
87
+ return
88
+
89
+ # Identify duplicates based on (nct_id, text) using the flattened view
90
+ duplicates_mask = working_df.duplicated(subset=["nct_id", "text"], keep=False)
91
+
92
+ # We use the mask on working_df to find the IDs
93
+ duplicates_working_df = working_df[duplicates_mask]
94
+
95
+ if duplicates_working_df.empty:
96
+ print("✅ No exact duplicates found. Database is clean.")
97
+ return
98
+
99
+ unique_duplicate_ids = duplicates_working_df["nct_id"].unique()
100
+ print(f"⚠️ Found duplicates affecting {len(unique_duplicate_ids)} studies (NCT IDs).")
101
+
102
+ total_deleted = 0
103
+ total_reinserted = 0
104
+
105
+ # Process each affected NCT ID
106
+ for nct_id in unique_duplicate_ids:
107
+ # Get indices from working_df where nct_id matches
108
+ # This ensures we are looking at the right rows in the ORIGINAL df
109
+ indices = working_df[working_df["nct_id"] == nct_id].index
110
+
111
+ # Extract original records (preserving structure)
112
+ study_records_df = df.loc[indices]
113
+ original_count = len(study_records_df)
114
+
115
+ unique_records = []
116
+ seen_texts = set()
117
+
118
+ records = study_records_df.to_dict("records")
119
+ records.sort(key=calculate_richness, reverse=True)
120
+
121
+ for record in records:
122
+ text_content = record.get("text", "")
123
+ if text_content not in seen_texts:
124
+ unique_records.append(record)
125
+ seen_texts.add(text_content)
126
+
127
+ new_count = len(unique_records)
128
+
129
+ if new_count < original_count:
130
+ print(f" - {nct_id}: Reducing {original_count} -> {new_count} records.")
131
+
132
+ if not dry_run:
133
+ # Delete using the ID (LanceDB SQL filter)
134
+ # Note: In LanceDB SQL, if nct_id is in metadata struct, we access it via metadata.nct_id
135
+ # But wait, tbl.delete() takes a SQL string.
136
+ # If the schema has 'metadata' struct, we must use 'metadata.nct_id'.
137
+ # If it was flattened (unlikely for the table itself), we use 'nct_id'.
138
+
139
+ # We check if 'nct_id' is a top-level column in the original DF
140
+ if "nct_id" in df.columns:
141
+ where_clause = f"nct_id = '{nct_id}'"
142
+ else:
143
+ where_clause = f"metadata.nct_id = '{nct_id}'"
144
+
145
+ tbl.delete(where_clause)
146
+
147
+ if unique_records:
148
+ tbl.add(unique_records)
149
+
150
+ total_deleted += original_count
151
+ total_reinserted += new_count
152
+ else:
153
+ print(f" - {nct_id}: No reduction needed (false positive?).")
154
+
155
+ if dry_run:
156
+ print(f"\n🧪 DRY RUN COMPLETE.")
157
+ print(f" - WOULD remove {total_deleted - total_reinserted} duplicate records.")
158
+ print(f" - WOULD preserve {total_reinserted} unique chunks.")
159
+ else:
160
+ print(f"\n🎉 Deduplication complete!")
161
+ print(f" - Removed {total_deleted - total_reinserted} duplicate records.")
162
+ print(f" - Preserved {total_reinserted} unique chunks.")
163
+
164
+ except Exception as e:
165
+ print(f"❌ Error: {e}")
166
+ import traceback
167
+ traceback.print_exc()
168
+
169
+ if __name__ == "__main__":
170
+ parser = argparse.ArgumentParser(description="Remove duplicate records from LanceDB.")
171
+ parser.add_argument("--dry-run", action="store_true", help="Simulate the process without making changes.")
172
+ args = parser.parse_args()
173
+
174
+ remove_duplicates(dry_run=args.dry_run)
tests/test_data_integrity.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import lancedb
3
+ import pandas as pd
4
+ import os
5
+ import sys
6
+
7
+ # Add project root to path
8
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
9
+
10
+ class TestDataIntegrity(unittest.TestCase):
11
+ def setUp(self):
12
+ # Determine the project root directory
13
+ self.test_dir = os.path.dirname(os.path.abspath(__file__))
14
+ self.project_root = os.path.dirname(self.test_dir)
15
+ self.db_path = os.path.join(self.project_root, "ct_gov_lancedb")
16
+
17
+ def test_pfizer_myeloma_counts(self):
18
+ """
19
+ Verifies that the database contains the expected number of Pfizer studies
20
+ related to Multiple Myeloma, based on strict keyword matching.
21
+ """
22
+ if not os.path.exists(self.db_path):
23
+ self.skipTest(f"Database directory '{self.db_path}' does not exist. Skipping data integrity test.")
24
+
25
+ print(f"\n📂 Loading database from {self.db_path}...")
26
+ try:
27
+ db = lancedb.connect(self.db_path)
28
+ tbl = db.open_table("clinical_trials")
29
+ except Exception as e:
30
+ self.skipTest(f"Failed to load LanceDB table: {e}")
31
+
32
+ # Fetch all data (LanceDB is fast enough for this size, or we could query)
33
+ # For integrity check, loading into DF is fine.
34
+ df = tbl.to_pandas()
35
+
36
+ # Handle metadata flattening if needed (LanceDB stores metadata in a struct)
37
+ if "metadata" in df.columns:
38
+ # Flatten the metadata column
39
+ meta_df = pd.json_normalize(df["metadata"])
40
+ df = meta_df
41
+
42
+ # 1. Check for 'org' column
43
+ if "org" not in df.columns:
44
+ self.fail("'org' column missing from metadata.")
45
+
46
+ # 2. Filter by Sponsor (Pfizer)
47
+ pfizer_studies = df[df["org"].str.contains("Pfizer", case=False, na=False)]
48
+ # We expect at least some Pfizer studies if the DB is populated
49
+ self.assertGreater(len(pfizer_studies), 0, "No Pfizer studies found in DB.")
50
+
51
+ # 3. Filter by "Multiple Myeloma" in Title or Conditions
52
+ query = "Multiple Myeloma"
53
+
54
+ def is_relevant(row):
55
+ title = str(row.get("title", "")).lower()
56
+ conditions = str(row.get("condition", "")).lower()
57
+ q = query.lower()
58
+ return q in title or q in conditions
59
+
60
+ relevant_studies = pfizer_studies[pfizer_studies.apply(is_relevant, axis=1)]
61
+
62
+ count = len(relevant_studies)
63
+ print(f"🎯 Pfizer Studies with '{query}' in Title or Conditions: {count}")
64
+
65
+ # Assertion: Based on our previous check, we expect exactly 7.
66
+ # However, to be robust against minor data updates, we can assert a range or exact value.
67
+ # Let's assert it's non-zero and reasonably small (since we know it shouldn't be 514).
68
+ self.assertGreater(count, 0, "Should find at least one relevant study.")
69
+ self.assertLess(count, 50, "Should not find hundreds of studies (strict filter check).")
70
+
71
+ # Optional: Assert exact count if we want to be very strict about data consistency
72
+ # self.assertEqual(count, 7, "Expected exactly 7 studies based on known ground truth.")
73
+
74
+ if __name__ == "__main__":
75
+ unittest.main()
tests/test_hybrid_search.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import sys
3
+ import os
4
+
5
+ # Add project root to path
6
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
7
+
8
+ from modules.tools import search_trials
9
+ from modules.utils import load_environment
10
+
11
+ # Mark as integration test since it loads the DB
12
+ @pytest.mark.integration
13
+ def test_hybrid_search_integration():
14
+ """
15
+ Integration test for Hybrid Search.
16
+ Verifies that the search_trials tool can retrieve results using the hybrid retriever.
17
+ """
18
+ load_environment()
19
+
20
+ # Test 1: Dynamic ID Search
21
+ # First, find a valid ID from a broad search
22
+ print("\n🔍 Finding a valid ID for testing...")
23
+ broad_results = search_trials.invoke({"query": "cancer"})
24
+
25
+ # Extract an ID from the results
26
+ import re
27
+ match = re.search(r"ID: (NCT\d+)", broad_results)
28
+ if not match:
29
+ pytest.skip("Could not find any studies in DB to test against.")
30
+
31
+ target_id = match.group(1)
32
+ print(f"🎯 Found target ID: {target_id}. Now testing exact search...")
33
+
34
+ # Now search for that specific ID
35
+ results_id = search_trials.invoke({"query": target_id})
36
+
37
+ assert "Found" in results_id
38
+ assert target_id in results_id, f"Hybrid search failed to retrieve exact ID {target_id}"
39
+
40
+ # Extract sponsor from the first result to ensure we test with valid data
41
+ # Result format: "**Title** ... - Sponsor: SponsorName ..."
42
+ sponsor_match = re.search(r"Sponsor: (.*?)\n", broad_results)
43
+ if not sponsor_match:
44
+ print("⚠️ Could not extract sponsor from results. Skipping hybrid test.")
45
+ return
46
+
47
+ target_sponsor = sponsor_match.group(1).strip()
48
+ # Normalize it to get the simple name if possible, or just use it
49
+ # But search_trials expects a simple name to map to variations.
50
+ # If we pass the full name, get_sponsor_variations might return None if not mapped.
51
+ # So let's try to find a mapped sponsor if possible, or just skip if not mapped.
52
+
53
+ from modules.utils import normalize_sponsor
54
+ simple_sponsor = normalize_sponsor(target_sponsor)
55
+
56
+ # If normalization didn't change it, it might not be in our alias list.
57
+ # But we can still try to search with it.
58
+
59
+ print(f"\n🔍 Testing Hybrid Search with dynamic sponsor: '{simple_sponsor}' (Original: {target_sponsor})")
60
+
61
+ # Use a generic query that likely matches the study, or just "study"
62
+ results_hybrid = search_trials.invoke({"query": "study", "sponsor": simple_sponsor})
63
+
64
+ assert "Found" in results_hybrid, f"Should find results for valid sponsor {simple_sponsor}"
65
+ assert target_sponsor in results_hybrid or simple_sponsor in results_hybrid
tests/test_sponsor_normalization.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from modules.utils import normalize_sponsor, get_sponsor_variations, SPONSOR_MAPPINGS
3
+
4
+ class TestSponsorNormalization(unittest.TestCase):
5
+ def test_normalize_sponsor_aliases(self):
6
+ """Test that common aliases map to canonical names."""
7
+ self.assertEqual(normalize_sponsor("J&J"), "Janssen")
8
+ self.assertEqual(normalize_sponsor("Johnson & Johnson"), "Janssen")
9
+ self.assertEqual(normalize_sponsor("GSK"), "GlaxoSmithKline")
10
+ self.assertEqual(normalize_sponsor("Merck"), "Merck Sharp & Dohme")
11
+ self.assertEqual(normalize_sponsor("BMS"), "Bristol-Myers Squibb")
12
+
13
+ def test_normalize_sponsor_variations(self):
14
+ """Test that specific DB variations map to canonical names."""
15
+ self.assertEqual(normalize_sponsor("Janssen Research & Development, LLC"), "Janssen")
16
+ self.assertEqual(normalize_sponsor("Pfizer Inc."), "Pfizer")
17
+ self.assertEqual(normalize_sponsor("Merck Sharp & Dohme LLC"), "Merck Sharp & Dohme")
18
+
19
+ def test_normalize_sponsor_canonical(self):
20
+ """Test that canonical names return themselves."""
21
+ self.assertEqual(normalize_sponsor("Janssen"), "Janssen")
22
+ self.assertEqual(normalize_sponsor("Pfizer"), "Pfizer")
23
+
24
+ def test_get_sponsor_variations(self):
25
+ """Test that getting variations works for aliases and canonical names."""
26
+ # Test with alias
27
+ vars_jnj = get_sponsor_variations("J&J")
28
+ self.assertIn("Janssen Research & Development, LLC", vars_jnj)
29
+ self.assertIn("Janssen", vars_jnj)
30
+
31
+ # Test with canonical
32
+ vars_janssen = get_sponsor_variations("Janssen")
33
+ self.assertEqual(vars_jnj, vars_janssen)
34
+
35
+ # Test with variation input (should normalize first)
36
+ vars_variation = get_sponsor_variations("Janssen Research & Development, LLC")
37
+ self.assertEqual(vars_janssen, vars_variation)
38
+
39
+ def test_unknown_sponsor(self):
40
+ """Test behavior for unknown sponsors."""
41
+ self.assertEqual(normalize_sponsor("Unknown Pharma"), "Unknown Pharma")
42
+ self.assertIsNone(get_sponsor_variations("Unknown Pharma"))
43
+
44
+ if __name__ == "__main__":
45
+ unittest.main()
tests/test_unit.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import pandas as pd
3
+ import sys
4
+ import os
5
+ from unittest.mock import MagicMock, patch
6
+
7
+ # Add project root to path to import app modules
8
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
9
+
10
+ from modules.utils import normalize_sponsor # noqa: E402
11
+ from modules.tools import expand_query # noqa: E402
12
+ from modules.graph_viz import build_graph # noqa: E402
13
+ from llama_index.core.schema import NodeWithScore, TextNode # noqa: E402
14
+
15
+ # --- Tests for normalize_sponsor ---
16
+
17
+
18
+ def test_normalize_sponsor_aliases():
19
+ assert normalize_sponsor("J&J") == "Janssen"
20
+ assert normalize_sponsor("Johnson & Johnson") == "Janssen"
21
+ assert normalize_sponsor("GSK") == "GlaxoSmithKline"
22
+ assert normalize_sponsor("Merck") == "Merck Sharp & Dohme"
23
+ assert normalize_sponsor("MSD") == "Merck Sharp & Dohme"
24
+ assert normalize_sponsor("BMS") == "Bristol-Myers Squibb"
25
+
26
+
27
+ def test_normalize_sponsor_no_change():
28
+ assert normalize_sponsor("Pfizer") == "Pfizer"
29
+ assert normalize_sponsor("Moderna") == "Moderna"
30
+ assert normalize_sponsor("Unknown Sponsor") == "Unknown Sponsor"
31
+
32
+
33
+ # --- Tests for Analytics Logic (Mocked) ---
34
+
35
+
36
+ def filter_dataframe(df, phase=None, status=None, sponsor=None, intervention=None):
37
+ """
38
+ Replicating the logic from get_study_analytics for testing purposes.
39
+ """
40
+ if phase:
41
+ target_phases = [p.strip().upper().replace(" ", "") for p in phase.split(",")]
42
+ df["phase_upper"] = df["phase"].astype(str).str.upper().str.replace(" ", "")
43
+ mask = df["phase_upper"].apply(lambda x: any(tp in x for tp in target_phases))
44
+ df = df[mask]
45
+
46
+ if status:
47
+ df = df[df["status"].str.upper() == status.upper()]
48
+
49
+ if sponsor:
50
+ target_sponsor = normalize_sponsor(sponsor).lower()
51
+ df["org_lower"] = df["org"].astype(str).apply(normalize_sponsor).str.lower()
52
+ df = df[df["org_lower"].str.contains(target_sponsor, regex=False)]
53
+
54
+ if intervention:
55
+ target_intervention = intervention.lower()
56
+ df["intervention_lower"] = df["intervention"].astype(str).str.lower()
57
+ df = df[df["intervention_lower"].str.contains(target_intervention, regex=False)]
58
+
59
+ return df
60
+
61
+
62
+ @pytest.fixture
63
+ def sample_df():
64
+ data = {
65
+ "nct_id": ["NCT001", "NCT002", "NCT003", "NCT004"],
66
+ "phase": ["PHASE1", "PHASE2", "PHASE3", "PHASE2"],
67
+ "status": ["RECRUITING", "COMPLETED", "COMPLETED", "RECRUITING"],
68
+ "org": ["Pfizer", "Janssen", "Merck Sharp & Dohme", "Pfizer"],
69
+ "intervention": ["Drug A", "Drug B", "Keytruda", "Drug A + Drug C"],
70
+ "start_year": [2020, 2021, 2022, 2023],
71
+ "title": [
72
+ "Study of Drug A",
73
+ "Study of Drug B",
74
+ "Keytruda Trial",
75
+ "Combo Study",
76
+ ],
77
+ "condition": ["Cancer", "Diabetes", "Lung Cancer", "Cancer"],
78
+ }
79
+ return pd.DataFrame(data)
80
+
81
+
82
+ def test_analytics_filter_intervention(sample_df):
83
+ # Filter for Keytruda
84
+ filtered = filter_dataframe(sample_df, intervention="Keytruda")
85
+ assert len(filtered) == 1
86
+ assert filtered.iloc[0]["nct_id"] == "NCT003"
87
+
88
+
89
+ def test_analytics_filter_intervention_partial(sample_df):
90
+ # Filter for "Drug A" (should match NCT001 and NCT004)
91
+ filtered = filter_dataframe(sample_df, intervention="Drug A")
92
+ assert len(filtered) == 2
93
+ assert set(filtered["nct_id"]) == {"NCT001", "NCT004"}
94
+
95
+
96
+ # --- Tests for Query Expansion ---
97
+
98
+
99
+ @patch("modules.tools.Settings")
100
+ def test_expand_query(mock_settings):
101
+ # Mock LLM response
102
+ mock_response = MagicMock()
103
+ mock_response.text = "Expanded Query: cancer OR carcinoma OR tumor"
104
+ mock_settings.llm.complete.return_value = mock_response
105
+
106
+ query = "cancer"
107
+ expanded = expand_query(query)
108
+
109
+ assert "cancer OR carcinoma OR tumor" in expanded
110
+ mock_settings.llm.complete.assert_called_once()
111
+
112
+
113
+ def test_expand_query_skip_long():
114
+ long_query = "this is a very long query that should definitely be skipped because it has too many words"
115
+ assert expand_query(long_query) == long_query
116
+
117
+
118
+
119
+
120
+
121
+ # --- Tests for Graph Visualization ---
122
+
123
+
124
+ def test_build_graph():
125
+ data = [
126
+ {"nct_id": "NCT1", "title": "Study 1", "org": "Pfizer", "condition": "Cancer"},
127
+ {
128
+ "nct_id": "NCT2",
129
+ "title": "Study 2",
130
+ "org": "Merck",
131
+ "condition": "Cancer, Diabetes",
132
+ },
133
+ ]
134
+
135
+ nodes, edges, config = build_graph(data)
136
+
137
+ # Check Nodes
138
+ # 2 Studies + 2 Sponsors + 2 Conditions (Cancer, Diabetes) = 6 Nodes
139
+ assert len(nodes) == 6
140
+
141
+ node_ids = [n.id for n in nodes]
142
+ assert "NCT1" in node_ids
143
+ assert "Pfizer" in node_ids
144
+ assert "Cancer" in node_ids
145
+
146
+ # Check Edges
147
+ # NCT1 -> Pfizer, NCT1 -> Cancer (2 edges)
148
+ # NCT2 -> Merck, NCT2 -> Cancer, NCT2 -> Diabetes (3 edges)
149
+ assert len(edges) == 5