Spaces:
Sleeping
Sleeping
Geoffrey Kip
commited on
Commit
·
507be68
0
Parent(s):
Initial Release
Browse files- .dockerignore +47 -0
- .flake8 +4 -0
- .gitattributes +1 -0
- .gitignore +62 -0
- DEPLOYMENT.md +82 -0
- Dockerfile +33 -0
- README.md +240 -0
- ct_agent_app.py +583 -0
- modules/__init__.py +0 -0
- modules/cohort_tools.py +145 -0
- modules/constants.py +103 -0
- modules/graph_viz.py +97 -0
- modules/tools.py +706 -0
- modules/utils.py +281 -0
- requirements.txt +19 -0
- scripts/analyze_db.py +149 -0
- scripts/ingest_ct.py +449 -0
- scripts/remove_duplicates.py +174 -0
- tests/test_data_integrity.py +75 -0
- tests/test_hybrid_search.py +65 -0
- tests/test_sponsor_normalization.py +45 -0
- tests/test_unit.py +149 -0
.dockerignore
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Git
|
| 2 |
+
.git
|
| 3 |
+
.gitignore
|
| 4 |
+
|
| 5 |
+
# Python
|
| 6 |
+
__pycache__/
|
| 7 |
+
*.py[cod]
|
| 8 |
+
*$py.class
|
| 9 |
+
*.so
|
| 10 |
+
.Python
|
| 11 |
+
env/
|
| 12 |
+
build/
|
| 13 |
+
develop-eggs/
|
| 14 |
+
dist/
|
| 15 |
+
downloads/
|
| 16 |
+
eggs/
|
| 17 |
+
.eggs/
|
| 18 |
+
lib/
|
| 19 |
+
lib64/
|
| 20 |
+
parts/
|
| 21 |
+
sdist/
|
| 22 |
+
var/
|
| 23 |
+
wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
|
| 28 |
+
# Virtual Environment
|
| 29 |
+
venv/
|
| 30 |
+
.venv/
|
| 31 |
+
|
| 32 |
+
# Environment Variables (CRITICAL: Do not include secrets)
|
| 33 |
+
.env
|
| 34 |
+
.env.local
|
| 35 |
+
|
| 36 |
+
# IDE
|
| 37 |
+
.vscode/
|
| 38 |
+
.idea/
|
| 39 |
+
|
| 40 |
+
# Mac
|
| 41 |
+
.DS_Store
|
| 42 |
+
|
| 43 |
+
# Logs
|
| 44 |
+
*.log
|
| 45 |
+
|
| 46 |
+
# Temporary
|
| 47 |
+
*.tmp
|
.flake8
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[flake8]
|
| 2 |
+
max-line-length = 120
|
| 3 |
+
extend-ignore = E203
|
| 4 |
+
exclude = venv, .git, __pycache__, build, dist
|
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
ct_gov_lancedb/**/* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
*.egg-info/
|
| 8 |
+
*.pyc
|
| 9 |
+
*.pyo
|
| 10 |
+
|
| 11 |
+
# Build artifacts
|
| 12 |
+
dist/
|
| 13 |
+
build/
|
| 14 |
+
*.spec
|
| 15 |
+
|
| 16 |
+
# Virtual Environment
|
| 17 |
+
.venv/
|
| 18 |
+
venv/
|
| 19 |
+
env/
|
| 20 |
+
ENV/
|
| 21 |
+
|
| 22 |
+
# Environment variables
|
| 23 |
+
.env
|
| 24 |
+
.env.local
|
| 25 |
+
|
| 26 |
+
# Session files
|
| 27 |
+
amazon_session.json
|
| 28 |
+
|
| 29 |
+
# Database files
|
| 30 |
+
agent_data.db
|
| 31 |
+
*.db
|
| 32 |
+
*.db-journal
|
| 33 |
+
ct_gov_lancedb/
|
| 34 |
+
|
| 35 |
+
# Chrome/Browser session data
|
| 36 |
+
user_session/
|
| 37 |
+
|
| 38 |
+
# IDE
|
| 39 |
+
.vscode/
|
| 40 |
+
.idea/
|
| 41 |
+
*.swp
|
| 42 |
+
*.swo
|
| 43 |
+
*~
|
| 44 |
+
*.code-workspace
|
| 45 |
+
|
| 46 |
+
# OS
|
| 47 |
+
.DS_Store
|
| 48 |
+
Thumbs.db
|
| 49 |
+
.DS_Store?
|
| 50 |
+
|
| 51 |
+
# Logs
|
| 52 |
+
*.log
|
| 53 |
+
|
| 54 |
+
# Dev Containers
|
| 55 |
+
.devcontainer/
|
| 56 |
+
|
| 57 |
+
# macOS App Bundle (generated, but keep source)
|
| 58 |
+
*.app/
|
| 59 |
+
|
| 60 |
+
# Temporary files
|
| 61 |
+
*.tmp
|
| 62 |
+
*.temp
|
DEPLOYMENT.md
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Deployment Guide: Hugging Face Spaces 🐳
|
| 2 |
+
|
| 3 |
+
This guide will walk you through deploying the **Clinical Trial Inspector Agent** to **Hugging Face Spaces** using Docker.
|
| 4 |
+
|
| 5 |
+
## Prerequisites
|
| 6 |
+
|
| 7 |
+
1. **Hugging Face Account**: [Sign up here](https://huggingface.co/join).
|
| 8 |
+
2. **Git LFS (Large File Storage)**: Required to upload the database (~700MB).
|
| 9 |
+
* **Mac**: `brew install git-lfs`
|
| 10 |
+
* **Windows**: Download from [git-lfs.com](https://git-lfs.com/)
|
| 11 |
+
* **Linux**: `sudo apt-get install git-lfs`
|
| 12 |
+
|
| 13 |
+
## Step 1.5: Authentication (Crucial!) 🔑
|
| 14 |
+
|
| 15 |
+
Hugging Face requires an **Access Token** for Git operations (passwords don't work).
|
| 16 |
+
|
| 17 |
+
1. Go to **[Settings > Access Tokens](https://huggingface.co/settings/tokens)**.
|
| 18 |
+
2. Click **Create new token**.
|
| 19 |
+
3. **Type**: Select **Write** (important!).
|
| 20 |
+
4. Copy the token (starts with `hf_...`).
|
| 21 |
+
5. **Usage**: When `git push` asks for a password, **paste this token**.
|
| 22 |
+
|
| 23 |
+
## Step 2: Create a New Space
|
| 24 |
+
|
| 25 |
+
1. Go to [huggingface.co/new-space](https://huggingface.co/new-space).
|
| 26 |
+
2. **Space Name**: e.g., `clinical-trial-agent`.
|
| 27 |
+
3. **License**: `MIT` (or your choice).
|
| 28 |
+
4. **SDK**: Select **Docker**.
|
| 29 |
+
5. **Visibility**: Public or Private.
|
| 30 |
+
6. Click **Create Space**.
|
| 31 |
+
|
| 32 |
+
## Step 2: Prepare Your Local Repo
|
| 33 |
+
|
| 34 |
+
You need to initialize Git LFS to track the large LanceDB files.
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
# Initialize LFS
|
| 38 |
+
git lfs install
|
| 39 |
+
|
| 40 |
+
# Track the LanceDB files
|
| 41 |
+
git lfs track "ct_gov_lancedb/**/*"
|
| 42 |
+
git add .gitattributes
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
## Step 3: Push to Hugging Face
|
| 46 |
+
|
| 47 |
+
You can either push your existing repo or clone the Space and copy files. Pushing existing is easier:
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
# Add the Space as a remote (replace YOUR_USERNAME and SPACE_NAME)
|
| 51 |
+
git remote add space https://huggingface.co/spaces/YOUR_USERNAME/SPACE_NAME
|
| 52 |
+
|
| 53 |
+
# Push the main branch
|
| 54 |
+
git push space main
|
| 55 |
+
# OR if you are on a feature branch:
|
| 56 |
+
git push space feature/deploy_app:main
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
> **Note**: The first push will take time as it uploads the 700MB database.
|
| 60 |
+
|
| 61 |
+
## Step 4: Configure Secrets (Optional but Recommended)
|
| 62 |
+
|
| 63 |
+
To run in **Admin Mode** (no user prompt for API key):
|
| 64 |
+
|
| 65 |
+
1. Go to your Space's **Settings** tab.
|
| 66 |
+
2. Scroll to **Variables and secrets**.
|
| 67 |
+
3. Click **New secret**.
|
| 68 |
+
4. **Name**: `GOOGLE_API_KEY`
|
| 69 |
+
5. **Value**: Your Google API Key (starts with `AIza...`).
|
| 70 |
+
|
| 71 |
+
## Step 5: Verify Deployment
|
| 72 |
+
|
| 73 |
+
1. Go to the **App** tab in your Space.
|
| 74 |
+
2. You should see "Building..." in the logs.
|
| 75 |
+
3. Once built, the app will launch! 🚀
|
| 76 |
+
|
| 77 |
+
---
|
| 78 |
+
|
| 79 |
+
## Troubleshooting
|
| 80 |
+
|
| 81 |
+
* **"LFS upload failed"**: Ensure you ran `git lfs install` and `git lfs track`.
|
| 82 |
+
* **"Runtime Error"**: Check the **Logs** tab. If it says "API Key Missing", ensure you set the Secret or enter it in the UI.
|
Dockerfile
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use an official Python runtime as a parent image
|
| 2 |
+
FROM python:3.10-slim
|
| 3 |
+
|
| 4 |
+
# Set the working directory in the container
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Install system dependencies
|
| 8 |
+
# build-essential is often needed for compiling python packages
|
| 9 |
+
# git is needed if you install packages from git
|
| 10 |
+
RUN apt-get update && apt-get install -y \
|
| 11 |
+
build-essential \
|
| 12 |
+
git \
|
| 13 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 14 |
+
|
| 15 |
+
# Copy the requirements file into the container at /app
|
| 16 |
+
COPY requirements.txt .
|
| 17 |
+
|
| 18 |
+
# Install any needed packages specified in requirements.txt
|
| 19 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 20 |
+
|
| 21 |
+
# Copy the current directory contents into the container at /app
|
| 22 |
+
COPY . .
|
| 23 |
+
|
| 24 |
+
# Expose port 8501 for Streamlit
|
| 25 |
+
EXPOSE 8501
|
| 26 |
+
|
| 27 |
+
# Define environment variable for Streamlit to run in headless mode
|
| 28 |
+
ENV STREAMLIT_SERVER_HEADLESS=true
|
| 29 |
+
ENV STREAMLIT_SERVER_PORT=8501
|
| 30 |
+
ENV STREAMLIT_SERVER_ADDRESS=0.0.0.0
|
| 31 |
+
|
| 32 |
+
# Run the application
|
| 33 |
+
CMD ["streamlit", "run", "ct_agent_app.py"]
|
README.md
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Clinical Trial Inspector Agent 🕵️♂️💊
|
| 2 |
+
|
| 3 |
+
**Clinical Trial Inspector** is an advanced AI agent designed to revolutionize how researchers, clinicians, and analysts explore clinical trial data. By combining **Semantic Search**, **Retrieval-Augmented Generation (RAG)**, and **Visual Analytics**, it transforms raw data from [ClinicalTrials.gov](https://clinicaltrials.gov/) into actionable insights.
|
| 4 |
+
|
| 5 |
+
Built with **LangChain**, **LlamaIndex**, **Streamlit**, **Altair**, **Streamlit-Agraph**, and **Google Gemini**, this tool goes beyond simple keyword search. It understands natural language, generates inline visualizations, performs complex multi-dimensional analysis, and visualizes relationships in an interactive knowledge graph.
|
| 6 |
+
|
| 7 |
+
## ✨ Key Features
|
| 8 |
+
|
| 9 |
+
### 2. 🧠 Intelligent Search & Retrieval
|
| 10 |
+
* **Hybrid Search**: Combines **Semantic Search** (vector similarity) with **BM25 Keyword Search** (sparse retrieval) using **LanceDB's Native Hybrid Search**. This ensures you find studies that match both the *meaning* (e.g., "kidney cancer" -> "renal cell carcinoma") and *exact terms* (e.g., "NCT04589845", "Teclistamab").
|
| 11 |
+
* **Smart Filtering**:
|
| 12 |
+
* **Strict Pre-Filtering**: For specific sponsors (e.g., "Pfizer"), it forces the engine to look *only* at that sponsor's studies first, ensuring 100% recall.
|
| 13 |
+
* **Strict Keyword Filtering (Analytics Only)**: For counting questions (e.g., "How many studies..."), the **Analytics Engine** (`get_study_analytics`) prioritizes studies where the query explicitly appears in the **Title** or **Conditions**, ensuring high precision and accurate counts.
|
| 14 |
+
* **Sponsor Alias Support**: Intelligently maps aliases (e.g., "J&J", "MSD") to their canonical sponsor names ("Janssen", "Merck Sharp & Dohme") for accurate aggregation.
|
| 15 |
+
* **Smart Summary**: Returns a clean, concise list of relevant studies.
|
| 16 |
+
* **Query Expansion**: Automatically expands your search terms with medical synonyms (e.g., "Heart Attack" -> "Myocardial Infarction").
|
| 17 |
+
* **Re-Ranking**: Uses a Cross-Encoder (`ms-marco-MiniLM`) to re-score results for maximum relevance.
|
| 18 |
+
* **Query Decomposition**: Breaks down complex multi-part questions (e.g., *"Compare the primary outcomes of Keytruda vs Opdivo"*) into sub-questions for precise answers.
|
| 19 |
+
* **Cohort SQL Generation**: Translates eligibility criteria into standard SQL queries (OMOP CDM) for patient cohort identification.
|
| 20 |
+
|
| 21 |
+
### 📊 Visual Analytics & Insights
|
| 22 |
+
- **Inline Charts (Contextual)**: The agent automatically generates **Bar Charts** and **Line Charts** directly in the chat stream when you ask aggregation questions (e.g., *"Top sponsors for Multiple Myeloma"*).
|
| 23 |
+
- **Analytics Dashboard (Global)**: A dedicated dashboard to analyze trends across the **entire dataset** (60,000+ studies), independent of your chat session.
|
| 24 |
+
- **Interactive Knowledge Graph**: Visualize connections between **Studies**, **Sponsors**, and **Conditions** in a dynamic, interactive network graph.
|
| 25 |
+
|
| 26 |
+
### 🌍 Geospatial Dashboard
|
| 27 |
+
- **Global Trial Map**: Visualize the geographic distribution of clinical trials on an interactive world map.
|
| 28 |
+
- **Region Toggle**: Switch between **World View** (Country-level aggregation) and **USA View** (State-level aggregation).
|
| 29 |
+
- **Dot Visualization**: Uses dynamic **CircleMarkers** (dots) sized by trial count to show density.
|
| 30 |
+
- **Interactive Filters**: Filter the map by **Phase**, **Status**, **Sponsor**, **Start Year**, and **Study Type**.
|
| 31 |
+
|
| 32 |
+
### 🔍 Multi-Filter Analysis
|
| 33 |
+
- **Complex Filtering**: Answer sophisticated questions by applying multiple filters simultaneously.
|
| 34 |
+
- *Example*: *"For **Phase 2 and 3** studies, what are **Pfizer's** most common study indications?"*
|
| 35 |
+
- **Full Dataset Scope**: General analytics questions analyze the **entire database**, not just a sample.
|
| 36 |
+
- **Smart Retrieval**: Retrieves up to **5,000 relevant studies** for comprehensive analysis.
|
| 37 |
+
|
| 38 |
+
### ⚡ High-Performance Ingestion
|
| 39 |
+
- **Parallel Processing**: Uses multi-core processing to ingest and embed thousands of studies per minute.
|
| 40 |
+
- **LanceDB Integration**: Uses **LanceDB** for high-performance vector storage and native hybrid search.
|
| 41 |
+
- **Idempotent Updates**: Smartly updates existing records without duplication, allowing for seamless data refreshes.
|
| 42 |
+
|
| 43 |
+
## 🤖 Agent Capabilities & Tools
|
| 44 |
+
|
| 45 |
+
The agent is equipped with specialized tools to handle different types of requests:
|
| 46 |
+
|
| 47 |
+
### 1. `search_trials`
|
| 48 |
+
* **Purpose**: Finds specific clinical trials based on natural language queries.
|
| 49 |
+
* **Capabilities**: Semantic Search, Smart Filtering (Phase, Status, Sponsor, Intervention), Query Expansion, Hybrid Search, Re-Ranking.
|
| 50 |
+
|
| 51 |
+
### 2. `get_study_analytics`
|
| 52 |
+
* **Purpose**: Aggregates data to reveal trends and insights.
|
| 53 |
+
* **Capabilities**: Multi-Filtering, Grouping (Phase, Status, Sponsor, Year, Condition), Full Dataset Access, Inline Visualization.
|
| 54 |
+
|
| 55 |
+
### 3. `compare_studies`
|
| 56 |
+
* **Purpose**: Handles complex comparison or multi-part questions.
|
| 57 |
+
* **Capabilities**: Uses **Query Decomposition** to break a complex query into sub-queries, executes them against the database, and synthesizes the results.
|
| 58 |
+
|
| 59 |
+
### 4. `find_similar_studies`
|
| 60 |
+
* **Purpose**: Discovers studies that are semantically similar to a specific trial.
|
| 61 |
+
* **Capabilities**:
|
| 62 |
+
* **NCT Lookup**: Automatically fetches content if queried with an NCT ID.
|
| 63 |
+
* **Self-Exclusion**: Filters out the reference study from results.
|
| 64 |
+
* **Scoring**: Returns similarity scores for transparency.
|
| 65 |
+
|
| 66 |
+
### 5. `get_study_details`
|
| 67 |
+
* **Purpose**: Fetches the full text content of a specific study by NCT ID.
|
| 68 |
+
* **Capabilities**: Retrieves all chunks of a study to provide comprehensive details (Criteria, Summary, Protocol).
|
| 69 |
+
|
| 70 |
+
### 6. `get_cohort_sql`
|
| 71 |
+
* **Purpose**: Translates clinical trial eligibility criteria into standard SQL queries for claims data analysis.
|
| 72 |
+
* **Capabilities**:
|
| 73 |
+
* **Extraction**: Parses text into structured inclusion/exclusion rules (Concepts, Codes).
|
| 74 |
+
* **SQL Generation**: Generates OMOP-compatible SQL queries targeting `medical_claims` and `pharmacy_claims`.
|
| 75 |
+
* **Logic Enforcement**: Applies temporal logic (e.g., "2 diagnoses > 30 days apart") for chronic conditions.
|
| 76 |
+
|
| 77 |
+
## ⚙️ How It Works (RAG Pipeline)
|
| 78 |
+
|
| 79 |
+
1. **Ingestion**: `ingest_ct.py` fetches study data from ClinicalTrials.gov. It extracts rich text (including **Eligibility Criteria** and **Interventions**) and structured metadata. It uses **multiprocessing** for speed.
|
| 80 |
+
2. **Embedding**: Text is converted into vector embeddings using `PubMedBERT` and stored in **LanceDB**.
|
| 81 |
+
3. **Retrieval**:
|
| 82 |
+
* **Query Transformation**: Synonyms are injected via LLM.
|
| 83 |
+
* **Pre-Filtering**: Strict filters (Status, Year, Sponsor) reduce the search scope.
|
| 84 |
+
* **Hybrid Search**: Parallel **Vector Search** (Semantic) and **BM25** (Keyword) combined via **LanceDB Native Hybrid Search**.
|
| 85 |
+
* **Post-Filtering**: Additional metadata checks (Phase, Intervention) on retrieved candidates.
|
| 86 |
+
* **Re-Ranking**: Cross-Encoder re-scoring.
|
| 87 |
+
4. **Synthesis**: **Google Gemini** synthesizes the final answer.
|
| 88 |
+
|
| 89 |
+
### 🏗️ Ingestion Pipeline
|
| 90 |
+
|
| 91 |
+
```mermaid
|
| 92 |
+
graph TD
|
| 93 |
+
API[ClinicalTrials.gov API] -->|Fetch Batches| Script[ingest_ct.py]
|
| 94 |
+
Script -->|Process & Embed| LanceDB[(LanceDB)]
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
### 🧠 RAG Retrieval Flow
|
| 98 |
+
|
| 99 |
+
```mermaid
|
| 100 |
+
graph TD
|
| 101 |
+
User[User Query] -->|Expand| Synonyms[Synonym Injection]
|
| 102 |
+
Synonyms -->|Pre-Filter| PreFilter[Pre-Retrieval Filters]
|
| 103 |
+
PreFilter -->|Filtered Scope| Hybrid[Hybrid Search]
|
| 104 |
+
Hybrid -->|Parallel Search| Vector[Vector Search] & BM25[BM25 Keyword Search]
|
| 105 |
+
Vector & BM25 -->|Reciprocal Rank Fusion| Fusion[Merged Candidates]
|
| 106 |
+
Fusion -->|Candidates| PostFilter[Post-Retrieval Filters]
|
| 107 |
+
PostFilter -->|Top N| ReRank[Cross-Encoder Re-Ranking]
|
| 108 |
+
ReRank -->|Context| LLM[Google Gemini]
|
| 109 |
+
LLM -->|Answer| Response[Final Response]
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
### 🕸️ Knowledge Graph
|
| 113 |
+
|
| 114 |
+
```mermaid
|
| 115 |
+
graph TD
|
| 116 |
+
LanceDB[(LanceDB)] -->|Metadata| GraphBuilder[build_graph]
|
| 117 |
+
GraphBuilder -->|Nodes & Edges| Agraph[Streamlit Agraph]
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
## 🛠️ Tech Stack
|
| 121 |
+
|
| 122 |
+
- **Frontend**: Streamlit, Altair, Streamlit-Agraph
|
| 123 |
+
- **LLM**: Google Gemini (`gemini-2.5-flash`)
|
| 124 |
+
- **Orchestration**: LangChain (Agents, Tool Calling)
|
| 125 |
+
- **Retrieval (RAG)**: LlamaIndex (VectorStoreIndex, SubQuestionQueryEngine)
|
| 126 |
+
- **Vector Database**: LanceDB (Local)
|
| 127 |
+
- **Embeddings**: HuggingFace (`pritamdeka/S-PubMedBert-MS-MARCO`)
|
| 128 |
+
|
| 129 |
+
## 🚀 Getting Started
|
| 130 |
+
|
| 131 |
+
### Prerequisites
|
| 132 |
+
|
| 133 |
+
- Python 3.10+
|
| 134 |
+
- A Google Cloud API Key with access to Gemini
|
| 135 |
+
|
| 136 |
+
### Installation
|
| 137 |
+
|
| 138 |
+
1. **Clone the repository**
|
| 139 |
+
```bash
|
| 140 |
+
git clone <repository-url>
|
| 141 |
+
cd clinical_trial_agent
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
2. **Create and activate a virtual environment**
|
| 145 |
+
```bash
|
| 146 |
+
python -m venv venv
|
| 147 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
3. **Install dependencies**
|
| 151 |
+
```bash
|
| 152 |
+
pip install -r requirements.txt
|
| 153 |
+
```
|
| 154 |
+
|
| 155 |
+
4. **Set up Environment Variables**
|
| 156 |
+
Create a `.env` file in the root directory and add your Google API Key:
|
| 157 |
+
```bash
|
| 158 |
+
GOOGLE_API_KEY=your_google_api_key_here
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
## 📖 Usage
|
| 162 |
+
|
| 163 |
+
### 1. Ingest Data
|
| 164 |
+
Populate the local database. The script uses parallel processing for speed.
|
| 165 |
+
|
| 166 |
+
```bash
|
| 167 |
+
# Recommended: Ingest 5000 recent studies
|
| 168 |
+
python scripts/ingest_ct.py --limit 5000 --years 5
|
| 169 |
+
|
| 170 |
+
# Ingest ALL studies (Warning: Large download!)
|
| 171 |
+
python scripts/ingest_ct.py --limit -1 --years 10
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
### 2. Run the Agent
|
| 175 |
+
Launch the Streamlit application:
|
| 176 |
+
|
| 177 |
+
```bash
|
| 178 |
+
streamlit run ct_agent_app.py
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
### 3. Ask Questions!
|
| 182 |
+
- **Search**: *"Find studies for Multiple Myeloma."*
|
| 183 |
+
- **Comparison**: *"Compare the primary outcomes of Keytruda vs Opdivo."*
|
| 184 |
+
- **Analytics**: *"Who are the top sponsors for Breast Cancer?"* (Now supports grouping by **Intervention** and **Study Type**!)
|
| 185 |
+
- **Graph**: Go to the **Knowledge Graph** tab to visualize connections.
|
| 186 |
+
|
| 187 |
+
## 🧪 Testing & Quality
|
| 188 |
+
|
| 189 |
+
- **Unit Tests**: Run `python -m pytest tests/test_unit.py` to verify core logic.
|
| 190 |
+
- **Hybrid Search Tests**: Run `python -m pytest tests/test_hybrid_search.py` to verify the search engine's precision and recall.
|
| 191 |
+
- **Data Integrity**: Run `python -m unittest tests/test_data_integrity.py` to verify database content against known ground truths.
|
| 192 |
+
- **Sponsor Normalization**: Run `python -m pytest tests/test_sponsor_normalization.py` to verify alias mapping logic.
|
| 193 |
+
- **Linting**: Codebase is formatted with `black` and linted with `flake8`.
|
| 194 |
+
|
| 195 |
+
## 📂 Project Structure
|
| 196 |
+
|
| 197 |
+
- `ct_agent_app.py`: Main application logic.
|
| 198 |
+
- `modules/`:
|
| 199 |
+
- `utils.py`: Configuration, Normalization, Custom Filters.
|
| 200 |
+
- `constants.py`: Static data (Coordinates, Mappings).
|
| 201 |
+
- `tools.py`: Tool definitions (`search_trials`, `compare_studies`, etc.).
|
| 202 |
+
- `cohort_tools.py`: SQL generation logic (`get_cohort_sql`).
|
| 203 |
+
- `graph_viz.py`: Knowledge Graph logic.
|
| 204 |
+
- `scripts/`:
|
| 205 |
+
- `ingest_ct.py`: Parallel data ingestion pipeline.
|
| 206 |
+
- `analyze_db.py`: Database inspection.
|
| 207 |
+
|
| 208 |
+
- `ct_gov_lancedb/`: Persisted LanceDB vector store.
|
| 209 |
+
- `tests/`:
|
| 210 |
+
- `test_unit.py`: Core logic tests.
|
| 211 |
+
- `test_hybrid_search.py`: Integration tests for search engine.
|
| 212 |
+
|
| 213 |
+
## 🐳 Deployment
|
| 214 |
+
|
| 215 |
+
The application is container-ready and can be deployed using Docker.
|
| 216 |
+
|
| 217 |
+
### Build the Image
|
| 218 |
+
```bash
|
| 219 |
+
docker build -t clinical-trial-agent .
|
| 220 |
+
```
|
| 221 |
+
|
| 222 |
+
### Run the Container
|
| 223 |
+
You can run the container in two modes:
|
| 224 |
+
|
| 225 |
+
**1. Admin Mode (API Key in Environment)**
|
| 226 |
+
Pass the key as an environment variable. Users will not be prompted.
|
| 227 |
+
```bash
|
| 228 |
+
docker run -p 8501:8501 -e GOOGLE_API_KEY=your_key_here clinical-trial-agent
|
| 229 |
+
```
|
| 230 |
+
|
| 231 |
+
**2. User Mode (Prompt for Key)**
|
| 232 |
+
Run without the key. Users will be prompted to enter their own key in the sidebar.
|
| 233 |
+
```bash
|
| 234 |
+
docker run -p 8501:8501 clinical-trial-agent
|
| 235 |
+
```
|
| 236 |
+
|
| 237 |
+
### Hosting Options
|
| 238 |
+
- **Hugging Face Spaces**: Select "Docker" SDK. Add `GOOGLE_API_KEY` to Secrets for Admin Mode.
|
| 239 |
+
- **Google Cloud Run**: Deploy the container and map port 8501.
|
| 240 |
+
|
ct_agent_app.py
ADDED
|
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Clinical Trial Inspector Agent Application.
|
| 3 |
+
|
| 4 |
+
This is the main Streamlit application script. It orchestrates:
|
| 5 |
+
1. **LLM & Agents**: Initializes Google Gemini and the LangChain agent.
|
| 6 |
+
2. **RAG Pipeline**: Loads the LlamaIndex vector store for semantic retrieval.
|
| 7 |
+
3. **User Interface**: Renders the Streamlit UI with tabs for Chat, Analytics, and Raw Data.
|
| 8 |
+
4. **Visualization**: Handles dynamic chart generation using Altair.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import streamlit as st
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import os
|
| 14 |
+
import altair as alt
|
| 15 |
+
import logging
|
| 16 |
+
from dotenv import load_dotenv
|
| 17 |
+
|
| 18 |
+
# Suppress logging
|
| 19 |
+
logging.getLogger("langchain_google_genai._function_utils").setLevel(logging.ERROR)
|
| 20 |
+
|
| 21 |
+
# Load environment variables
|
| 22 |
+
load_dotenv()
|
| 23 |
+
|
| 24 |
+
# Module Imports
|
| 25 |
+
from modules.utils import load_index, setup_llama_index
|
| 26 |
+
from modules.constants import COUNTRY_COORDINATES, STATE_COORDINATES
|
| 27 |
+
|
| 28 |
+
# ... (imports)
|
| 29 |
+
from modules.tools import (
|
| 30 |
+
search_trials,
|
| 31 |
+
find_similar_studies,
|
| 32 |
+
get_study_analytics,
|
| 33 |
+
compare_studies,
|
| 34 |
+
get_study_details,
|
| 35 |
+
fetch_study_analytics_data,
|
| 36 |
+
)
|
| 37 |
+
from modules.cohort_tools import get_cohort_sql
|
| 38 |
+
from modules.graph_viz import build_graph
|
| 39 |
+
from streamlit_agraph import agraph
|
| 40 |
+
from streamlit_option_menu import option_menu
|
| 41 |
+
import folium
|
| 42 |
+
from streamlit_folium import st_folium
|
| 43 |
+
|
| 44 |
+
# LangChain Imports
|
| 45 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 46 |
+
from langchain.agents import AgentExecutor, create_tool_calling_agent
|
| 47 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 48 |
+
from langchain_core.messages import HumanMessage, AIMessage
|
| 49 |
+
from langchain_core.prompts import MessagesPlaceholder
|
| 50 |
+
|
| 51 |
+
# --- App Configuration ---
|
| 52 |
+
st.set_page_config(
|
| 53 |
+
page_title="Clinical Trial Inspector",
|
| 54 |
+
layout="wide",
|
| 55 |
+
initial_sidebar_state="expanded",
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# --- Custom CSS for Sidebar Width ---
|
| 59 |
+
st.markdown(
|
| 60 |
+
"""
|
| 61 |
+
<style>
|
| 62 |
+
[data-testid="stSidebar"] {
|
| 63 |
+
min-width: 200px;
|
| 64 |
+
max-width: 250px;
|
| 65 |
+
}
|
| 66 |
+
</style>
|
| 67 |
+
""",
|
| 68 |
+
unsafe_allow_html=True,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
st.title("🧬 Clinical Trial Inspector Agent")
|
| 72 |
+
|
| 73 |
+
# 1. Setup LLM & LlamaIndex Settings
|
| 74 |
+
# We use Google Gemini-2.5-Flash for fast and accurate responses.
|
| 75 |
+
api_key = os.environ.get("GOOGLE_API_KEY")
|
| 76 |
+
|
| 77 |
+
if not api_key:
|
| 78 |
+
st.sidebar.warning("⚠️ API Key Missing")
|
| 79 |
+
user_key = st.sidebar.text_input("Enter Google API Key:", type="password", help="Get one at https://aistudio.google.com/")
|
| 80 |
+
if user_key:
|
| 81 |
+
st.session_state["api_key"] = user_key
|
| 82 |
+
api_key = user_key
|
| 83 |
+
st.sidebar.success("Key set!")
|
| 84 |
+
st.rerun()
|
| 85 |
+
else:
|
| 86 |
+
# Check if key is already in session state (from previous run)
|
| 87 |
+
if "api_key" in st.session_state:
|
| 88 |
+
api_key = st.session_state["api_key"]
|
| 89 |
+
else:
|
| 90 |
+
st.warning("Please enter your Google API Key in the sidebar to continue.")
|
| 91 |
+
st.stop()
|
| 92 |
+
else:
|
| 93 |
+
# Env var exists, ensure it's in session state for tools to find
|
| 94 |
+
st.session_state["api_key"] = api_key
|
| 95 |
+
|
| 96 |
+
# Ensure LlamaIndex settings (Embeddings, LLM) are applied on every run
|
| 97 |
+
setup_llama_index(api_key=api_key)
|
| 98 |
+
|
| 99 |
+
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0, google_api_key=api_key)
|
| 100 |
+
|
| 101 |
+
# 2. Load LlamaIndex (Cached)
|
| 102 |
+
# The index is loaded once and cached to avoid reloading on every interaction.
|
| 103 |
+
index = load_index()
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# 3. Define Agent (Cached)
|
| 107 |
+
@st.cache_resource
|
| 108 |
+
def get_agent():
|
| 109 |
+
"""Initializes and caches the LangChain agent."""
|
| 110 |
+
tools = [
|
| 111 |
+
search_trials,
|
| 112 |
+
find_similar_studies,
|
| 113 |
+
get_study_analytics,
|
| 114 |
+
compare_studies,
|
| 115 |
+
get_study_details,
|
| 116 |
+
get_cohort_sql,
|
| 117 |
+
]
|
| 118 |
+
|
| 119 |
+
prompt = ChatPromptTemplate.from_messages(
|
| 120 |
+
[
|
| 121 |
+
(
|
| 122 |
+
"system",
|
| 123 |
+
"You are a Clinical Trial Expert Assistant. "
|
| 124 |
+
"Your goal is to help researchers and analysts understand clinical trial data. "
|
| 125 |
+
"You have access to a local database of clinical trials (embedded from ClinicalTrials.gov). "
|
| 126 |
+
"Use the available tools to search for studies, find similar studies, and generate analytics. "
|
| 127 |
+
"When asked about 'trends', 'counts', 'how many', or 'most common', ALWAYS use the `get_study_analytics` tool. "
|
| 128 |
+
"Do NOT use `search_trials` for counting questions like 'How many studies...'. "
|
| 129 |
+
"When asked to 'find studies', 'search', or 'list', use `search_trials`. "
|
| 130 |
+
"When asked to 'compare' multiple studies or answer complex multi-part questions, use `compare_studies`. "
|
| 131 |
+
"If the user asks for a specific study by ID (e.g., NCT12345678), `search_trials` handles that automatically. "
|
| 132 |
+
"However, if the user asks for specific **details**, **criteria**, **summary**, or **protocol** of a single study, "
|
| 133 |
+
"you MUST use the `get_study_details` tool to fetch the full content. "
|
| 134 |
+
"If the user asks to **generate SQL**, **build a cohort**, or **translate criteria to code** for a study, "
|
| 135 |
+
"use the `get_cohort_sql` tool. "
|
| 136 |
+
"When reporting 'similar studies', ALWAYS include the similarity score provided by the tool "
|
| 137 |
+
"and DO NOT include the study that was used as the query (the reference study). "
|
| 138 |
+
"Provide concise, evidence-based answers citing specific studies when possible.",
|
| 139 |
+
),
|
| 140 |
+
MessagesPlaceholder(variable_name="chat_history"),
|
| 141 |
+
("human", "{input}"),
|
| 142 |
+
("placeholder", "{agent_scratchpad}"),
|
| 143 |
+
]
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
agent = create_tool_calling_agent(llm, tools, prompt)
|
| 147 |
+
return AgentExecutor(agent=agent, tools=tools, verbose=True)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
agent_executor = get_agent()
|
| 151 |
+
|
| 152 |
+
# --- Sidebar ---
|
| 153 |
+
with st.sidebar:
|
| 154 |
+
st.image(
|
| 155 |
+
"https://cdn-icons-png.flaticon.com/512/3004/3004458.png", width=50
|
| 156 |
+
)
|
| 157 |
+
st.title("Clinical Trial Agent")
|
| 158 |
+
|
| 159 |
+
page = option_menu(
|
| 160 |
+
"Main Menu",
|
| 161 |
+
["Chat Assistant", "Analytics Dashboard", "Knowledge Graph", "Study Map", "Raw Data"],
|
| 162 |
+
icons=["chat-dots", "graph-up", "diagram-3", "map", "database"],
|
| 163 |
+
menu_icon="cast",
|
| 164 |
+
default_index=0,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
# --- Helper Functions ---
|
| 169 |
+
def generate_dashboard_analytics():
|
| 170 |
+
"""Callback to generate analytics and update session state."""
|
| 171 |
+
# Map UI selection to tool arguments
|
| 172 |
+
group_map = {
|
| 173 |
+
"Phase": "phase",
|
| 174 |
+
"Status": "status",
|
| 175 |
+
"Sponsor": "sponsor",
|
| 176 |
+
"Start Year": "start_year",
|
| 177 |
+
"Intervention": "intervention",
|
| 178 |
+
"Study Type": "study_type",
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
# Get values from session state
|
| 182 |
+
# We use .get() to avoid KeyErrors if the widget hasn't initialized yet (though it should have)
|
| 183 |
+
g_by = st.session_state.get("dash_group_by", "Sponsor")
|
| 184 |
+
p_filter = st.session_state.get("dash_phase", "")
|
| 185 |
+
s_filter = st.session_state.get("dash_sponsor", "")
|
| 186 |
+
|
| 187 |
+
with st.spinner(f"Analyzing studies by {g_by}..."):
|
| 188 |
+
# Call the tool directly
|
| 189 |
+
result = get_study_analytics.invoke(
|
| 190 |
+
{
|
| 191 |
+
"query": "overall",
|
| 192 |
+
"group_by": group_map.get(g_by, "sponsor"),
|
| 193 |
+
"phase": p_filter if p_filter else None,
|
| 194 |
+
"sponsor": s_filter if s_filter else None,
|
| 195 |
+
}
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# The tool sets session state 'inline_chart_data'
|
| 199 |
+
if "inline_chart_data" in st.session_state:
|
| 200 |
+
st.session_state["dashboard_data"] = st.session_state["inline_chart_data"]
|
| 201 |
+
else:
|
| 202 |
+
st.warning(result)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# --- PAGE 1: CHAT ---
|
| 206 |
+
if page == "Chat Assistant":
|
| 207 |
+
st.header("💬 Chat Assistant")
|
| 208 |
+
if "messages" not in st.session_state:
|
| 209 |
+
st.session_state.messages = []
|
| 210 |
+
|
| 211 |
+
# Render Chat History
|
| 212 |
+
for message in st.session_state.messages:
|
| 213 |
+
with st.chat_message(message["role"]):
|
| 214 |
+
st.markdown(message["content"])
|
| 215 |
+
# Render chart if present in message history (persisted charts)
|
| 216 |
+
if "chart_data" in message:
|
| 217 |
+
chart_data = message["chart_data"]
|
| 218 |
+
st.caption(chart_data["title"])
|
| 219 |
+
chart = (
|
| 220 |
+
alt.Chart(pd.DataFrame(chart_data["data"]))
|
| 221 |
+
.mark_bar()
|
| 222 |
+
.encode(
|
| 223 |
+
x=alt.X(
|
| 224 |
+
chart_data["x"], sort="-y", axis=alt.Axis(labelLimit=200)
|
| 225 |
+
),
|
| 226 |
+
y=alt.Y(chart_data["y"], title="Count"),
|
| 227 |
+
tooltip=[chart_data["x"], chart_data["y"]],
|
| 228 |
+
)
|
| 229 |
+
.interactive()
|
| 230 |
+
)
|
| 231 |
+
st.altair_chart(chart, theme="streamlit", width="stretch")
|
| 232 |
+
|
| 233 |
+
# Chat Input
|
| 234 |
+
if prompt := st.chat_input("Ask about clinical trials..."):
|
| 235 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
| 236 |
+
with st.chat_message("user"):
|
| 237 |
+
st.markdown(prompt)
|
| 238 |
+
|
| 239 |
+
with st.chat_message("assistant"):
|
| 240 |
+
with st.spinner("Analyzing clinical trials..."):
|
| 241 |
+
try:
|
| 242 |
+
# Clear previous inline chart data to avoid stale charts
|
| 243 |
+
if "inline_chart_data" in st.session_state:
|
| 244 |
+
del st.session_state["inline_chart_data"]
|
| 245 |
+
|
| 246 |
+
# Construct chat history for the agent context
|
| 247 |
+
chat_history = []
|
| 248 |
+
for msg in st.session_state.messages[:-1]:
|
| 249 |
+
if msg["role"] == "user":
|
| 250 |
+
chat_history.append(HumanMessage(content=msg["content"]))
|
| 251 |
+
else:
|
| 252 |
+
chat_history.append(AIMessage(content=msg["content"]))
|
| 253 |
+
|
| 254 |
+
# Invoke Agent
|
| 255 |
+
response = agent_executor.invoke(
|
| 256 |
+
{"input": prompt, "chat_history": chat_history}
|
| 257 |
+
)
|
| 258 |
+
output = response["output"]
|
| 259 |
+
st.markdown(output)
|
| 260 |
+
|
| 261 |
+
# Check for inline chart data (set by tools)
|
| 262 |
+
chart_data = None
|
| 263 |
+
if "inline_chart_data" in st.session_state:
|
| 264 |
+
chart_data = st.session_state["inline_chart_data"]
|
| 265 |
+
st.caption(chart_data["title"])
|
| 266 |
+
if chart_data["type"] == "bar":
|
| 267 |
+
# Use Altair for better charts
|
| 268 |
+
chart = (
|
| 269 |
+
alt.Chart(pd.DataFrame(chart_data["data"]))
|
| 270 |
+
.mark_bar()
|
| 271 |
+
.encode(
|
| 272 |
+
x=alt.X(
|
| 273 |
+
chart_data["x"],
|
| 274 |
+
sort="-y",
|
| 275 |
+
axis=alt.Axis(labelLimit=200),
|
| 276 |
+
),
|
| 277 |
+
y=alt.Y(chart_data["y"], title="Count"),
|
| 278 |
+
tooltip=[chart_data["x"], chart_data["y"]],
|
| 279 |
+
)
|
| 280 |
+
.interactive()
|
| 281 |
+
)
|
| 282 |
+
st.altair_chart(chart, theme="streamlit", width="stretch")
|
| 283 |
+
|
| 284 |
+
# Clean up session state
|
| 285 |
+
del st.session_state["inline_chart_data"]
|
| 286 |
+
|
| 287 |
+
# Save message with chart data if present
|
| 288 |
+
msg_obj = {"role": "assistant", "content": output}
|
| 289 |
+
if chart_data:
|
| 290 |
+
msg_obj["chart_data"] = chart_data
|
| 291 |
+
st.session_state.messages.append(msg_obj)
|
| 292 |
+
|
| 293 |
+
except Exception as e:
|
| 294 |
+
st.error(f"An error occurred: {e}")
|
| 295 |
+
|
| 296 |
+
# --- PAGE 2: ANALYTICS DASHBOARD ---
|
| 297 |
+
if page == "Analytics Dashboard":
|
| 298 |
+
st.header("📊 Global Analytics")
|
| 299 |
+
st.write(
|
| 300 |
+
"Analyze trends across the entire clinical trial dataset (60,000+ studies)."
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
col1, col2 = st.columns([1, 3])
|
| 304 |
+
|
| 305 |
+
with col1:
|
| 306 |
+
st.subheader("Configuration")
|
| 307 |
+
group_by = st.selectbox(
|
| 308 |
+
"Group By",
|
| 309 |
+
["Phase", "Status", "Sponsor", "Start Year", "Intervention", "Study Type"],
|
| 310 |
+
index=2,
|
| 311 |
+
key="dash_group_by",
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# Optional Filters
|
| 315 |
+
st.markdown("---")
|
| 316 |
+
st.markdown("**Filters (Optional)**")
|
| 317 |
+
filter_phase = st.text_input("Phase (e.g., Phase 2)", key="dash_phase")
|
| 318 |
+
filter_sponsor = st.text_input("Sponsor (e.g., Pfizer)", key="dash_sponsor")
|
| 319 |
+
|
| 320 |
+
st.button(
|
| 321 |
+
"Generate Analytics", type="primary", on_click=generate_dashboard_analytics
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
with col2:
|
| 325 |
+
# Always render if data exists in session state
|
| 326 |
+
if "dashboard_data" in st.session_state:
|
| 327 |
+
c_data = st.session_state["dashboard_data"]
|
| 328 |
+
st.subheader(c_data["title"])
|
| 329 |
+
|
| 330 |
+
# Altair Chart Rendering
|
| 331 |
+
if (
|
| 332 |
+
c_data["x"] == "start_year" or group_by == "Start Year"
|
| 333 |
+
): # Check both key and UI selection
|
| 334 |
+
# Line chart for years
|
| 335 |
+
chart = (
|
| 336 |
+
alt.Chart(pd.DataFrame(c_data["data"]))
|
| 337 |
+
.mark_line(point=True)
|
| 338 |
+
.encode(
|
| 339 |
+
x=alt.X(
|
| 340 |
+
c_data["x"], axis=alt.Axis(format="d"), title="Year"
|
| 341 |
+
), # 'd' for integer year
|
| 342 |
+
y=alt.Y(c_data["y"], title="Count"),
|
| 343 |
+
tooltip=[c_data["x"], c_data["y"]],
|
| 344 |
+
)
|
| 345 |
+
.interactive()
|
| 346 |
+
)
|
| 347 |
+
else:
|
| 348 |
+
# Bar chart for others
|
| 349 |
+
chart = (
|
| 350 |
+
alt.Chart(pd.DataFrame(c_data["data"]))
|
| 351 |
+
.mark_bar()
|
| 352 |
+
.encode(
|
| 353 |
+
x=alt.X(
|
| 354 |
+
c_data["x"],
|
| 355 |
+
sort="-y",
|
| 356 |
+
axis=alt.Axis(labelLimit=200),
|
| 357 |
+
),
|
| 358 |
+
y=alt.Y(c_data["y"], title="Count"),
|
| 359 |
+
tooltip=[c_data["x"], c_data["y"]],
|
| 360 |
+
)
|
| 361 |
+
.interactive()
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
st.altair_chart(chart, theme="streamlit", width="stretch")
|
| 365 |
+
|
| 366 |
+
# Show raw table
|
| 367 |
+
with st.expander("View Source Data"):
|
| 368 |
+
st.dataframe(pd.DataFrame(c_data["data"]))
|
| 369 |
+
|
| 370 |
+
# --- PAGE 3: KNOWLEDGE GRAPH ---
|
| 371 |
+
if page == "Knowledge Graph":
|
| 372 |
+
st.header("🕸️ Interactive Knowledge Graph")
|
| 373 |
+
st.write("Visualize connections between Studies, Sponsors, and Conditions.")
|
| 374 |
+
|
| 375 |
+
col_g1, col_g2 = st.columns([1, 3])
|
| 376 |
+
|
| 377 |
+
with col_g1:
|
| 378 |
+
st.subheader("Graph Settings")
|
| 379 |
+
graph_query = st.text_input("Search Topic", value="Cancer")
|
| 380 |
+
limit = st.slider("Max Nodes", 10, 100, 50)
|
| 381 |
+
|
| 382 |
+
if st.button("Build Graph"):
|
| 383 |
+
with st.spinner("Fetching data and building graph..."):
|
| 384 |
+
# Use retriever to get relevant nodes
|
| 385 |
+
retriever = index.as_retriever(similarity_top_k=limit)
|
| 386 |
+
nodes = retriever.retrieve(graph_query)
|
| 387 |
+
data = [n.metadata for n in nodes]
|
| 388 |
+
|
| 389 |
+
# Build Graph
|
| 390 |
+
g_nodes, g_edges, g_config = build_graph(data)
|
| 391 |
+
|
| 392 |
+
st.session_state["graph_data"] = {
|
| 393 |
+
"nodes": g_nodes,
|
| 394 |
+
"edges": g_edges,
|
| 395 |
+
"config": g_config,
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
with col_g2:
|
| 399 |
+
if "graph_data" in st.session_state:
|
| 400 |
+
g_data = st.session_state["graph_data"]
|
| 401 |
+
st.success(
|
| 402 |
+
f"Graph built with {len(g_data['nodes'])} nodes and {len(g_data['edges'])} edges."
|
| 403 |
+
)
|
| 404 |
+
agraph(
|
| 405 |
+
nodes=g_data["nodes"], edges=g_data["edges"], config=g_data["config"]
|
| 406 |
+
)
|
| 407 |
+
else:
|
| 408 |
+
st.info("Enter a topic and click 'Build Graph' to visualize connections.")
|
| 409 |
+
|
| 410 |
+
# --- PAGE# --- Study Map Tab ---
|
| 411 |
+
elif page == "Study Map":
|
| 412 |
+
st.header("🌍 Global Clinical Trial Map")
|
| 413 |
+
st.markdown("Visualize the geographic distribution of clinical trials.")
|
| 414 |
+
|
| 415 |
+
# Sidebar Filters for Map
|
| 416 |
+
st.sidebar.markdown("### 🗺️ Map Filters")
|
| 417 |
+
map_region = st.sidebar.radio("Region", ["World", "USA"], index=0)
|
| 418 |
+
|
| 419 |
+
map_phase = st.sidebar.multiselect(
|
| 420 |
+
"Phase", ["PHASE1", "PHASE2", "PHASE3", "PHASE4"], default=["PHASE2", "PHASE3"]
|
| 421 |
+
)
|
| 422 |
+
map_status = st.sidebar.selectbox(
|
| 423 |
+
"Status", ["RECRUITING", "COMPLETED", "ACTIVE_NOT_RECRUITING"], index=0
|
| 424 |
+
)
|
| 425 |
+
map_sponsor = st.sidebar.text_input("Sponsor (Optional)", "")
|
| 426 |
+
map_year = st.sidebar.number_input("Start Year (>=)", min_value=2000, value=2020)
|
| 427 |
+
map_type = st.sidebar.selectbox(
|
| 428 |
+
"Study Type", ["Interventional", "Observational", "All"], index=0
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
# Convert filters to arguments
|
| 432 |
+
phase_str = ",".join(map_phase) if map_phase else None
|
| 433 |
+
type_arg = map_type if map_type != "All" else None
|
| 434 |
+
|
| 435 |
+
if st.button("Update Map"):
|
| 436 |
+
with st.spinner("Aggregating geographic data..."):
|
| 437 |
+
# Determine grouping based on Region
|
| 438 |
+
group_by_field = "state" if map_region == "USA" else "country"
|
| 439 |
+
|
| 440 |
+
# Call analytics logic directly
|
| 441 |
+
summary = fetch_study_analytics_data(
|
| 442 |
+
query="overall",
|
| 443 |
+
group_by=group_by_field,
|
| 444 |
+
phase=phase_str,
|
| 445 |
+
status=map_status,
|
| 446 |
+
sponsor=map_sponsor,
|
| 447 |
+
start_year=map_year,
|
| 448 |
+
study_type=type_arg,
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
# Retrieve data from session state
|
| 452 |
+
chart_data = st.session_state.get("inline_chart_data", {})
|
| 453 |
+
data_records = chart_data.get("data", [])
|
| 454 |
+
|
| 455 |
+
if not data_records:
|
| 456 |
+
st.warning("No data found for these filters.")
|
| 457 |
+
st.session_state["map_data"] = None
|
| 458 |
+
st.session_state["map_region"] = map_region # Store region too
|
| 459 |
+
else:
|
| 460 |
+
# Store in session state for persistence
|
| 461 |
+
st.session_state["map_data"] = data_records
|
| 462 |
+
st.session_state["map_region"] = map_region
|
| 463 |
+
|
| 464 |
+
# Render Map (Outside Button Block)
|
| 465 |
+
if st.session_state.get("map_data"):
|
| 466 |
+
data_records = st.session_state["map_data"]
|
| 467 |
+
region_mode = st.session_state.get("map_region", "World")
|
| 468 |
+
df_map = pd.DataFrame(data_records)
|
| 469 |
+
|
| 470 |
+
# Configure Map Center/Zoom
|
| 471 |
+
if region_mode == "USA":
|
| 472 |
+
m = folium.Map(location=[37.0902, -95.7129], zoom_start=4)
|
| 473 |
+
coord_map = STATE_COORDINATES
|
| 474 |
+
else:
|
| 475 |
+
m = folium.Map(location=[20, 0], zoom_start=2)
|
| 476 |
+
coord_map = COUNTRY_COORDINATES
|
| 477 |
+
|
| 478 |
+
# Add CircleMarkers
|
| 479 |
+
for _, row in df_map.iterrows():
|
| 480 |
+
loc_name = row["category"]
|
| 481 |
+
count = row["count"]
|
| 482 |
+
|
| 483 |
+
# Clean name if needed (strip trailing parenthesis)
|
| 484 |
+
loc_clean = loc_name.rstrip(")")
|
| 485 |
+
coords = coord_map.get(loc_clean)
|
| 486 |
+
|
| 487 |
+
if coords:
|
| 488 |
+
folium.CircleMarker(
|
| 489 |
+
location=coords,
|
| 490 |
+
radius=min(max(count / 5, 3), 20), # Adjust scale
|
| 491 |
+
popup=f"{loc_clean}: {count} trials",
|
| 492 |
+
color="blue" if region_mode == "USA" else "crimson",
|
| 493 |
+
fill=True,
|
| 494 |
+
fill_color="blue" if region_mode == "USA" else "crimson",
|
| 495 |
+
).add_to(m)
|
| 496 |
+
|
| 497 |
+
st_folium(m, width=800, height=500)
|
| 498 |
+
|
| 499 |
+
# Show data table
|
| 500 |
+
st.subheader(f"{region_mode} Data")
|
| 501 |
+
st.dataframe(df_map)
|
| 502 |
+
|
| 503 |
+
# --- PAGE 4: RAW DATA ---
|
| 504 |
+
if page == "Raw Data":
|
| 505 |
+
st.header("📂 Raw Data Explorer")
|
| 506 |
+
st.write("View and filter the underlying dataset.")
|
| 507 |
+
|
| 508 |
+
# Load a sample or full dataset? Full might be slow.
|
| 509 |
+
# We load a sample (top 100) to avoid performance issues.
|
| 510 |
+
col_raw_1, col_raw_2 = st.columns([1, 1])
|
| 511 |
+
|
| 512 |
+
with col_raw_1:
|
| 513 |
+
if st.button("Load Sample Data (Top 100)"):
|
| 514 |
+
with st.spinner("Fetching data..."):
|
| 515 |
+
retriever = index.as_retriever(similarity_top_k=100)
|
| 516 |
+
nodes = retriever.retrieve("clinical trial")
|
| 517 |
+
data = [n.metadata for n in nodes]
|
| 518 |
+
df_raw = pd.DataFrame(data)
|
| 519 |
+
|
| 520 |
+
# Format Year to remove commas (e.g., 2,023 -> 2023)
|
| 521 |
+
if "start_year" in df_raw.columns:
|
| 522 |
+
df_raw["start_year"] = (
|
| 523 |
+
pd.to_numeric(df_raw["start_year"], errors="coerce")
|
| 524 |
+
.astype("Int64")
|
| 525 |
+
.astype(str)
|
| 526 |
+
.str.replace(",", "")
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
# Store in session state to persist the table
|
| 530 |
+
st.session_state["sample_data"] = df_raw
|
| 531 |
+
|
| 532 |
+
with col_raw_2:
|
| 533 |
+
# Download Full Dataset Logic
|
| 534 |
+
if st.button("Prepare Full Download (CSV)"):
|
| 535 |
+
with st.spinner("Fetching all records from database..."):
|
| 536 |
+
try:
|
| 537 |
+
# Access LanceDB directly for speed
|
| 538 |
+
import lancedb
|
| 539 |
+
db = lancedb.connect("./ct_gov_lancedb")
|
| 540 |
+
tbl = db.open_table("clinical_trials")
|
| 541 |
+
|
| 542 |
+
# Fetch all data
|
| 543 |
+
df_full = tbl.to_pandas()
|
| 544 |
+
|
| 545 |
+
# Handle metadata flattening if needed
|
| 546 |
+
if "metadata" in df_full.columns:
|
| 547 |
+
meta_df = pd.json_normalize(df_full["metadata"])
|
| 548 |
+
# Combine or just use metadata
|
| 549 |
+
df_full = meta_df
|
| 550 |
+
|
| 551 |
+
# Convert to CSV
|
| 552 |
+
csv = df_full.to_csv(index=False).encode("utf-8")
|
| 553 |
+
st.session_state["full_csv"] = csv
|
| 554 |
+
st.success(f"Ready! Fetched {len(df_full)} records.")
|
| 555 |
+
else:
|
| 556 |
+
st.warning("No data found in database.")
|
| 557 |
+
except Exception as e:
|
| 558 |
+
st.error(f"Error fetching data: {e}")
|
| 559 |
+
|
| 560 |
+
if "full_csv" in st.session_state:
|
| 561 |
+
st.download_button(
|
| 562 |
+
label="⬇️ Download Full CSV",
|
| 563 |
+
data=st.session_state["full_csv"],
|
| 564 |
+
file_name="clinical_trials_full.csv",
|
| 565 |
+
mime="text/csv",
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
# Display Sample Data Table (Full Width)
|
| 569 |
+
if "sample_data" in st.session_state:
|
| 570 |
+
st.markdown("### Sample Data (Top 100)")
|
| 571 |
+
st.dataframe(
|
| 572 |
+
st.session_state["sample_data"],
|
| 573 |
+
column_config={
|
| 574 |
+
"nct_id": "NCT ID",
|
| 575 |
+
"title": "Study Title",
|
| 576 |
+
"start_year": st.column_config.TextColumn(
|
| 577 |
+
"Start Year"
|
| 578 |
+
), # Force text to avoid commas
|
| 579 |
+
"url": st.column_config.LinkColumn("URL"),
|
| 580 |
+
},
|
| 581 |
+
width="stretch",
|
| 582 |
+
hide_index=True,
|
| 583 |
+
)
|
modules/__init__.py
ADDED
|
File without changes
|
modules/cohort_tools.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from langchain.tools import tool
|
| 3 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 4 |
+
from langchain.prompts import PromptTemplate
|
| 5 |
+
from modules.tools import get_study_details
|
| 6 |
+
from modules.utils import load_environment
|
| 7 |
+
|
| 8 |
+
import streamlit as st
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
# Load env for API key
|
| 12 |
+
load_environment()
|
| 13 |
+
|
| 14 |
+
def get_llm():
|
| 15 |
+
"""Retrieves LLM instance with dynamic API key."""
|
| 16 |
+
# Check session state first (User provided key)
|
| 17 |
+
api_key = None
|
| 18 |
+
if hasattr(st, "session_state") and "api_key" in st.session_state:
|
| 19 |
+
api_key = st.session_state["api_key"]
|
| 20 |
+
|
| 21 |
+
# Fallback to environment variable
|
| 22 |
+
if not api_key:
|
| 23 |
+
api_key = os.environ.get("GOOGLE_API_KEY")
|
| 24 |
+
|
| 25 |
+
if not api_key:
|
| 26 |
+
raise ValueError("Google API Key not found in session state or environment.")
|
| 27 |
+
|
| 28 |
+
return ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0, google_api_key=api_key)
|
| 29 |
+
|
| 30 |
+
# Initialize LLM (Dynamic)
|
| 31 |
+
# llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0)
|
| 32 |
+
|
| 33 |
+
EXTRACT_PROMPT = PromptTemplate(
|
| 34 |
+
template="""
|
| 35 |
+
You are a Clinical Informatics Expert.
|
| 36 |
+
Your task is to extract structured cohort requirements from the following Clinical Trial Eligibility Criteria.
|
| 37 |
+
|
| 38 |
+
Output a JSON object with two keys: "inclusion" and "exclusion".
|
| 39 |
+
Each key should contain a list of rules.
|
| 40 |
+
Each rule should have:
|
| 41 |
+
- "concept": The medical concept (e.g., "Type 2 Diabetes", "Metformin").
|
| 42 |
+
- "domain": The domain (Condition, Drug, Measurement, Procedure, Observation).
|
| 43 |
+
- "temporal": Any temporal logic (e.g., "History of", "Within last 6 months").
|
| 44 |
+
- "codes": A list of potential ICD-10 or RxNorm codes (make a best guess).
|
| 45 |
+
|
| 46 |
+
CRITERIA:
|
| 47 |
+
{criteria}
|
| 48 |
+
|
| 49 |
+
JSON OUTPUT:
|
| 50 |
+
""",
|
| 51 |
+
input_variables=["criteria"],
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
SQL_PROMPT = PromptTemplate(
|
| 55 |
+
template="""
|
| 56 |
+
You are a SQL Expert specializing in Healthcare Claims Data Analysis.
|
| 57 |
+
Generate a standard SQL query to define a cohort of patients based on the following structured requirements.
|
| 58 |
+
|
| 59 |
+
### Schema Assumptions
|
| 60 |
+
1. **`medical_claims`** (Diagnoses & Procedures):
|
| 61 |
+
- `patient_id`, `claim_date`, `diagnosis_code` (ICD-10), `procedure_code` (CPT/HCPCS).
|
| 62 |
+
2. **`pharmacy_claims`** (Drugs):
|
| 63 |
+
- `patient_id`, `fill_date`, `ndc_code`.
|
| 64 |
+
|
| 65 |
+
### Logic Rules
|
| 66 |
+
1. **Conditions (Diagnoses)**:
|
| 67 |
+
- Require **at least 2 distinct claim dates** where the diagnosis code matches.
|
| 68 |
+
- These 2 claims must be **at least 30 days apart** (to confirm chronic condition).
|
| 69 |
+
2. **Drugs**:
|
| 70 |
+
- Require at least 1 claim with a matching NDC code.
|
| 71 |
+
3. **Procedures**:
|
| 72 |
+
- Require at least 1 claim with a matching CPT/HCPCS code.
|
| 73 |
+
4. **Exclusions**:
|
| 74 |
+
- Exclude patients who have ANY matching claims for exclusion criteria.
|
| 75 |
+
|
| 76 |
+
### Requirements (JSON)
|
| 77 |
+
{requirements}
|
| 78 |
+
|
| 79 |
+
### Output
|
| 80 |
+
Generate a single SQL query that selects `patient_id` from the claims tables meeting the criteria.
|
| 81 |
+
Use Common Table Expressions (CTEs) for clarity.
|
| 82 |
+
Do NOT output markdown formatting (```sql), just the raw SQL.
|
| 83 |
+
|
| 84 |
+
SQL QUERY:
|
| 85 |
+
""",
|
| 86 |
+
input_variables=["requirements"],
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def extract_cohort_requirements(criteria_text: str) -> dict:
|
| 91 |
+
"""Uses LLM to parse criteria text into structured JSON."""
|
| 92 |
+
llm = get_llm()
|
| 93 |
+
chain = EXTRACT_PROMPT | llm
|
| 94 |
+
response = chain.invoke({"criteria": criteria_text})
|
| 95 |
+
try:
|
| 96 |
+
# Clean up potential markdown code blocks
|
| 97 |
+
text = response.content.replace("```json", "").replace("```", "").strip()
|
| 98 |
+
return json.loads(text)
|
| 99 |
+
except json.JSONDecodeError:
|
| 100 |
+
return {"error": "Failed to parse LLM output", "raw_output": response.content}
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def generate_cohort_sql(requirements: dict) -> str:
|
| 104 |
+
"""Uses LLM to translate structured requirements into SQL."""
|
| 105 |
+
llm = get_llm()
|
| 106 |
+
chain = SQL_PROMPT | llm
|
| 107 |
+
response = chain.invoke({"requirements": json.dumps(requirements, indent=2)})
|
| 108 |
+
return response.content.replace("```sql", "").replace("```", "").strip()
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@tool("get_cohort_sql")
|
| 112 |
+
def get_cohort_sql(nct_id: str) -> str:
|
| 113 |
+
"""
|
| 114 |
+
Generates a SQL query to define the patient cohort for a specific study (NCT ID).
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
nct_id (str): The ClinicalTrials.gov identifier (e.g., NCT01234567).
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
str: A formatted string containing the Extracted Requirements (JSON) and the Generated SQL.
|
| 121 |
+
"""
|
| 122 |
+
# 1. Fetch Study Details
|
| 123 |
+
# We reuse the existing tool logic to get the text
|
| 124 |
+
study_text = get_study_details.invoke(nct_id)
|
| 125 |
+
|
| 126 |
+
if "No study found" in study_text:
|
| 127 |
+
return f"Could not find study {nct_id}."
|
| 128 |
+
|
| 129 |
+
# 2. Extract Requirements
|
| 130 |
+
requirements = extract_cohort_requirements(study_text)
|
| 131 |
+
|
| 132 |
+
# 3. Generate SQL
|
| 133 |
+
sql_query = generate_cohort_sql(requirements)
|
| 134 |
+
|
| 135 |
+
return f"""
|
| 136 |
+
### 📋 Extracted Cohort Requirements
|
| 137 |
+
```json
|
| 138 |
+
{json.dumps(requirements, indent=2)}
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
### 💾 Generated SQL Query (OMOP CDM)
|
| 142 |
+
```sql
|
| 143 |
+
{sql_query}
|
| 144 |
+
```
|
| 145 |
+
"""
|
modules/constants.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# --- Geographic Constants ---
|
| 3 |
+
COUNTRY_COORDINATES = {
|
| 4 |
+
"United States": [37.0902, -95.7129],
|
| 5 |
+
"Canada": [56.1304, -106.3468],
|
| 6 |
+
"United Kingdom": [55.3781, -3.4360],
|
| 7 |
+
"Germany": [51.1657, 10.4515],
|
| 8 |
+
"France": [46.2276, 2.2137],
|
| 9 |
+
"China": [35.8617, 104.1954],
|
| 10 |
+
"Japan": [36.2048, 138.2529],
|
| 11 |
+
"Australia": [-25.2744, 133.7751],
|
| 12 |
+
"Brazil": [-14.2350, -51.9253],
|
| 13 |
+
"India": [20.5937, 78.9629],
|
| 14 |
+
"Russia": [61.5240, 105.3188],
|
| 15 |
+
"South Korea": [35.9078, 127.7669],
|
| 16 |
+
"Italy": [41.8719, 12.5674],
|
| 17 |
+
"Spain": [40.4637, -3.7492],
|
| 18 |
+
"Netherlands": [52.1326, 5.2913],
|
| 19 |
+
"Belgium": [50.5039, 4.4699],
|
| 20 |
+
"Switzerland": [46.8182, 8.2275],
|
| 21 |
+
"Sweden": [60.1282, 18.6435],
|
| 22 |
+
"Israel": [31.0461, 34.8516],
|
| 23 |
+
"Poland": [51.9194, 19.1451],
|
| 24 |
+
"Taiwan": [23.6978, 120.9605],
|
| 25 |
+
"Mexico": [23.6345, -102.5528],
|
| 26 |
+
"Argentina": [-38.4161, -63.6167],
|
| 27 |
+
"South Africa": [-30.5595, 22.9375],
|
| 28 |
+
"Turkey": [38.9637, 35.2433],
|
| 29 |
+
"Denmark": [56.2639, 9.5018],
|
| 30 |
+
"New Zealand": [-40.9006, 174.8860],
|
| 31 |
+
"Czech Republic": [49.8175, 15.4730],
|
| 32 |
+
"Hungary": [47.1625, 19.5033],
|
| 33 |
+
"Finland": [61.9241, 25.7482],
|
| 34 |
+
"Norway": [60.4720, 8.4689],
|
| 35 |
+
"Austria": [47.5162, 14.5501],
|
| 36 |
+
"Greece": [39.0742, 21.8243],
|
| 37 |
+
"Ireland": [53.1424, -7.6921],
|
| 38 |
+
"Portugal": [39.3999, -8.2245],
|
| 39 |
+
"Ukraine": [48.3794, 31.1656],
|
| 40 |
+
"Egypt": [26.8206, 30.8025],
|
| 41 |
+
"Thailand": [15.8700, 100.9925],
|
| 42 |
+
"Singapore": [1.3521, 103.8198],
|
| 43 |
+
"Malaysia": [4.2105, 101.9758],
|
| 44 |
+
"Vietnam": [14.0583, 108.2772],
|
| 45 |
+
"Philippines": [12.8797, 121.7740],
|
| 46 |
+
"Indonesia": [-0.7893, 113.9213],
|
| 47 |
+
"Saudi Arabia": [23.8859, 45.0792],
|
| 48 |
+
"United Arab Emirates": [23.4241, 53.8478],
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
STATE_COORDINATES = {
|
| 52 |
+
"Alabama": [32.806671, -86.791130],
|
| 53 |
+
"Alaska": [61.370716, -152.404419],
|
| 54 |
+
"Arizona": [33.729759, -111.431221],
|
| 55 |
+
"Arkansas": [34.969704, -92.373123],
|
| 56 |
+
"California": [36.116203, -119.681564],
|
| 57 |
+
"Colorado": [39.059811, -105.311104],
|
| 58 |
+
"Connecticut": [41.597782, -72.755371],
|
| 59 |
+
"Delaware": [39.318523, -75.507141],
|
| 60 |
+
"District of Columbia": [38.897438, -77.026817],
|
| 61 |
+
"Florida": [27.766279, -81.686783],
|
| 62 |
+
"Georgia": [33.040619, -83.643074],
|
| 63 |
+
"Hawaii": [21.094318, -157.498337],
|
| 64 |
+
"Idaho": [44.240459, -114.478828],
|
| 65 |
+
"Illinois": [40.349457, -88.986137],
|
| 66 |
+
"Indiana": [39.849426, -86.258278],
|
| 67 |
+
"Iowa": [42.011539, -93.210526],
|
| 68 |
+
"Kansas": [38.526600, -96.726486],
|
| 69 |
+
"Kentucky": [37.668140, -84.670067],
|
| 70 |
+
"Louisiana": [31.169546, -91.867805],
|
| 71 |
+
"Maine": [44.693947, -69.381927],
|
| 72 |
+
"Maryland": [39.063946, -76.802101],
|
| 73 |
+
"Massachusetts": [42.230171, -71.530106],
|
| 74 |
+
"Michigan": [43.326618, -84.536095],
|
| 75 |
+
"Minnesota": [45.694454, -93.900192],
|
| 76 |
+
"Mississippi": [32.741646, -89.678696],
|
| 77 |
+
"Missouri": [38.456085, -92.288368],
|
| 78 |
+
"Montana": [46.921925, -110.454353],
|
| 79 |
+
"Nebraska": [41.125370, -98.268082],
|
| 80 |
+
"Nevada": [38.313515, -117.055374],
|
| 81 |
+
"New Hampshire": [43.452492, -71.563896],
|
| 82 |
+
"New Jersey": [40.298904, -74.521011],
|
| 83 |
+
"New Mexico": [34.840515, -106.248482],
|
| 84 |
+
"New York": [42.165726, -74.948051],
|
| 85 |
+
"North Carolina": [35.630066, -79.806419],
|
| 86 |
+
"North Dakota": [47.528912, -99.784012],
|
| 87 |
+
"Ohio": [40.388783, -82.764915],
|
| 88 |
+
"Oklahoma": [35.565342, -96.928917],
|
| 89 |
+
"Oregon": [44.572021, -122.070938],
|
| 90 |
+
"Pennsylvania": [41.203323, -77.194527],
|
| 91 |
+
"Rhode Island": [41.680893, -71.511780],
|
| 92 |
+
"South Carolina": [33.856892, -80.945007],
|
| 93 |
+
"South Dakota": [44.299782, -99.438828],
|
| 94 |
+
"Tennessee": [35.747845, -86.692345],
|
| 95 |
+
"Texas": [31.054487, -97.563461],
|
| 96 |
+
"Utah": [40.150032, -111.862434],
|
| 97 |
+
"Vermont": [44.045876, -72.710686],
|
| 98 |
+
"Virginia": [37.769337, -78.169968],
|
| 99 |
+
"Washington": [47.400902, -121.490494],
|
| 100 |
+
"West Virginia": [38.491226, -80.954453],
|
| 101 |
+
"Wisconsin": [44.268543, -89.616508],
|
| 102 |
+
"Wyoming": [42.755966, -107.302490],
|
| 103 |
+
}
|
modules/graph_viz.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from streamlit_agraph import Node, Edge, Config
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def build_graph(data):
|
| 5 |
+
"""
|
| 6 |
+
Constructs a knowledge graph from clinical trial data.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
data (list): List of study metadata dictionaries.
|
| 10 |
+
|
| 11 |
+
Returns:
|
| 12 |
+
tuple: (nodes, edges, config) for streamlit-agraph.
|
| 13 |
+
"""
|
| 14 |
+
nodes = []
|
| 15 |
+
edges = []
|
| 16 |
+
|
| 17 |
+
# Sets to track unique entities
|
| 18 |
+
study_ids = set()
|
| 19 |
+
sponsors = set()
|
| 20 |
+
conditions = set()
|
| 21 |
+
|
| 22 |
+
for study in data:
|
| 23 |
+
nct_id = study.get("nct_id", "Unknown")
|
| 24 |
+
title = study.get("title", "Unknown")
|
| 25 |
+
# Use 'sponsor' if available (new ingestion), else fallback to 'org'
|
| 26 |
+
sponsor = study.get("sponsor", study.get("org", "Unknown"))
|
| 27 |
+
condition_str = study.get("condition", "")
|
| 28 |
+
|
| 29 |
+
# 1. Study Node
|
| 30 |
+
if nct_id not in study_ids:
|
| 31 |
+
nodes.append(
|
| 32 |
+
Node(
|
| 33 |
+
id=nct_id,
|
| 34 |
+
label=nct_id,
|
| 35 |
+
size=20,
|
| 36 |
+
color="#4B8BBE", # Blue
|
| 37 |
+
title=title,
|
| 38 |
+
shape="dot",
|
| 39 |
+
)
|
| 40 |
+
)
|
| 41 |
+
study_ids.add(nct_id)
|
| 42 |
+
|
| 43 |
+
# 2. Sponsor Node & Edge
|
| 44 |
+
if sponsor and sponsor != "Unknown":
|
| 45 |
+
if sponsor not in sponsors:
|
| 46 |
+
nodes.append(
|
| 47 |
+
Node(
|
| 48 |
+
id=sponsor,
|
| 49 |
+
label=sponsor,
|
| 50 |
+
size=15,
|
| 51 |
+
color="#FF6B6B", # Red
|
| 52 |
+
shape="triangle",
|
| 53 |
+
)
|
| 54 |
+
)
|
| 55 |
+
sponsors.add(sponsor)
|
| 56 |
+
|
| 57 |
+
# Edge: Study -> Sponsor
|
| 58 |
+
edges.append(
|
| 59 |
+
Edge(
|
| 60 |
+
source=nct_id, target=sponsor, label="sponsored_by", color="#CCCCCC"
|
| 61 |
+
)
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# 3. Condition Nodes & Edges
|
| 65 |
+
if condition_str:
|
| 66 |
+
conds = [c.strip() for c in condition_str.split(",") if c.strip()]
|
| 67 |
+
for cond in conds:
|
| 68 |
+
if cond not in conditions:
|
| 69 |
+
nodes.append(
|
| 70 |
+
Node(
|
| 71 |
+
id=cond,
|
| 72 |
+
label=cond,
|
| 73 |
+
size=15,
|
| 74 |
+
color="#6BCB77", # Green
|
| 75 |
+
shape="diamond",
|
| 76 |
+
)
|
| 77 |
+
)
|
| 78 |
+
conditions.add(cond)
|
| 79 |
+
|
| 80 |
+
# Edge: Study -> Condition
|
| 81 |
+
edges.append(
|
| 82 |
+
Edge(source=nct_id, target=cond, label="studies", color="#CCCCCC")
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Configuration
|
| 86 |
+
config = Config(
|
| 87 |
+
width=800,
|
| 88 |
+
height=600,
|
| 89 |
+
directed=True,
|
| 90 |
+
physics=True,
|
| 91 |
+
hierarchical=False,
|
| 92 |
+
nodeHighlightBehavior=True,
|
| 93 |
+
highlightColor="#F7A7A6",
|
| 94 |
+
collapsible=False,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
return nodes, edges, config
|
modules/tools.py
ADDED
|
@@ -0,0 +1,706 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LangChain Tools for the Clinical Trial Agent.
|
| 3 |
+
|
| 4 |
+
This module defines the tools that the agent can use to interact with the clinical trial data.
|
| 5 |
+
Tools include:
|
| 6 |
+
1. **search_trials**: Semantic search with optional strict filtering.
|
| 7 |
+
2. **find_similar_studies**: Finding studies semantically similar to a given text.
|
| 8 |
+
3. **get_study_analytics**: Aggregating data for trends and insights (with inline charts).
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import streamlit as st
|
| 13 |
+
from typing import Optional
|
| 14 |
+
from langchain.tools import tool as langchain_tool
|
| 15 |
+
from llama_index.core.vector_stores import (
|
| 16 |
+
MetadataFilter,
|
| 17 |
+
MetadataFilters,
|
| 18 |
+
FilterOperator,
|
| 19 |
+
)
|
| 20 |
+
from llama_index.core import Settings
|
| 21 |
+
from llama_index.core.postprocessor import MetadataReplacementPostProcessor
|
| 22 |
+
from llama_index.core.postprocessor import SentenceTransformerRerank
|
| 23 |
+
from llama_index.core.query_engine import SubQuestionQueryEngine
|
| 24 |
+
from llama_index.core.tools import QueryEngineTool, ToolMetadata
|
| 25 |
+
from modules.utils import (
|
| 26 |
+
load_index,
|
| 27 |
+
normalize_sponsor,
|
| 28 |
+
get_sponsor_variations,
|
| 29 |
+
get_hybrid_retriever,
|
| 30 |
+
)
|
| 31 |
+
import re
|
| 32 |
+
import traceback
|
| 33 |
+
|
| 34 |
+
# --- Tools ---
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def expand_query(query: str) -> str:
|
| 38 |
+
"""Expands a search query with synonyms using the LLM."""
|
| 39 |
+
if not query or len(query.split()) > 10: # Skip expansion for long queries
|
| 40 |
+
return query
|
| 41 |
+
|
| 42 |
+
# Skip expansion if it looks like an NCT ID
|
| 43 |
+
if re.search(r"NCT\d+", query, re.IGNORECASE):
|
| 44 |
+
return query
|
| 45 |
+
|
| 46 |
+
prompt = (
|
| 47 |
+
f"You are a helpful medical assistant. "
|
| 48 |
+
f"Expand the following search query with relevant medical synonyms and acronyms. "
|
| 49 |
+
f"Return ONLY the expanded query string combined with OR operators. "
|
| 50 |
+
f"Do not add any explanation.\n\n"
|
| 51 |
+
f"Query: {query}\n"
|
| 52 |
+
f"Expanded Query:"
|
| 53 |
+
)
|
| 54 |
+
try:
|
| 55 |
+
# Use the global Settings.llm
|
| 56 |
+
if not Settings.llm:
|
| 57 |
+
# Fallback if not initialized (though load_index does it)
|
| 58 |
+
from modules.utils import setup_llama_index
|
| 59 |
+
|
| 60 |
+
setup_llama_index()
|
| 61 |
+
|
| 62 |
+
response = Settings.llm.complete(prompt)
|
| 63 |
+
expanded = response.text.strip()
|
| 64 |
+
# Clean up if LLM is chatty
|
| 65 |
+
if "Expanded Query:" in expanded:
|
| 66 |
+
expanded = expanded.split("Expanded Query:")[-1].strip()
|
| 67 |
+
|
| 68 |
+
if not expanded:
|
| 69 |
+
print(f"⚠️ Expansion returned empty. Using original query.")
|
| 70 |
+
return query
|
| 71 |
+
|
| 72 |
+
print(f"✨ Expanded Query: '{query}' -> '{expanded}'")
|
| 73 |
+
return expanded
|
| 74 |
+
except Exception as e:
|
| 75 |
+
print(f"⚠️ Query expansion failed: {e}")
|
| 76 |
+
return query
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@langchain_tool("search_trials")
|
| 80 |
+
def search_trials(
|
| 81 |
+
query: str = None,
|
| 82 |
+
status: str = None,
|
| 83 |
+
phase: str = None,
|
| 84 |
+
sponsor: str = None,
|
| 85 |
+
intervention: str = None,
|
| 86 |
+
year: int = None,
|
| 87 |
+
):
|
| 88 |
+
"""
|
| 89 |
+
Searches for clinical trials using semantic search with robust filtering.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
query (str, optional): The natural language search query.
|
| 93 |
+
status (str, optional): Filter by recruitment status.
|
| 94 |
+
phase (str, optional): Filter by trial phase.
|
| 95 |
+
sponsor (str, optional): Filter by sponsor name.
|
| 96 |
+
intervention (str, optional): Filter by intervention/drug name.
|
| 97 |
+
year (int, optional): Filter for studies starting on or after this year.
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
str: A structured list of relevant studies.
|
| 101 |
+
"""
|
| 102 |
+
index = load_index()
|
| 103 |
+
|
| 104 |
+
# Constants
|
| 105 |
+
TOP_K_STRICT = 500 # High recall for pre-filtered search
|
| 106 |
+
|
| 107 |
+
# --- Query Construction ---
|
| 108 |
+
if not query:
|
| 109 |
+
parts = [p for p in [sponsor, intervention, phase, status] if p]
|
| 110 |
+
query = " ".join(parts) if parts else "clinical trial"
|
| 111 |
+
else:
|
| 112 |
+
# Inject context for vector search
|
| 113 |
+
if sponsor and normalize_sponsor(sponsor).lower() not in query.lower():
|
| 114 |
+
query = f"{normalize_sponsor(sponsor)} {query}"
|
| 115 |
+
if intervention and intervention.lower() not in query.lower():
|
| 116 |
+
query = f"{intervention} {query}"
|
| 117 |
+
|
| 118 |
+
query = expand_query(query)
|
| 119 |
+
|
| 120 |
+
print(f"🔍 Tool Called: search_trials(query='{query}', sponsor='{sponsor}')")
|
| 121 |
+
|
| 122 |
+
# --- Strategy 1: Strict Pre-Retrieval Filtering (High Precision) ---
|
| 123 |
+
# Filter by Sponsor/Status/Year at the database level first.
|
| 124 |
+
pre_filters = []
|
| 125 |
+
|
| 126 |
+
# NCT ID Match
|
| 127 |
+
nct_match = re.search(r"NCT\d+", query, re.IGNORECASE)
|
| 128 |
+
if nct_match:
|
| 129 |
+
nct_id = nct_match.group(0).upper()
|
| 130 |
+
pre_filters.append(MetadataFilter(key="nct_id", value=nct_id, operator=FilterOperator.EQ))
|
| 131 |
+
|
| 132 |
+
if status:
|
| 133 |
+
pre_filters.append(MetadataFilter(key="status", value=status.upper(), operator=FilterOperator.EQ))
|
| 134 |
+
if year:
|
| 135 |
+
pre_filters.append(MetadataFilter(key="start_year", value=year, operator=FilterOperator.GTE))
|
| 136 |
+
|
| 137 |
+
# Sponsor Pre-Filter
|
| 138 |
+
if sponsor:
|
| 139 |
+
from modules.utils import get_sponsor_variations
|
| 140 |
+
variations = get_sponsor_variations(sponsor)
|
| 141 |
+
if variations:
|
| 142 |
+
print(f"🎯 Applying strict pre-filter for sponsor '{sponsor}' ({len(variations)} variants)")
|
| 143 |
+
# Use 'sponsor' field which is the Lead Sponsor
|
| 144 |
+
pre_filters.append(MetadataFilter(key="sponsor", value=variations, operator=FilterOperator.IN))
|
| 145 |
+
else:
|
| 146 |
+
print(f"⚠️ No strict mapping for sponsor '{sponsor}'. Will rely on fuzzy post-filtering.")
|
| 147 |
+
|
| 148 |
+
metadata_filters = MetadataFilters(filters=pre_filters) if pre_filters else None
|
| 149 |
+
|
| 150 |
+
# Post-processors (Reranking)
|
| 151 |
+
reranker = SentenceTransformerRerank(model="cross-encoder/ms-marco-MiniLM-L-12-v2", top_n=50)
|
| 152 |
+
|
| 153 |
+
# --- HYBRID SEARCH IMPLEMENTATION ---
|
| 154 |
+
# Combine Vector + BM25 using get_hybrid_retriever
|
| 155 |
+
try:
|
| 156 |
+
retriever = get_hybrid_retriever(index, similarity_top_k=TOP_K_STRICT, filters=metadata_filters)
|
| 157 |
+
nodes = retriever.retrieve(query)
|
| 158 |
+
|
| 159 |
+
# (QueryFusionRetriever returns nodes, but we want to rerank them)
|
| 160 |
+
if nodes:
|
| 161 |
+
from llama_index.core.schema import QueryBundle
|
| 162 |
+
nodes = reranker.postprocess_nodes(nodes, query_bundle=QueryBundle(query_str=query))
|
| 163 |
+
|
| 164 |
+
except Exception as e:
|
| 165 |
+
print(f"⚠️ Hybrid search failed: {e}. Falling back to standard vector search.")
|
| 166 |
+
traceback.print_exc()
|
| 167 |
+
query_engine = index.as_query_engine(
|
| 168 |
+
similarity_top_k=TOP_K_STRICT,
|
| 169 |
+
filters=metadata_filters,
|
| 170 |
+
node_postprocessors=[reranker]
|
| 171 |
+
)
|
| 172 |
+
response = query_engine.query(query)
|
| 173 |
+
nodes = response.source_nodes
|
| 174 |
+
|
| 175 |
+
# --- Strict Metadata Filtering (Post-Fusion) ---
|
| 176 |
+
# BM25 results might not respect the vector filters, so filter them out.
|
| 177 |
+
final_nodes = []
|
| 178 |
+
for node in nodes:
|
| 179 |
+
meta = node.metadata
|
| 180 |
+
keep = True
|
| 181 |
+
|
| 182 |
+
# Re-apply filters to ensure BM25 results are valid
|
| 183 |
+
if status and meta.get("status", "").upper() != status.upper():
|
| 184 |
+
keep = False
|
| 185 |
+
if year:
|
| 186 |
+
try:
|
| 187 |
+
if int(meta.get("start_year", 0)) < year:
|
| 188 |
+
keep = False
|
| 189 |
+
except:
|
| 190 |
+
pass
|
| 191 |
+
if sponsor:
|
| 192 |
+
# Strict logic for sponsor in pre-filters is ignored by BM25.
|
| 193 |
+
# Check if the sponsor matches one of the variations OR fuzzy match
|
| 194 |
+
# If strict variations exist, enforce them.
|
| 195 |
+
variations = get_sponsor_variations(sponsor)
|
| 196 |
+
node_sponsor = meta.get("sponsor", "")
|
| 197 |
+
# Fallback to org if sponsor is missing (legacy data)
|
| 198 |
+
if not node_sponsor:
|
| 199 |
+
node_sponsor = meta.get("org", "")
|
| 200 |
+
|
| 201 |
+
if variations:
|
| 202 |
+
if node_sponsor not in variations:
|
| 203 |
+
keep = False
|
| 204 |
+
else:
|
| 205 |
+
# Fuzzy fallback
|
| 206 |
+
if normalize_sponsor(sponsor).lower() not in normalize_sponsor(node_sponsor).lower():
|
| 207 |
+
keep = False
|
| 208 |
+
|
| 209 |
+
if keep:
|
| 210 |
+
final_nodes.append(node)
|
| 211 |
+
|
| 212 |
+
nodes = final_nodes
|
| 213 |
+
|
| 214 |
+
# --- Strict Keyword Filtering ---
|
| 215 |
+
# BM25 handles keyword relevance naturally, so rely on the Hybrid Search + Reranker
|
| 216 |
+
# rather than applying an aggressive substring check here.
|
| 217 |
+
|
| 218 |
+
# Update response object structure to match expected format if we used retriever
|
| 219 |
+
class MockResponse:
|
| 220 |
+
def __init__(self, nodes):
|
| 221 |
+
self.source_nodes = nodes
|
| 222 |
+
|
| 223 |
+
response = MockResponse(nodes)
|
| 224 |
+
|
| 225 |
+
# --- Strategy 2: Hybrid Search (Fallback) ---
|
| 226 |
+
# Hybrid Search is enabled by default.
|
| 227 |
+
# Strict filters are handled in post-processing above.
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# --- Formatting Output ---
|
| 231 |
+
if not response.source_nodes:
|
| 232 |
+
return "No matching studies found. Try broadening your search terms or filters."
|
| 233 |
+
|
| 234 |
+
# Filter by Relevance Score for display
|
| 235 |
+
MIN_SCORE = 1.5
|
| 236 |
+
relevant_nodes = [node for node in response.source_nodes if node.score > MIN_SCORE]
|
| 237 |
+
|
| 238 |
+
# If strict filtering removes too much, show at least top 3 to be helpful
|
| 239 |
+
if len(relevant_nodes) < 3 and len(response.source_nodes) > 0:
|
| 240 |
+
relevant_nodes = response.source_nodes[:3]
|
| 241 |
+
|
| 242 |
+
display_limit = 20
|
| 243 |
+
display_nodes = relevant_nodes[:display_limit]
|
| 244 |
+
|
| 245 |
+
results = []
|
| 246 |
+
for node in display_nodes:
|
| 247 |
+
meta = node.metadata
|
| 248 |
+
entry = (
|
| 249 |
+
f"**{meta.get('title', 'Untitled')}**\n"
|
| 250 |
+
f" - ID: {meta.get('nct_id')}\n"
|
| 251 |
+
f" - Phase: {meta.get('phase', 'N/A')}\n"
|
| 252 |
+
f" - Status: {meta.get('status', 'N/A')}\n"
|
| 253 |
+
f" - Sponsor: {meta.get('sponsor', meta.get('org', 'Unknown'))}\n"
|
| 254 |
+
f" - Relevance: {node.score:.2f}"
|
| 255 |
+
)
|
| 256 |
+
results.append(entry)
|
| 257 |
+
|
| 258 |
+
return f"Found {len(results)} relevant studies:\n\n" + "\n\n".join(results)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
@langchain_tool("find_similar_studies")
|
| 262 |
+
def find_similar_studies(query: str):
|
| 263 |
+
"""
|
| 264 |
+
Finds studies semantically similar to a given query or study description.
|
| 265 |
+
|
| 266 |
+
This tool is useful for "more like this" functionality. It relies purely
|
| 267 |
+
on vector similarity without strict metadata filtering.
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
query (str): The text to match against (e.g., a study title or description).
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
str: A string containing the top 5 similar studies with their titles and summaries.
|
| 274 |
+
"""
|
| 275 |
+
index = load_index()
|
| 276 |
+
|
| 277 |
+
# 1. Check if query is an NCT ID
|
| 278 |
+
nct_match = re.search(r"NCT\d+", query, re.IGNORECASE)
|
| 279 |
+
target_nct = None
|
| 280 |
+
search_text = query
|
| 281 |
+
|
| 282 |
+
if nct_match:
|
| 283 |
+
target_nct = nct_match.group(0).upper()
|
| 284 |
+
print(f"🎯 Detected NCT ID for similarity: {target_nct}")
|
| 285 |
+
|
| 286 |
+
# Fetch the study content to use as the semantic query
|
| 287 |
+
# Use the vector store directly to get the text
|
| 288 |
+
retriever = index.as_retriever(
|
| 289 |
+
filters=MetadataFilters(
|
| 290 |
+
filters=[MetadataFilter(key="nct_id", value=target_nct, operator=FilterOperator.EQ)]
|
| 291 |
+
),
|
| 292 |
+
similarity_top_k=1
|
| 293 |
+
)
|
| 294 |
+
nodes = retriever.retrieve(target_nct)
|
| 295 |
+
|
| 296 |
+
if nodes:
|
| 297 |
+
# Use the study's text (Title + Summary) as the query
|
| 298 |
+
search_text = nodes[0].text
|
| 299 |
+
print(f"✅ Found study content. Using {len(search_text)} chars for semantic search.")
|
| 300 |
+
else:
|
| 301 |
+
print(f"⚠️ Study {target_nct} not found. Falling back to text search.")
|
| 302 |
+
|
| 303 |
+
# 2. Perform Semantic Search
|
| 304 |
+
# Fetch more candidates (10) to allow for filtering
|
| 305 |
+
retriever = index.as_retriever(similarity_top_k=10)
|
| 306 |
+
nodes = retriever.retrieve(search_text)
|
| 307 |
+
|
| 308 |
+
results = []
|
| 309 |
+
count = 0
|
| 310 |
+
for node in nodes:
|
| 311 |
+
# 3. Self-Exclusion
|
| 312 |
+
if target_nct and node.metadata.get("nct_id") == target_nct:
|
| 313 |
+
continue
|
| 314 |
+
|
| 315 |
+
# Deduplication (if multiple chunks of same study appear)
|
| 316 |
+
if any(r["nct_id"] == node.metadata.get("nct_id") for r in results):
|
| 317 |
+
continue
|
| 318 |
+
|
| 319 |
+
results.append({
|
| 320 |
+
"nct_id": node.metadata.get("nct_id"),
|
| 321 |
+
"text": f"Study: {node.metadata['title']} (NCT: {node.metadata.get('nct_id')})\nScore: {node.score:.4f}\nSummary: {node.text[:200]}..."
|
| 322 |
+
})
|
| 323 |
+
|
| 324 |
+
count += 1
|
| 325 |
+
if count >= 5: # Limit to top 5 unique results
|
| 326 |
+
break
|
| 327 |
+
|
| 328 |
+
if not results:
|
| 329 |
+
return "No similar studies found."
|
| 330 |
+
|
| 331 |
+
return "\n\n".join([r["text"] for r in results])
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def fetch_study_analytics_data(
|
| 335 |
+
query: str,
|
| 336 |
+
group_by: str,
|
| 337 |
+
phase: Optional[str] = None,
|
| 338 |
+
status: Optional[str] = None,
|
| 339 |
+
sponsor: Optional[str] = None,
|
| 340 |
+
intervention: Optional[str] = None,
|
| 341 |
+
start_year: Optional[int] = None,
|
| 342 |
+
study_type: Optional[str] = None,
|
| 343 |
+
) -> str:
|
| 344 |
+
"""
|
| 345 |
+
Underlying logic for fetching and aggregating clinical trial data.
|
| 346 |
+
See get_study_analytics for full docstring.
|
| 347 |
+
"""
|
| 348 |
+
index = load_index()
|
| 349 |
+
|
| 350 |
+
# 1. Retrieve Data
|
| 351 |
+
if query.lower() == "overall":
|
| 352 |
+
try:
|
| 353 |
+
# Connect to LanceDB directly for speed
|
| 354 |
+
import lancedb
|
| 355 |
+
db = lancedb.connect("./ct_gov_lancedb")
|
| 356 |
+
tbl = db.open_table("clinical_trials")
|
| 357 |
+
# Fetch all data as pandas DataFrame
|
| 358 |
+
df = tbl.to_pandas()
|
| 359 |
+
|
| 360 |
+
# LlamaIndex stores metadata in a 'metadata' column (usually as a dict/struct)
|
| 361 |
+
# We need to flatten it to get columns like 'status', 'phase', etc.
|
| 362 |
+
if "metadata" in df.columns:
|
| 363 |
+
# Check if it's already a dict or needs parsing
|
| 364 |
+
# LanceDB to_pandas() converts struct to dict
|
| 365 |
+
meta_df = pd.json_normalize(df["metadata"])
|
| 366 |
+
df = meta_df
|
| 367 |
+
|
| 368 |
+
# If columns are already flat (depending on schema evolution), we are good.
|
| 369 |
+
# But usually it's nested.
|
| 370 |
+
|
| 371 |
+
except Exception as e:
|
| 372 |
+
return f"Error fetching full dataset: {e}"
|
| 373 |
+
else:
|
| 374 |
+
filters = []
|
| 375 |
+
if status:
|
| 376 |
+
filters.append(
|
| 377 |
+
MetadataFilter(
|
| 378 |
+
key="status", value=status.upper(), operator=FilterOperator.EQ
|
| 379 |
+
)
|
| 380 |
+
)
|
| 381 |
+
if phase and "," not in phase:
|
| 382 |
+
pass
|
| 383 |
+
|
| 384 |
+
if sponsor:
|
| 385 |
+
# Use the helper to get all variations (e.g. "Pfizer" -> ["Pfizer", "Pfizer Inc."])
|
| 386 |
+
sponsor_variations = get_sponsor_variations(sponsor)
|
| 387 |
+
if sponsor_variations:
|
| 388 |
+
print(f"🎯 Using strict pre-filter for sponsor '{sponsor}': {len(sponsor_variations)} variations found.")
|
| 389 |
+
filters.append(
|
| 390 |
+
MetadataFilter(
|
| 391 |
+
key="sponsor", value=sponsor_variations, operator=FilterOperator.IN
|
| 392 |
+
)
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
metadata_filters = MetadataFilters(filters=filters) if filters else None
|
| 396 |
+
|
| 397 |
+
search_query = query
|
| 398 |
+
if sponsor and sponsor.lower() not in query.lower():
|
| 399 |
+
search_query = f"{sponsor} {query}"
|
| 400 |
+
|
| 401 |
+
# Use hybrid search for better recall
|
| 402 |
+
retriever = index.as_retriever(
|
| 403 |
+
similarity_top_k=5000,
|
| 404 |
+
filters=metadata_filters,
|
| 405 |
+
vector_store_query_mode="hybrid"
|
| 406 |
+
)
|
| 407 |
+
nodes = retriever.retrieve(search_query)
|
| 408 |
+
|
| 409 |
+
# --- Strict Keyword Filtering ---
|
| 410 |
+
# Strictly check if the query appears in Title or Conditions to ensure accurate counting.
|
| 411 |
+
# EXCEPTION: If the query matches the requested sponsor, we also check the 'org' field.
|
| 412 |
+
if query.lower() != "overall":
|
| 413 |
+
q_term = query.lower()
|
| 414 |
+
|
| 415 |
+
# Check if the query is essentially the sponsor name
|
| 416 |
+
is_sponsor_query = False
|
| 417 |
+
|
| 418 |
+
# Check if the query itself normalizes to a known sponsor
|
| 419 |
+
query_normalized = normalize_sponsor(query)
|
| 420 |
+
if query_normalized and query_normalized != query:
|
| 421 |
+
# If normalization changed it (or found a mapping), it's likely a sponsor
|
| 422 |
+
is_sponsor_query = True
|
| 423 |
+
|
| 424 |
+
if sponsor:
|
| 425 |
+
# Normalize both to see if they refer to the same entity
|
| 426 |
+
norm_query = normalize_sponsor(query)
|
| 427 |
+
norm_sponsor = normalize_sponsor(sponsor)
|
| 428 |
+
|
| 429 |
+
if norm_query and norm_sponsor and norm_query.lower() == norm_sponsor.lower():
|
| 430 |
+
is_sponsor_query = True
|
| 431 |
+
elif sponsor.lower() in query.lower() or query.lower() in sponsor.lower():
|
| 432 |
+
is_sponsor_query = True
|
| 433 |
+
|
| 434 |
+
filtered_nodes = []
|
| 435 |
+
for node in nodes:
|
| 436 |
+
meta = node.metadata
|
| 437 |
+
title = meta.get("title", "").lower()
|
| 438 |
+
conditions = meta.get("condition", "").lower() # Note: key is 'condition' in DB
|
| 439 |
+
org = meta.get("org", "").lower()
|
| 440 |
+
sponsor_val = meta.get("sponsor", "").lower()
|
| 441 |
+
|
| 442 |
+
# If it's a sponsor query, we allow matches on the Organization OR Sponsor field
|
| 443 |
+
# AND we check if the normalized values match (handling aliases like J&J -> Janssen)
|
| 444 |
+
match = False
|
| 445 |
+
if q_term in title or q_term in conditions:
|
| 446 |
+
match = True
|
| 447 |
+
elif is_sponsor_query:
|
| 448 |
+
# Check raw match
|
| 449 |
+
if q_term in org or q_term in sponsor_val:
|
| 450 |
+
match = True
|
| 451 |
+
else:
|
| 452 |
+
# Check normalized match
|
| 453 |
+
norm_org = normalize_sponsor(org)
|
| 454 |
+
norm_val = normalize_sponsor(sponsor_val)
|
| 455 |
+
|
| 456 |
+
# Compare against the normalized query (which is the sponsor in this case)
|
| 457 |
+
target_norm = norm_sponsor if sponsor else query_normalized
|
| 458 |
+
|
| 459 |
+
if norm_org and target_norm and norm_org.lower() == target_norm.lower():
|
| 460 |
+
match = True
|
| 461 |
+
elif norm_val and target_norm and norm_val.lower() == target_norm.lower():
|
| 462 |
+
match = True
|
| 463 |
+
|
| 464 |
+
if match:
|
| 465 |
+
filtered_nodes.append(node)
|
| 466 |
+
|
| 467 |
+
print(f"📉 Strict Filter: {len(nodes)} -> {len(filtered_nodes)} nodes for '{query}'")
|
| 468 |
+
nodes = filtered_nodes
|
| 469 |
+
|
| 470 |
+
data = [node.metadata for node in nodes]
|
| 471 |
+
df = pd.DataFrame(data)
|
| 472 |
+
|
| 473 |
+
if "nct_id" in df.columns:
|
| 474 |
+
df = df.drop_duplicates(subset="nct_id")
|
| 475 |
+
|
| 476 |
+
if df.empty:
|
| 477 |
+
return "No studies found for analytics."
|
| 478 |
+
|
| 479 |
+
# --- APPLY FILTERS (Pandas) ---
|
| 480 |
+
if phase:
|
| 481 |
+
target_phases = [p.strip().upper().replace(" ", "") for p in phase.split(",")]
|
| 482 |
+
df["phase_upper"] = df["phase"].astype(str).str.upper().str.replace(" ", "")
|
| 483 |
+
mask = df["phase_upper"].apply(lambda x: any(tp in x for tp in target_phases))
|
| 484 |
+
df = df[mask]
|
| 485 |
+
|
| 486 |
+
if status:
|
| 487 |
+
df = df[df["status"].str.upper() == status.upper()]
|
| 488 |
+
|
| 489 |
+
if sponsor:
|
| 490 |
+
target_sponsor = normalize_sponsor(sponsor).lower()
|
| 491 |
+
# Use 'sponsor' column if it exists, otherwise fallback to 'org'
|
| 492 |
+
if "sponsor" in df.columns:
|
| 493 |
+
df["sponsor_check"] = df["sponsor"].fillna(df["org"]).astype(str).apply(normalize_sponsor).str.lower()
|
| 494 |
+
else:
|
| 495 |
+
df["sponsor_check"] = df["org"].astype(str).apply(normalize_sponsor).str.lower()
|
| 496 |
+
|
| 497 |
+
df = df[df["sponsor_check"].str.contains(target_sponsor, regex=False)]
|
| 498 |
+
|
| 499 |
+
if intervention:
|
| 500 |
+
target_intervention = intervention.lower()
|
| 501 |
+
df["intervention_lower"] = df["intervention"].astype(str).str.lower()
|
| 502 |
+
df = df[df["intervention_lower"].str.contains(target_intervention, regex=False)]
|
| 503 |
+
|
| 504 |
+
if start_year:
|
| 505 |
+
df["start_year"] = pd.to_numeric(df["start_year"], errors="coerce").fillna(0)
|
| 506 |
+
df = df[df["start_year"] >= start_year]
|
| 507 |
+
|
| 508 |
+
if study_type:
|
| 509 |
+
df = df[df["study_type"].str.upper() == study_type.upper()]
|
| 510 |
+
|
| 511 |
+
if df.empty:
|
| 512 |
+
return "No studies found after applying filters."
|
| 513 |
+
|
| 514 |
+
key_map = {
|
| 515 |
+
"phase": "phase",
|
| 516 |
+
"status": "status",
|
| 517 |
+
"sponsor": "sponsor" if "sponsor" in df.columns else "org",
|
| 518 |
+
"start_year": "start_year",
|
| 519 |
+
"condition": "condition",
|
| 520 |
+
"intervention": "intervention",
|
| 521 |
+
"study_type": "study_type",
|
| 522 |
+
"country": "country",
|
| 523 |
+
"state": "state",
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
if group_by not in key_map:
|
| 527 |
+
return f"Invalid group_by field: {group_by}. Valid options: phase, status, sponsor, start_year, condition, intervention, study_type, country, state"
|
| 528 |
+
|
| 529 |
+
col = key_map[group_by]
|
| 530 |
+
|
| 531 |
+
if col == "start_year":
|
| 532 |
+
df[col] = pd.to_numeric(df[col], errors="coerce")
|
| 533 |
+
counts = df[col].value_counts().sort_index()
|
| 534 |
+
elif col == "condition":
|
| 535 |
+
counts = df[col].astype(str).str.split(", ").explode().value_counts().head(10)
|
| 536 |
+
elif col == "intervention":
|
| 537 |
+
all_interventions = []
|
| 538 |
+
for interventions in df[col].dropna():
|
| 539 |
+
parts = [i.strip() for i in interventions.split(";") if i.strip()]
|
| 540 |
+
all_interventions.extend(parts)
|
| 541 |
+
counts = pd.Series(all_interventions).value_counts().head(10)
|
| 542 |
+
else:
|
| 543 |
+
counts = df[col].value_counts().head(10)
|
| 544 |
+
|
| 545 |
+
summary = counts.to_string()
|
| 546 |
+
|
| 547 |
+
chart_df = counts.reset_index()
|
| 548 |
+
chart_df.columns = ["category", "count"]
|
| 549 |
+
|
| 550 |
+
chart_data = {
|
| 551 |
+
"type": "bar",
|
| 552 |
+
"title": f"Studies by {group_by.capitalize()}",
|
| 553 |
+
"data": chart_df.to_dict("records"),
|
| 554 |
+
"x": "category",
|
| 555 |
+
"y": "count",
|
| 556 |
+
}
|
| 557 |
+
|
| 558 |
+
if "inline_chart_data" not in st.session_state:
|
| 559 |
+
st.session_state["inline_chart_data"] = chart_data
|
| 560 |
+
else:
|
| 561 |
+
st.session_state["inline_chart_data"] = chart_data
|
| 562 |
+
|
| 563 |
+
return f"Found {len(df)} studies. Top counts:\n{summary}\n\n(Chart generated in UI)"
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
@langchain_tool("get_study_analytics")
|
| 567 |
+
def get_study_analytics(
|
| 568 |
+
query: str,
|
| 569 |
+
group_by: str,
|
| 570 |
+
phase: Optional[str] = None,
|
| 571 |
+
status: Optional[str] = None,
|
| 572 |
+
sponsor: Optional[str] = None,
|
| 573 |
+
intervention: Optional[str] = None,
|
| 574 |
+
start_year: Optional[int] = None,
|
| 575 |
+
study_type: Optional[str] = None,
|
| 576 |
+
):
|
| 577 |
+
"""
|
| 578 |
+
Aggregates clinical trial data based on a search query and groups by a specific field.
|
| 579 |
+
|
| 580 |
+
This tool performs the following steps:
|
| 581 |
+
1. Retrieves a large number of relevant studies (up to 500).
|
| 582 |
+
2. Applies strict filters (Phase, Status, Sponsor) in memory (Pandas).
|
| 583 |
+
3. Groups the data by the requested field (e.g., Sponsor).
|
| 584 |
+
4. Generates a summary string for the LLM.
|
| 585 |
+
5. **Side Effect**: Injects chart data into `st.session_state` to trigger an inline chart in the UI.
|
| 586 |
+
|
| 587 |
+
Args:
|
| 588 |
+
query (str): The search query to filter studies (e.g., "cancer").
|
| 589 |
+
group_by (str): The field to group by. Options: "phase", "status", "sponsor", "start_year", "condition".
|
| 590 |
+
phase (Optional[str]): Optional filter for phase (e.g., "PHASE2").
|
| 591 |
+
status (Optional[str]): Optional filter for status (e.g., "RECRUITING").
|
| 592 |
+
sponsor (Optional[str]): Optional filter for sponsor (e.g., "Pfizer").
|
| 593 |
+
intervention (Optional[str]): Optional filter for intervention (e.g., "Keytruda").
|
| 594 |
+
|
| 595 |
+
Returns:
|
| 596 |
+
str: A summary string of the top counts and a note that a chart has been generated.
|
| 597 |
+
"""
|
| 598 |
+
return fetch_study_analytics_data(
|
| 599 |
+
query=query,
|
| 600 |
+
group_by=group_by,
|
| 601 |
+
phase=phase,
|
| 602 |
+
status=status,
|
| 603 |
+
sponsor=sponsor,
|
| 604 |
+
intervention=intervention,
|
| 605 |
+
start_year=start_year,
|
| 606 |
+
study_type=study_type,
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
@langchain_tool("compare_studies")
|
| 611 |
+
def compare_studies(query: str):
|
| 612 |
+
"""
|
| 613 |
+
Compares multiple studies or answers complex multi-part questions using query decomposition.
|
| 614 |
+
|
| 615 |
+
Use this tool when the user asks to "compare", "contrast", or analyze differences/similarities
|
| 616 |
+
between specific studies, sponsors, or phases. It breaks down the question into sub-questions.
|
| 617 |
+
|
| 618 |
+
Args:
|
| 619 |
+
query (str): The complex comparison query (e.g., "Compare the primary outcomes of Keytruda vs Opdivo").
|
| 620 |
+
|
| 621 |
+
Returns:
|
| 622 |
+
str: A detailed response synthesizing the answers to sub-questions.
|
| 623 |
+
"""
|
| 624 |
+
index = load_index()
|
| 625 |
+
|
| 626 |
+
# Create a base query engine for the sub-questions
|
| 627 |
+
# Increase top_k and add re-ranking to improve recall for comparison queries
|
| 628 |
+
reranker = SentenceTransformerRerank(model="cross-encoder/ms-marco-MiniLM-L-12-v2", top_n=10)
|
| 629 |
+
|
| 630 |
+
base_engine = index.as_query_engine(
|
| 631 |
+
similarity_top_k=50,
|
| 632 |
+
node_postprocessors=[reranker]
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
# Wrap it in a QueryEngineTool
|
| 636 |
+
query_tool = QueryEngineTool(
|
| 637 |
+
query_engine=base_engine,
|
| 638 |
+
metadata=ToolMetadata(
|
| 639 |
+
name="clinical_trials_db",
|
| 640 |
+
description="Vector database of clinical trial protocols, results, and metadata.",
|
| 641 |
+
),
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
# Create the SubQuestionQueryEngine
|
| 645 |
+
# Explicitly define the question generator to use the configured LLM (Gemini)
|
| 646 |
+
# This avoids the default behavior which might try to import OpenAI modules
|
| 647 |
+
from llama_index.core.question_gen import LLMQuestionGenerator
|
| 648 |
+
from llama_index.core import Settings
|
| 649 |
+
|
| 650 |
+
question_gen = LLMQuestionGenerator.from_defaults(llm=Settings.llm)
|
| 651 |
+
|
| 652 |
+
query_engine = SubQuestionQueryEngine.from_defaults(
|
| 653 |
+
query_engine_tools=[query_tool],
|
| 654 |
+
question_gen=question_gen,
|
| 655 |
+
use_async=True,
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
try:
|
| 659 |
+
response = query_engine.query(query)
|
| 660 |
+
return str(response) + "\n\n(Note: This analysis is based on the most relevant studies retrieved from the database, not necessarily an exhaustive list.)"
|
| 661 |
+
except Exception as e:
|
| 662 |
+
return f"Error during comparison: {e}"
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
@langchain_tool("get_study_details")
|
| 666 |
+
def get_study_details(nct_id: str):
|
| 667 |
+
"""
|
| 668 |
+
Retrieves the full details of a specific clinical trial by its NCT ID.
|
| 669 |
+
|
| 670 |
+
Use this tool when the user asks for specific information about a single study,
|
| 671 |
+
such as "What are the inclusion criteria for NCT12345678?" or "Give me a summary of study NCT...".
|
| 672 |
+
It returns the full text content of the study document, including criteria, outcomes, and contacts.
|
| 673 |
+
|
| 674 |
+
Args:
|
| 675 |
+
nct_id (str): The NCT ID of the study (e.g., "NCT01234567").
|
| 676 |
+
|
| 677 |
+
Returns:
|
| 678 |
+
str: The full text content of the study, or a message if not found.
|
| 679 |
+
"""
|
| 680 |
+
index = load_index()
|
| 681 |
+
|
| 682 |
+
# Clean the ID
|
| 683 |
+
clean_id = nct_id.strip().upper()
|
| 684 |
+
|
| 685 |
+
# Use a retriever with a strict metadata filter for the ID
|
| 686 |
+
# Set top_k=20 to capture all chunks if the document was split
|
| 687 |
+
filters = MetadataFilters(
|
| 688 |
+
filters=[
|
| 689 |
+
MetadataFilter(key="nct_id", value=clean_id, operator=FilterOperator.EQ)
|
| 690 |
+
]
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
retriever = index.as_retriever(similarity_top_k=20, filters=filters)
|
| 694 |
+
nodes = retriever.retrieve(clean_id)
|
| 695 |
+
|
| 696 |
+
if not nodes:
|
| 697 |
+
return f"Study {clean_id} not found in the database."
|
| 698 |
+
|
| 699 |
+
# Sort nodes by their position in the document to reconstruct full text
|
| 700 |
+
# LlamaIndex nodes usually have 'start_char_idx' in metadata or relationships
|
| 701 |
+
# Try to sort by node ID or just concatenate them
|
| 702 |
+
|
| 703 |
+
# Simple concatenation (assuming retrieval order is roughly correct or sufficient)
|
| 704 |
+
full_text = "\n\n".join([node.text for node in nodes])
|
| 705 |
+
|
| 706 |
+
return f"Details for {clean_id} (Combined {len(nodes)} parts):\n\n{full_text}"
|
modules/utils.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions for the Clinical Trial Agent.
|
| 3 |
+
|
| 4 |
+
Handles configuration, LanceDB index loading, data normalization, and custom filtering logic.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import streamlit as st
|
| 9 |
+
from typing import List, Optional
|
| 10 |
+
from llama_index.core import VectorStoreIndex, StorageContext, Settings
|
| 11 |
+
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
| 12 |
+
from llama_index.vector_stores.lancedb import LanceDBVectorStore
|
| 13 |
+
from llama_index.llms.gemini import Gemini
|
| 14 |
+
import lancedb
|
| 15 |
+
from dotenv import load_dotenv
|
| 16 |
+
|
| 17 |
+
# --- MONKEYPATCH START ---
|
| 18 |
+
# Patch LanceDBVectorStore to handle 'nprobes' AttributeError and fix SQL quoting for IN filters.
|
| 19 |
+
original_query = LanceDBVectorStore.query
|
| 20 |
+
|
| 21 |
+
def patched_query(self, query, **kwargs):
|
| 22 |
+
try:
|
| 23 |
+
return original_query(self, query, **kwargs)
|
| 24 |
+
except Exception as e:
|
| 25 |
+
print(f"⚠️ LanceDB Query Error: {e}")
|
| 26 |
+
if hasattr(query, "filters"):
|
| 27 |
+
print(f" Filters: {query.filters}")
|
| 28 |
+
|
| 29 |
+
if "nprobes" in str(e):
|
| 30 |
+
from llama_index.core.vector_stores.types import VectorStoreQueryResult
|
| 31 |
+
return VectorStoreQueryResult(nodes=[], similarities=[], ids=[])
|
| 32 |
+
raise e
|
| 33 |
+
|
| 34 |
+
LanceDBVectorStore.query = patched_query
|
| 35 |
+
|
| 36 |
+
# Patch _to_lance_filter to fix SQL quoting for IN operator with strings.
|
| 37 |
+
from llama_index.vector_stores.lancedb import base as lancedb_base
|
| 38 |
+
from llama_index.core.vector_stores.types import FilterOperator
|
| 39 |
+
|
| 40 |
+
original_to_lance_filter = lancedb_base._to_lance_filter
|
| 41 |
+
|
| 42 |
+
def patched_to_lance_filter(standard_filters, metadata_keys):
|
| 43 |
+
if not standard_filters:
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
# Reimplement filter logic to ensure correct SQL generation for LanceDB
|
| 47 |
+
filters = []
|
| 48 |
+
for filter in standard_filters.filters:
|
| 49 |
+
key = filter.key
|
| 50 |
+
if metadata_keys and key not in metadata_keys:
|
| 51 |
+
continue
|
| 52 |
+
|
| 53 |
+
# Prefix key with 'metadata.' for LanceDB struct column
|
| 54 |
+
lance_key = f"metadata.{key}"
|
| 55 |
+
|
| 56 |
+
# Handle IN operator with proper string quoting
|
| 57 |
+
if filter.operator == FilterOperator.IN:
|
| 58 |
+
if isinstance(filter.value, list):
|
| 59 |
+
# Quote strings properly
|
| 60 |
+
values = []
|
| 61 |
+
for v in filter.value:
|
| 62 |
+
if isinstance(v, str):
|
| 63 |
+
values.append(f"'{v}'") # Single quotes for SQL
|
| 64 |
+
else:
|
| 65 |
+
values.append(str(v))
|
| 66 |
+
val_str = ", ".join(values)
|
| 67 |
+
filters.append(f"{lance_key} IN ({val_str})")
|
| 68 |
+
continue
|
| 69 |
+
|
| 70 |
+
# Standard operators
|
| 71 |
+
op = filter.operator
|
| 72 |
+
val = filter.value
|
| 73 |
+
|
| 74 |
+
if op == FilterOperator.EQ:
|
| 75 |
+
if isinstance(val, str):
|
| 76 |
+
filters.append(f"{lance_key} = '{val}'")
|
| 77 |
+
else:
|
| 78 |
+
filters.append(f"{lance_key} = {val}")
|
| 79 |
+
elif op == FilterOperator.GT:
|
| 80 |
+
filters.append(f"{lance_key} > {val}")
|
| 81 |
+
elif op == FilterOperator.LT:
|
| 82 |
+
filters.append(f"{lance_key} < {val}")
|
| 83 |
+
elif op == FilterOperator.GTE:
|
| 84 |
+
filters.append(f"{lance_key} >= {val}")
|
| 85 |
+
elif op == FilterOperator.LTE:
|
| 86 |
+
filters.append(f"{lance_key} <= {val}")
|
| 87 |
+
elif op == FilterOperator.NE:
|
| 88 |
+
if isinstance(val, str):
|
| 89 |
+
filters.append(f"{lance_key} != '{val}'")
|
| 90 |
+
else:
|
| 91 |
+
filters.append(f"{lance_key} != {val}")
|
| 92 |
+
# Add other operators as needed
|
| 93 |
+
|
| 94 |
+
if not filters:
|
| 95 |
+
return None
|
| 96 |
+
|
| 97 |
+
return " AND ".join(filters)
|
| 98 |
+
|
| 99 |
+
lancedb_base._to_lance_filter = patched_to_lance_filter
|
| 100 |
+
# --- MONKEYPATCH END ---
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def load_environment():
|
| 104 |
+
"""Loads environment variables from .env file."""
|
| 105 |
+
load_dotenv()
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# --- Configuration ---
|
| 109 |
+
def setup_llama_index(api_key: Optional[str] = None):
|
| 110 |
+
"""
|
| 111 |
+
Configures global LlamaIndex settings (LLM and Embeddings).
|
| 112 |
+
"""
|
| 113 |
+
# Use passed key, or fallback to env var
|
| 114 |
+
final_key = api_key or os.environ.get("GOOGLE_API_KEY")
|
| 115 |
+
|
| 116 |
+
if not final_key:
|
| 117 |
+
# App handles prompting for key, so we just return or log warning
|
| 118 |
+
pass
|
| 119 |
+
|
| 120 |
+
try:
|
| 121 |
+
# Pass the key explicitly if available
|
| 122 |
+
Settings.llm = Gemini(model="models/gemini-2.5-flash", temperature=0, api_key=final_key)
|
| 123 |
+
except Exception as e:
|
| 124 |
+
print(f"⚠️ LLM initialization failed (likely missing API key): {e}")
|
| 125 |
+
print("⚠️ Using MockLLM for testing/fallback.")
|
| 126 |
+
from llama_index.core.llms import MockLLM
|
| 127 |
+
Settings.llm = MockLLM()
|
| 128 |
+
|
| 129 |
+
Settings.embed_model = HuggingFaceEmbedding(
|
| 130 |
+
model_name="pritamdeka/S-PubMedBert-MS-MARCO"
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@st.cache_resource
|
| 135 |
+
def load_index() -> VectorStoreIndex:
|
| 136 |
+
"""
|
| 137 |
+
Loads and caches the persistent LanceDB index.
|
| 138 |
+
"""
|
| 139 |
+
setup_llama_index()
|
| 140 |
+
|
| 141 |
+
# Initialize LanceDB
|
| 142 |
+
db_path = "./ct_gov_lancedb"
|
| 143 |
+
db = lancedb.connect(db_path)
|
| 144 |
+
|
| 145 |
+
# Define metadata keys explicitly to ensure filters work
|
| 146 |
+
metadata_keys = [
|
| 147 |
+
"nct_id", "title", "org", "sponsor", "status", "phase",
|
| 148 |
+
"study_type", "start_year", "condition", "intervention",
|
| 149 |
+
"country", "state"
|
| 150 |
+
]
|
| 151 |
+
|
| 152 |
+
# Create the vector store wrapper
|
| 153 |
+
vector_store = LanceDBVectorStore(
|
| 154 |
+
uri=db_path,
|
| 155 |
+
table_name="clinical_trials",
|
| 156 |
+
query_mode="hybrid",
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Manually set metadata keys as constructor doesn't accept them
|
| 160 |
+
vector_store._metadata_keys = metadata_keys
|
| 161 |
+
|
| 162 |
+
# Create storage context
|
| 163 |
+
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
| 164 |
+
|
| 165 |
+
# Load the index from the vector store
|
| 166 |
+
index = VectorStoreIndex.from_vector_store(
|
| 167 |
+
vector_store, storage_context=storage_context
|
| 168 |
+
)
|
| 169 |
+
return index
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def get_hybrid_retriever(index: VectorStoreIndex, similarity_top_k: int = 50, filters=None):
|
| 173 |
+
"""
|
| 174 |
+
Creates a Hybrid Retriever using LanceDB's native hybrid search.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
index (VectorStoreIndex): The loaded vector index.
|
| 178 |
+
similarity_top_k (int): Number of top results to retrieve.
|
| 179 |
+
filters (MetadataFilters, optional): Filters to apply.
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
VectorIndexRetriever: The configured retriever.
|
| 183 |
+
"""
|
| 184 |
+
# LanceDB supports native hybrid search via query_mode="hybrid"
|
| 185 |
+
# We pass this configuration to the retriever
|
| 186 |
+
# Use standard retriever first to avoid LanceDB hybrid search issues on small datasets
|
| 187 |
+
return index.as_retriever(
|
| 188 |
+
similarity_top_k=similarity_top_k,
|
| 189 |
+
filters=filters,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# --- Normalization ---
|
| 194 |
+
|
| 195 |
+
# Centralized Sponsor Mappings
|
| 196 |
+
# Key: Canonical Name
|
| 197 |
+
# Value: List of variations/aliases (including the canonical name itself if needed for matching)
|
| 198 |
+
SPONSOR_MAPPINGS = {
|
| 199 |
+
"GlaxoSmithKline": [
|
| 200 |
+
"gsk", "glaxo", "glaxosmithkline", "glaxosmithkline",
|
| 201 |
+
"GlaxoSmithKline"
|
| 202 |
+
],
|
| 203 |
+
"Janssen": [
|
| 204 |
+
"j&j", "johnson & johnson", "johnson and johnson", "janssen", "Janssen",
|
| 205 |
+
"Janssen Research & Development, LLC",
|
| 206 |
+
"Janssen Vaccines & Prevention B.V.",
|
| 207 |
+
"Janssen Pharmaceutical K.K.",
|
| 208 |
+
"Janssen-Cilag International NV",
|
| 209 |
+
"Janssen Sciences Ireland UC",
|
| 210 |
+
"Janssen Pharmaceutica N.V., Belgium",
|
| 211 |
+
"Janssen Scientific Affairs, LLC",
|
| 212 |
+
"Janssen-Cilag Ltd.",
|
| 213 |
+
"Xian-Janssen Pharmaceutical Ltd.",
|
| 214 |
+
"Janssen Korea, Ltd., Korea",
|
| 215 |
+
"Janssen-Cilag G.m.b.H",
|
| 216 |
+
"Janssen-Cilag, S.A.",
|
| 217 |
+
"Janssen BioPharma, Inc.",
|
| 218 |
+
],
|
| 219 |
+
"Bristol-Myers Squibb": [
|
| 220 |
+
"bms", "bristol", "bristol myers squibb", "bristol-myers squibb",
|
| 221 |
+
"Bristol-Myers Squibb"
|
| 222 |
+
],
|
| 223 |
+
"Merck Sharp & Dohme": [
|
| 224 |
+
"merck", "msd", "merck sharp & dohme",
|
| 225 |
+
"Merck Sharp & Dohme LLC"
|
| 226 |
+
],
|
| 227 |
+
"Pfizer": ["pfizer", "Pfizer", "Pfizer Inc."],
|
| 228 |
+
"AstraZeneca": ["astrazeneca", "AstraZeneca"],
|
| 229 |
+
"Eli Lilly and Company": ["lilly", "eli lilly", "Eli Lilly and Company"],
|
| 230 |
+
"Sanofi": ["sanofi", "Sanofi"],
|
| 231 |
+
"Novartis": ["novartis", "Novartis"],
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
def normalize_sponsor(sponsor: str) -> Optional[str]:
|
| 235 |
+
"""
|
| 236 |
+
Normalizes sponsor names to canonical forms using centralized mappings.
|
| 237 |
+
"""
|
| 238 |
+
if not sponsor:
|
| 239 |
+
return None
|
| 240 |
+
|
| 241 |
+
s = sponsor.lower().strip()
|
| 242 |
+
|
| 243 |
+
for canonical, variations in SPONSOR_MAPPINGS.items():
|
| 244 |
+
# Check if input matches canonical name (case-insensitive)
|
| 245 |
+
if s == canonical.lower():
|
| 246 |
+
return canonical
|
| 247 |
+
|
| 248 |
+
# Check variations and aliases
|
| 249 |
+
for v in variations:
|
| 250 |
+
v_lower = v.lower()
|
| 251 |
+
if v_lower == s:
|
| 252 |
+
return canonical
|
| 253 |
+
# If the variation is a known alias (like 'gsk'), check if it's in the string
|
| 254 |
+
if len(v) < 5 and v_lower in s:
|
| 255 |
+
return canonical
|
| 256 |
+
|
| 257 |
+
if canonical.lower() in s:
|
| 258 |
+
return canonical
|
| 259 |
+
|
| 260 |
+
return sponsor
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def get_sponsor_variations(sponsor: str) -> Optional[List[str]]:
|
| 264 |
+
"""
|
| 265 |
+
Returns list of exact database 'org' values for a given sponsor alias.
|
| 266 |
+
"""
|
| 267 |
+
if not sponsor:
|
| 268 |
+
return None
|
| 269 |
+
|
| 270 |
+
# First, normalize the input to get the canonical name
|
| 271 |
+
canonical = normalize_sponsor(sponsor)
|
| 272 |
+
|
| 273 |
+
if canonical in SPONSOR_MAPPINGS:
|
| 274 |
+
return SPONSOR_MAPPINGS[canonical]
|
| 275 |
+
|
| 276 |
+
return None
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit
|
| 2 |
+
requests
|
| 3 |
+
python-dotenv
|
| 4 |
+
langchain
|
| 5 |
+
langchain-community
|
| 6 |
+
langchain-google-genai==2.0.0
|
| 7 |
+
lancedb
|
| 8 |
+
lark
|
| 9 |
+
langchain-huggingface
|
| 10 |
+
llama-index
|
| 11 |
+
llama-index-vector-stores-lancedb
|
| 12 |
+
llama-index-embeddings-huggingface
|
| 13 |
+
llama-index-llms-gemini
|
| 14 |
+
streamlit-option-menu
|
| 15 |
+
streamlit-agraph
|
| 16 |
+
folium
|
| 17 |
+
streamlit-folium
|
| 18 |
+
rank_bm25
|
| 19 |
+
llama-index-retrievers-bm25
|
scripts/analyze_db.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Database Analysis Script.
|
| 3 |
+
|
| 4 |
+
This script connects to the local ChromaDB vector store and performs a quick analysis
|
| 5 |
+
of the ingested clinical trial data. It prints statistics about:
|
| 6 |
+
- Top Sponsors
|
| 7 |
+
- Phase Distribution
|
| 8 |
+
- Status Distribution
|
| 9 |
+
- Top Medical Conditions
|
| 10 |
+
- Sample of Recent Studies
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
python scripts/analyze_db.py
|
| 14 |
+
# OR
|
| 15 |
+
cd scripts && python analyze_db.py
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import lancedb
|
| 19 |
+
import pandas as pd
|
| 20 |
+
import os
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def analyze_db():
|
| 24 |
+
"""
|
| 25 |
+
Connects to ChromaDB and prints summary statistics of the dataset.
|
| 26 |
+
"""
|
| 27 |
+
# Determine the project root directory (one level up from this script)
|
| 28 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 29 |
+
project_root = os.path.dirname(script_dir)
|
| 30 |
+
db_path = os.path.join(project_root, "ct_gov_lancedb")
|
| 31 |
+
|
| 32 |
+
if not os.path.exists(db_path):
|
| 33 |
+
print(f"❌ Database directory '{db_path}' does not exist.")
|
| 34 |
+
print(" Please run 'python scripts/ingest_ct.py' first to ingest data.")
|
| 35 |
+
return
|
| 36 |
+
|
| 37 |
+
print(f"📂 Loading database from {db_path}...")
|
| 38 |
+
try:
|
| 39 |
+
db = lancedb.connect(db_path)
|
| 40 |
+
|
| 41 |
+
# Check for table existence
|
| 42 |
+
if "clinical_trials" not in db.table_names():
|
| 43 |
+
print(f"❌ Table 'clinical_trials' not found. Available: {db.table_names()}")
|
| 44 |
+
return
|
| 45 |
+
|
| 46 |
+
tbl = db.open_table("clinical_trials")
|
| 47 |
+
count = len(tbl)
|
| 48 |
+
print(f"✅ Found 'clinical_trials' table with {count} documents.")
|
| 49 |
+
|
| 50 |
+
# Fetch all data for analysis
|
| 51 |
+
df = tbl.to_pandas()
|
| 52 |
+
|
| 53 |
+
if df.empty:
|
| 54 |
+
print("❌ No data found.")
|
| 55 |
+
return
|
| 56 |
+
|
| 57 |
+
# Handle metadata if nested (LlamaIndex might nest it)
|
| 58 |
+
if "metadata" in df.columns:
|
| 59 |
+
# Try to flatten if it's a struct/dict
|
| 60 |
+
try:
|
| 61 |
+
meta_df = pd.json_normalize(df["metadata"])
|
| 62 |
+
# Merge with original df or just use meta_df for analysis
|
| 63 |
+
# We'll use meta_df for the metadata fields analysis
|
| 64 |
+
# But we might need 'text' from original
|
| 65 |
+
df = pd.concat([df.drop(columns=["metadata"]), meta_df], axis=1)
|
| 66 |
+
except:
|
| 67 |
+
pass
|
| 68 |
+
|
| 69 |
+
if "nct_id" in df.columns:
|
| 70 |
+
unique_ncts = df["nct_id"].nunique()
|
| 71 |
+
print(f"🔢 Unique NCT IDs: {unique_ncts}")
|
| 72 |
+
if unique_ncts < count:
|
| 73 |
+
print(f"⚠️ Warning: {count - unique_ncts} duplicate records found!")
|
| 74 |
+
else:
|
| 75 |
+
print("⚠️ 'nct_id' field not found in metadata.")
|
| 76 |
+
|
| 77 |
+
# --- Analysis Sections ---
|
| 78 |
+
|
| 79 |
+
print("\n📊 --- Top 10 Sponsors ---")
|
| 80 |
+
if "org" in df.columns:
|
| 81 |
+
print(df["org"].value_counts().head(10))
|
| 82 |
+
else:
|
| 83 |
+
print("⚠️ 'org' field not found in metadata.")
|
| 84 |
+
|
| 85 |
+
print("\n📊 --- Phase Distribution ---")
|
| 86 |
+
if "phase" in df.columns:
|
| 87 |
+
print(df["phase"].value_counts())
|
| 88 |
+
else:
|
| 89 |
+
print("⚠️ 'phase' field not found in metadata.")
|
| 90 |
+
|
| 91 |
+
print("\n📊 --- Status Distribution ---")
|
| 92 |
+
if "status" in df.columns:
|
| 93 |
+
print(df["status"].value_counts())
|
| 94 |
+
else:
|
| 95 |
+
print("⚠️ 'status' field not found in metadata.")
|
| 96 |
+
|
| 97 |
+
print("\n📊 --- Top Conditions ---")
|
| 98 |
+
if "condition" in df.columns:
|
| 99 |
+
# Conditions are comma-separated strings, so we split and explode them
|
| 100 |
+
all_conditions = []
|
| 101 |
+
for conditions in df["condition"].dropna():
|
| 102 |
+
all_conditions.extend([c.strip() for c in conditions.split(",")])
|
| 103 |
+
print(pd.Series(all_conditions).value_counts().head(10))
|
| 104 |
+
else:
|
| 105 |
+
print("⚠️ 'condition' field not found in metadata.")
|
| 106 |
+
|
| 107 |
+
print("\n📊 --- Top Interventions ---")
|
| 108 |
+
if "intervention" in df.columns:
|
| 109 |
+
# Interventions are semicolon-separated strings (from ingest_ct.py), so we split by "; "
|
| 110 |
+
all_interventions = []
|
| 111 |
+
for interventions in df["intervention"].dropna():
|
| 112 |
+
# Split by semicolon and strip whitespace
|
| 113 |
+
parts = [i.strip() for i in interventions.split(";") if i.strip()]
|
| 114 |
+
all_interventions.extend(parts)
|
| 115 |
+
|
| 116 |
+
if all_interventions:
|
| 117 |
+
print(pd.Series(all_interventions).value_counts().head(20))
|
| 118 |
+
else:
|
| 119 |
+
print("No interventions found.")
|
| 120 |
+
else:
|
| 121 |
+
print("⚠️ 'intervention' field not found in metadata.")
|
| 122 |
+
|
| 123 |
+
print("\n📝 --- Sample Studies (Most Recent Start Years) ---")
|
| 124 |
+
if "start_year" in df.columns and "title" in df.columns:
|
| 125 |
+
# Ensure start_year is numeric for sorting
|
| 126 |
+
df["start_year"] = pd.to_numeric(df["start_year"], errors="coerce")
|
| 127 |
+
top_recent = df.sort_values(by="start_year", ascending=False).head(5)
|
| 128 |
+
for _, row in top_recent.iterrows():
|
| 129 |
+
print(
|
| 130 |
+
f"- [{row.get('start_year', 'N/A')}] {row.get('title', 'N/A')} ({row.get('nct_id', 'N/A')})"
|
| 131 |
+
)
|
| 132 |
+
print(f" Sponsor: {row.get('org', 'N/A')}")
|
| 133 |
+
print(f" Intervention: {row.get('intervention', 'N/A')}")
|
| 134 |
+
|
| 135 |
+
print("\n📊 --- Intervention Check ---")
|
| 136 |
+
if "intervention" in df.columns:
|
| 137 |
+
non_empty = df[df["intervention"].str.len() > 0]
|
| 138 |
+
print(f"Total records with interventions: {len(non_empty)}")
|
| 139 |
+
if not non_empty.empty:
|
| 140 |
+
print("Sample Intervention:", non_empty.iloc[0]["intervention"])
|
| 141 |
+
else:
|
| 142 |
+
print("⚠️ 'intervention' field not found.")
|
| 143 |
+
|
| 144 |
+
except Exception as e:
|
| 145 |
+
print(f"⚠️ Error analyzing DB: {e}")
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
if __name__ == "__main__":
|
| 149 |
+
analyze_db()
|
scripts/ingest_ct.py
ADDED
|
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data Ingestion Script for Clinical Trial Agent.
|
| 3 |
+
|
| 4 |
+
This script fetches clinical trial data from the ClinicalTrials.gov API (v2),
|
| 5 |
+
processes it into a rich text format, and ingests it into a local ChromaDB vector index
|
| 6 |
+
using LlamaIndex and PubMedBERT embeddings.
|
| 7 |
+
|
| 8 |
+
Features:
|
| 9 |
+
- **Pagination**: Fetches data in batches using the API's pagination tokens.
|
| 10 |
+
- **Robustness**: Implements retry logic for network errors.
|
| 11 |
+
- **Efficiency**: Uses batch insertion and reuses the existing index.
|
| 12 |
+
- **Progress Tracking**: Displays a progress bar using `tqdm`.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import requests
|
| 16 |
+
import re
|
| 17 |
+
from datetime import datetime, timedelta
|
| 18 |
+
from dotenv import load_dotenv
|
| 19 |
+
import argparse
|
| 20 |
+
import time
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
import os
|
| 23 |
+
import concurrent.futures
|
| 24 |
+
|
| 25 |
+
# LlamaIndex Imports
|
| 26 |
+
from llama_index.core import Document, VectorStoreIndex, StorageContext, Settings
|
| 27 |
+
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
| 28 |
+
from llama_index.vector_stores.lancedb import LanceDBVectorStore
|
| 29 |
+
import lancedb
|
| 30 |
+
|
| 31 |
+
# List of US States for extraction
|
| 32 |
+
US_STATES = [
|
| 33 |
+
"Alabama", "Alaska", "Arizona", "Arkansas", "California", "Colorado", "Connecticut",
|
| 34 |
+
"Delaware", "Florida", "Georgia", "Hawaii", "Idaho", "Illinois", "Indiana", "Iowa",
|
| 35 |
+
"Kansas", "Kentucky", "Louisiana", "Maine", "Maryland", "Massachusetts", "Michigan",
|
| 36 |
+
"Minnesota", "Mississippi", "Missouri", "Montana", "Nebraska", "Nevada", "New Hampshire",
|
| 37 |
+
"New Jersey", "New Mexico", "New York", "North Carolina", "North Dakota", "Ohio",
|
| 38 |
+
"Oklahoma", "Oregon", "Pennsylvania", "Rhode Island", "South Carolina", "South Dakota",
|
| 39 |
+
"Tennessee", "Texas", "Utah", "Vermont", "Virginia", "Washington", "West Virginia",
|
| 40 |
+
"Wisconsin", "Wyoming", "District of Columbia"
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
load_dotenv()
|
| 44 |
+
|
| 45 |
+
# Disable LLM for ingestion (we only need embeddings, not generation)
|
| 46 |
+
Settings.llm = None
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def clean_text(text: str) -> str:
|
| 50 |
+
"""
|
| 51 |
+
Cleans raw text by removing HTML tags and normalizing whitespace.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
text (str): The raw text string.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
str: The cleaned text.
|
| 58 |
+
"""
|
| 59 |
+
if not text:
|
| 60 |
+
return ""
|
| 61 |
+
# Remove HTML tags
|
| 62 |
+
text = re.sub(r"<[^>]+>", "", text)
|
| 63 |
+
# Remove multiple spaces/newlines and trim
|
| 64 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 65 |
+
return text
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def fetch_trials_generator(
|
| 69 |
+
years: int = 5, max_studies: int = 1000, status: list = None, phases: list = None
|
| 70 |
+
):
|
| 71 |
+
"""
|
| 72 |
+
Yields batches of clinical trials from the ClinicalTrials.gov API.
|
| 73 |
+
|
| 74 |
+
Handles pagination automatically and implements retry logic for API requests.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
years (int): Number of years to look back for study start dates.
|
| 78 |
+
max_studies (int): Maximum total number of studies to fetch (-1 for all).
|
| 79 |
+
status (list): List of status strings to filter by (e.g., ["RECRUITING"]).
|
| 80 |
+
phases (list): List of phase strings to filter by (e.g., ["PHASE2"]).
|
| 81 |
+
|
| 82 |
+
Yields:
|
| 83 |
+
list: A batch of study dictionaries (JSON objects).
|
| 84 |
+
"""
|
| 85 |
+
base_url = "https://clinicaltrials.gov/api/v2/studies"
|
| 86 |
+
|
| 87 |
+
# Calculate start date for filtering
|
| 88 |
+
start_date = (datetime.now() - timedelta(days=365 * years)).strftime("%Y-%m-%d")
|
| 89 |
+
print("📡 Connecting to CT.gov API...")
|
| 90 |
+
print(f"🔎 Fetching trials starting after: {start_date}")
|
| 91 |
+
if status:
|
| 92 |
+
print(f" Filters - Status: {status}")
|
| 93 |
+
if phases:
|
| 94 |
+
print(f" Filters - Phases: {phases}")
|
| 95 |
+
|
| 96 |
+
fetched_count = 0
|
| 97 |
+
next_page_token = None
|
| 98 |
+
|
| 99 |
+
# If max_studies is -1, fetch ALL studies (infinite limit)
|
| 100 |
+
fetch_limit = float("inf") if max_studies == -1 else max_studies
|
| 101 |
+
|
| 102 |
+
while fetched_count < fetch_limit:
|
| 103 |
+
# Determine batch size (max 1000 per API limit)
|
| 104 |
+
current_limit = 1000
|
| 105 |
+
if max_studies != -1:
|
| 106 |
+
current_limit = min(1000, max_studies - fetched_count)
|
| 107 |
+
|
| 108 |
+
# --- Query Construction ---
|
| 109 |
+
# Build the query term using the API's syntax
|
| 110 |
+
query_parts = [f"AREA[StartDate]RANGE[{start_date},MAX]"]
|
| 111 |
+
|
| 112 |
+
if status:
|
| 113 |
+
status_str = " OR ".join(status)
|
| 114 |
+
query_parts.append(f"AREA[OverallStatus]({status_str})")
|
| 115 |
+
|
| 116 |
+
if phases:
|
| 117 |
+
phase_str = " OR ".join(phases)
|
| 118 |
+
query_parts.append(f"AREA[Phase]({phase_str})")
|
| 119 |
+
|
| 120 |
+
full_query = " AND ".join(query_parts)
|
| 121 |
+
|
| 122 |
+
params = {
|
| 123 |
+
"query.term": full_query,
|
| 124 |
+
"pageSize": current_limit,
|
| 125 |
+
# Request specific fields to minimize payload size
|
| 126 |
+
"fields": ",".join(
|
| 127 |
+
[
|
| 128 |
+
"protocolSection.identificationModule.nctId",
|
| 129 |
+
"protocolSection.identificationModule.briefTitle",
|
| 130 |
+
"protocolSection.identificationModule.officialTitle",
|
| 131 |
+
"protocolSection.identificationModule.organization",
|
| 132 |
+
"protocolSection.statusModule.overallStatus",
|
| 133 |
+
"protocolSection.statusModule.startDateStruct",
|
| 134 |
+
"protocolSection.statusModule.completionDateStruct",
|
| 135 |
+
"protocolSection.designModule.phases",
|
| 136 |
+
"protocolSection.designModule.studyType",
|
| 137 |
+
"protocolSection.eligibilityModule.eligibilityCriteria",
|
| 138 |
+
"protocolSection.eligibilityModule.sex",
|
| 139 |
+
"protocolSection.eligibilityModule.stdAges",
|
| 140 |
+
"protocolSection.descriptionModule.briefSummary",
|
| 141 |
+
"protocolSection.conditionsModule.conditions",
|
| 142 |
+
"protocolSection.outcomesModule.primaryOutcomes",
|
| 143 |
+
"protocolSection.contactsLocationsModule.locations",
|
| 144 |
+
"protocolSection.outcomesModule.primaryOutcomes",
|
| 145 |
+
"protocolSection.contactsLocationsModule.locations",
|
| 146 |
+
"protocolSection.armsInterventionsModule",
|
| 147 |
+
"protocolSection.sponsorCollaboratorsModule.leadSponsor",
|
| 148 |
+
]
|
| 149 |
+
),
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
if next_page_token:
|
| 153 |
+
params["pageToken"] = next_page_token
|
| 154 |
+
|
| 155 |
+
# --- Retry Logic ---
|
| 156 |
+
retries = 3
|
| 157 |
+
for attempt in range(retries):
|
| 158 |
+
try:
|
| 159 |
+
response = requests.get(base_url, params=params, timeout=30)
|
| 160 |
+
if response.status_code == 200:
|
| 161 |
+
data = response.json()
|
| 162 |
+
studies = data.get("studies", [])
|
| 163 |
+
|
| 164 |
+
if not studies:
|
| 165 |
+
return # Stop generator if no studies returned
|
| 166 |
+
|
| 167 |
+
yield studies
|
| 168 |
+
|
| 169 |
+
fetched_count += len(studies)
|
| 170 |
+
next_page_token = data.get("nextPageToken")
|
| 171 |
+
|
| 172 |
+
if not next_page_token:
|
| 173 |
+
return # Stop generator if no more pages
|
| 174 |
+
|
| 175 |
+
break # Success, exit retry loop
|
| 176 |
+
else:
|
| 177 |
+
print(f"❌ API Error: {response.status_code} - {response.text}")
|
| 178 |
+
if attempt < retries - 1:
|
| 179 |
+
time.sleep(2)
|
| 180 |
+
else:
|
| 181 |
+
return # Stop generator on persistent error
|
| 182 |
+
except Exception as e:
|
| 183 |
+
print(f"❌ Request Error (Attempt {attempt+1}/{retries}): {e}")
|
| 184 |
+
if attempt < retries - 1:
|
| 185 |
+
time.sleep(2)
|
| 186 |
+
else:
|
| 187 |
+
return # Stop generator
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def process_study(study):
|
| 191 |
+
"""
|
| 192 |
+
Processes a single study dictionary into a LlamaIndex Document.
|
| 193 |
+
This function is designed to be run in parallel.
|
| 194 |
+
"""
|
| 195 |
+
try:
|
| 196 |
+
# Extract Modules
|
| 197 |
+
protocol = study.get("protocolSection", {})
|
| 198 |
+
identification = protocol.get("identificationModule", {})
|
| 199 |
+
status_module = protocol.get("statusModule", {})
|
| 200 |
+
design = protocol.get("designModule", {})
|
| 201 |
+
eligibility = protocol.get("eligibilityModule", {})
|
| 202 |
+
description = protocol.get("descriptionModule", {})
|
| 203 |
+
conditions_module = protocol.get("conditionsModule", {})
|
| 204 |
+
outcomes_module = protocol.get("outcomesModule", {})
|
| 205 |
+
arms_interventions_module = protocol.get("armsInterventionsModule", {})
|
| 206 |
+
outcomes_module = protocol.get("outcomesModule", {})
|
| 207 |
+
arms_interventions_module = protocol.get("armsInterventionsModule", {})
|
| 208 |
+
locations_module = protocol.get("contactsLocationsModule", {})
|
| 209 |
+
sponsor_module = protocol.get("sponsorCollaboratorsModule", {})
|
| 210 |
+
|
| 211 |
+
# Extract Fields
|
| 212 |
+
nct_id = identification.get("nctId", "N/A")
|
| 213 |
+
title = identification.get("briefTitle", "N/A")
|
| 214 |
+
official_title = identification.get("officialTitle", "N/A")
|
| 215 |
+
official_title = identification.get("officialTitle", "N/A")
|
| 216 |
+
org = identification.get("organization", {}).get("fullName", "N/A")
|
| 217 |
+
sponsor_name = sponsor_module.get("leadSponsor", {}).get("name", "N/A")
|
| 218 |
+
summary = clean_text(description.get("briefSummary", "N/A"))
|
| 219 |
+
|
| 220 |
+
overall_status = status_module.get("overallStatus", "N/A")
|
| 221 |
+
start_date = status_module.get("startDateStruct", {}).get("date", "N/A")
|
| 222 |
+
completion_date = status_module.get("completionDateStruct", {}).get(
|
| 223 |
+
"date", "N/A"
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
phases = ", ".join(design.get("phases", []))
|
| 227 |
+
study_type = design.get("studyType", "N/A")
|
| 228 |
+
|
| 229 |
+
criteria = clean_text(eligibility.get("eligibilityCriteria", "N/A"))
|
| 230 |
+
gender = eligibility.get("sex", "N/A")
|
| 231 |
+
ages = ", ".join(eligibility.get("stdAges", []))
|
| 232 |
+
|
| 233 |
+
conditions = ", ".join(conditions_module.get("conditions", []))
|
| 234 |
+
|
| 235 |
+
interventions = []
|
| 236 |
+
for interv in arms_interventions_module.get("interventions", []):
|
| 237 |
+
name = interv.get("name", "")
|
| 238 |
+
type_ = interv.get("type", "")
|
| 239 |
+
interventions.append(f"{type_}: {name}")
|
| 240 |
+
interventions_str = "; ".join(interventions)
|
| 241 |
+
|
| 242 |
+
primary_outcomes = []
|
| 243 |
+
for outcome in outcomes_module.get("primaryOutcomes", []):
|
| 244 |
+
measure = outcome.get("measure", "")
|
| 245 |
+
desc = outcome.get("description", "")
|
| 246 |
+
primary_outcomes.append(f"- {measure}: {desc}")
|
| 247 |
+
outcomes_str = clean_text("\n".join(primary_outcomes))
|
| 248 |
+
|
| 249 |
+
locations = []
|
| 250 |
+
for loc in locations_module.get("locations", []):
|
| 251 |
+
facility = loc.get("facility", "N/A")
|
| 252 |
+
city = loc.get("city", "")
|
| 253 |
+
country = loc.get("country", "")
|
| 254 |
+
locations.append(f"{facility} ({city}, {country})")
|
| 255 |
+
locations_str = "; ".join(locations[:5]) # Limit to 5 locations to save space
|
| 256 |
+
|
| 257 |
+
# Extract State (First match)
|
| 258 |
+
state = "Unknown"
|
| 259 |
+
# Check locations for US States
|
| 260 |
+
for loc_str in locations:
|
| 261 |
+
if "United States" in loc_str:
|
| 262 |
+
for s in US_STATES:
|
| 263 |
+
if s in loc_str:
|
| 264 |
+
state = s
|
| 265 |
+
break
|
| 266 |
+
if state != "Unknown":
|
| 267 |
+
break
|
| 268 |
+
|
| 269 |
+
# Construct Rich Page Content with Markdown Headers
|
| 270 |
+
# This text is what gets embedded and searched
|
| 271 |
+
page_content = (
|
| 272 |
+
f"# {title}\n"
|
| 273 |
+
f"**NCT ID:** {nct_id}\n"
|
| 274 |
+
f"**Official Title:** {official_title}\n"
|
| 275 |
+
f"**Sponsor:** {sponsor_name}\n"
|
| 276 |
+
f"**Organization:** {org}\n"
|
| 277 |
+
f"**Status:** {overall_status}\n"
|
| 278 |
+
f"**Phase:** {phases}\n"
|
| 279 |
+
f"**Study Type:** {study_type}\n"
|
| 280 |
+
f"**Start Date:** {start_date}\n"
|
| 281 |
+
f"**Completion Date:** {completion_date}\n\n"
|
| 282 |
+
f"## Summary\n{summary}\n\n"
|
| 283 |
+
f"## Conditions\n{conditions}\n\n"
|
| 284 |
+
f"## Interventions\n{interventions_str}\n\n"
|
| 285 |
+
f"## Eligibility Criteria\n"
|
| 286 |
+
f"**Gender:** {gender}\n"
|
| 287 |
+
f"**Ages:** {ages}\n"
|
| 288 |
+
f"**Criteria:**\n{criteria}\n\n"
|
| 289 |
+
f"## Primary Outcomes\n{outcomes_str}\n\n"
|
| 290 |
+
f"## Locations\n{locations_str}"
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# Metadata for filtering (Structured Data)
|
| 294 |
+
metadata = {
|
| 295 |
+
"nct_id": nct_id,
|
| 296 |
+
"title": title,
|
| 297 |
+
"org": org,
|
| 298 |
+
"sponsor": sponsor_name,
|
| 299 |
+
"status": overall_status,
|
| 300 |
+
"phase": phases,
|
| 301 |
+
"study_type": study_type,
|
| 302 |
+
"start_year": (int(start_date.split("-")[0]) if start_date != "N/A" else 0),
|
| 303 |
+
"condition": conditions,
|
| 304 |
+
"intervention": interventions_str,
|
| 305 |
+
"country": (
|
| 306 |
+
locations[0].split(",")[-1].strip() if locations else "Unknown"
|
| 307 |
+
),
|
| 308 |
+
"state": state,
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
return Document(text=page_content, metadata=metadata, id_=nct_id)
|
| 312 |
+
except Exception as e:
|
| 313 |
+
print(
|
| 314 |
+
f"⚠️ Error processing study {study.get('protocolSection', {}).get('identificationModule', {}).get('nctId', 'Unknown')}: {e}"
|
| 315 |
+
)
|
| 316 |
+
return None
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def run_ingestion():
|
| 320 |
+
"""
|
| 321 |
+
Main execution function for the ingestion script.
|
| 322 |
+
Parses arguments, initializes the index, and runs the ingestion loop.
|
| 323 |
+
"""
|
| 324 |
+
parser = argparse.ArgumentParser(description="Ingest Clinical Trials data.")
|
| 325 |
+
parser.add_argument(
|
| 326 |
+
"--limit",
|
| 327 |
+
type=int,
|
| 328 |
+
default=-1,
|
| 329 |
+
help="Number of studies to ingest. Set to -1 for ALL.",
|
| 330 |
+
)
|
| 331 |
+
parser.add_argument(
|
| 332 |
+
"--years", type=int, default=10, help="Number of years to look back."
|
| 333 |
+
)
|
| 334 |
+
parser.add_argument(
|
| 335 |
+
"--status",
|
| 336 |
+
type=str,
|
| 337 |
+
default="COMPLETED",
|
| 338 |
+
help="Comma-separated list of statuses (e.g., COMPLETED,RECRUITING).",
|
| 339 |
+
)
|
| 340 |
+
parser.add_argument(
|
| 341 |
+
"--phases",
|
| 342 |
+
type=str,
|
| 343 |
+
default="PHASE1,PHASE2,PHASE3,PHASE4",
|
| 344 |
+
help="Comma-separated list of phases (e.g., PHASE2,PHASE3).",
|
| 345 |
+
)
|
| 346 |
+
args = parser.parse_args()
|
| 347 |
+
|
| 348 |
+
status_list = args.status.split(",") if args.status else []
|
| 349 |
+
phase_list = args.phases.split(",") if args.phases else []
|
| 350 |
+
|
| 351 |
+
print(f"⚙️ Configuration: Limit={args.limit}, Years={args.years}")
|
| 352 |
+
print(f" Status Filter: {status_list}")
|
| 353 |
+
print(f" Phase Filter: {phase_list}")
|
| 354 |
+
|
| 355 |
+
# --- INITIALIZE LLAMAINDEX COMPONENTS ---
|
| 356 |
+
print("🧠 Initializing LlamaIndex Embeddings (PubMedBERT)...")
|
| 357 |
+
embed_model = HuggingFaceEmbedding(model_name="pritamdeka/S-PubMedBert-MS-MARCO")
|
| 358 |
+
|
| 359 |
+
# Initialize LanceDB (Persistent)
|
| 360 |
+
print("🚀 Initializing LanceDB...")
|
| 361 |
+
|
| 362 |
+
# Determine the project root directory (one level up from this script)
|
| 363 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 364 |
+
project_root = os.path.dirname(script_dir)
|
| 365 |
+
db_path = os.path.join(project_root, "ct_gov_lancedb")
|
| 366 |
+
|
| 367 |
+
# Connect to LanceDB
|
| 368 |
+
db = lancedb.connect(db_path)
|
| 369 |
+
|
| 370 |
+
table_name = "clinical_trials"
|
| 371 |
+
if table_name in db.table_names():
|
| 372 |
+
mode = "append"
|
| 373 |
+
print(f"ℹ️ Table '{table_name}' exists. Appending data.")
|
| 374 |
+
else:
|
| 375 |
+
mode = "create"
|
| 376 |
+
print(f"ℹ️ Table '{table_name}' does not exist. Creating new table.")
|
| 377 |
+
|
| 378 |
+
# Initialize Vector Store
|
| 379 |
+
vector_store = LanceDBVectorStore(
|
| 380 |
+
uri=db_path,
|
| 381 |
+
table_name=table_name,
|
| 382 |
+
mode=mode,
|
| 383 |
+
query_mode="hybrid" # Enable hybrid search support
|
| 384 |
+
)
|
| 385 |
+
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
| 386 |
+
|
| 387 |
+
# Initialize Index ONCE
|
| 388 |
+
# We pass the storage context to link it to the vector store
|
| 389 |
+
index = VectorStoreIndex.from_vector_store(
|
| 390 |
+
vector_store, storage_context=storage_context, embed_model=embed_model
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
total_ingested = 0
|
| 394 |
+
|
| 395 |
+
# Progress Bar
|
| 396 |
+
pbar = tqdm(
|
| 397 |
+
total=args.limit if args.limit > 0 else float("inf"),
|
| 398 |
+
desc="Ingesting Studies",
|
| 399 |
+
unit="study",
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
# --- INGESTION LOOP ---
|
| 403 |
+
# Use ProcessPoolExecutor for parallel processing of study data
|
| 404 |
+
with concurrent.futures.ProcessPoolExecutor() as executor:
|
| 405 |
+
for batch_studies in fetch_trials_generator(
|
| 406 |
+
years=args.years,
|
| 407 |
+
max_studies=args.limit,
|
| 408 |
+
status=status_list,
|
| 409 |
+
phases=phase_list,
|
| 410 |
+
):
|
| 411 |
+
# Parallelize the processing of the batch
|
| 412 |
+
# map returns an iterator, so we convert to list to trigger execution
|
| 413 |
+
documents_iter = executor.map(process_study, batch_studies)
|
| 414 |
+
|
| 415 |
+
# Filter out None results (errors)
|
| 416 |
+
documents = [doc for doc in documents_iter if doc is not None]
|
| 417 |
+
|
| 418 |
+
if documents:
|
| 419 |
+
# Overwrite Logic:
|
| 420 |
+
# To avoid duplicates, we delete existing records with the same NCT IDs.
|
| 421 |
+
doc_ids = [doc.id_ for doc in documents]
|
| 422 |
+
try:
|
| 423 |
+
# LanceDB supports deletion via SQL-like filter
|
| 424 |
+
# We construct a filter string: "nct_id IN ('NCT123', 'NCT456')"
|
| 425 |
+
ids_str = ", ".join([f"'{id}'" for id in doc_ids])
|
| 426 |
+
if ids_str:
|
| 427 |
+
tbl = db.open_table("clinical_trials")
|
| 428 |
+
tbl.delete(f"nct_id IN ({ids_str})")
|
| 429 |
+
except Exception as e:
|
| 430 |
+
# Ignore if table doesn't exist yet
|
| 431 |
+
pass
|
| 432 |
+
|
| 433 |
+
# Efficient Batch Insertion
|
| 434 |
+
# We convert documents to nodes and insert them into the index.
|
| 435 |
+
# This handles embedding generation automatically.
|
| 436 |
+
parser = Settings.node_parser
|
| 437 |
+
nodes = parser.get_nodes_from_documents(documents)
|
| 438 |
+
|
| 439 |
+
index.insert_nodes(nodes)
|
| 440 |
+
|
| 441 |
+
total_ingested += len(documents)
|
| 442 |
+
pbar.update(len(documents))
|
| 443 |
+
|
| 444 |
+
pbar.close()
|
| 445 |
+
print(f"🎉 Ingestion Complete! Total studies in DB: {total_ingested}")
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
if __name__ == "__main__":
|
| 449 |
+
run_ingestion()
|
scripts/remove_duplicates.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script to remove duplicate records from the LanceDB database.
|
| 3 |
+
|
| 4 |
+
This script scans the 'clinical_trials' table, identifies records with duplicate content
|
| 5 |
+
(same 'nct_id' AND same 'text'), and removes the extras.
|
| 6 |
+
|
| 7 |
+
It uses a safe "Fetch -> Dedupe -> Overwrite" strategy:
|
| 8 |
+
1. Identifies NCT IDs that have duplicates.
|
| 9 |
+
2. For each such NCT ID, fetches ALL its records (chunks).
|
| 10 |
+
3. Deduplicates these records in memory based on their text content.
|
| 11 |
+
4. Deletes ALL records for that NCT ID from the database.
|
| 12 |
+
5. Re-inserts the unique records.
|
| 13 |
+
|
| 14 |
+
This ensures that valid chunks of the same study are PRESERVED, while exact duplicates are removed.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import pandas as pd
|
| 19 |
+
import lancedb
|
| 20 |
+
|
| 21 |
+
import argparse
|
| 22 |
+
|
| 23 |
+
def calculate_richness(record):
|
| 24 |
+
"""Calculates a 'richness' score for a record based on metadata field count and content length."""
|
| 25 |
+
score = 0
|
| 26 |
+
if not record:
|
| 27 |
+
return 0
|
| 28 |
+
|
| 29 |
+
for key, value in record.items():
|
| 30 |
+
if key == "vector": continue
|
| 31 |
+
|
| 32 |
+
# Handle nested metadata
|
| 33 |
+
if key == "metadata" and isinstance(value, dict):
|
| 34 |
+
score += calculate_richness(value) # Recurse
|
| 35 |
+
continue
|
| 36 |
+
|
| 37 |
+
# Check for non-empty values
|
| 38 |
+
if value is not None and str(value).strip() != "":
|
| 39 |
+
score += 10 # Base points for having a populated field
|
| 40 |
+
|
| 41 |
+
# Bonus points for content length
|
| 42 |
+
if isinstance(value, str):
|
| 43 |
+
score += len(value) / 100.0
|
| 44 |
+
|
| 45 |
+
return score
|
| 46 |
+
|
| 47 |
+
def remove_duplicates(dry_run=False):
|
| 48 |
+
# Determine the project root directory
|
| 49 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 50 |
+
project_root = os.path.dirname(script_dir)
|
| 51 |
+
db_path = os.path.join(project_root, "ct_gov_lancedb")
|
| 52 |
+
|
| 53 |
+
if not os.path.exists(db_path):
|
| 54 |
+
print(f"❌ Database directory '{db_path}' does not exist.")
|
| 55 |
+
return
|
| 56 |
+
|
| 57 |
+
print(f"📂 Loading database from {db_path}...")
|
| 58 |
+
if dry_run:
|
| 59 |
+
print("🧪 RUNNING IN DRY-RUN MODE (No changes will be made)")
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
db = lancedb.connect(db_path)
|
| 63 |
+
tbl = db.open_table("clinical_trials")
|
| 64 |
+
|
| 65 |
+
print("🔍 Scanning for duplicates...")
|
| 66 |
+
# Fetch all data
|
| 67 |
+
df = tbl.to_pandas()
|
| 68 |
+
|
| 69 |
+
if df.empty:
|
| 70 |
+
print("Database is empty.")
|
| 71 |
+
return
|
| 72 |
+
|
| 73 |
+
# Create a working copy to flatten metadata for analysis
|
| 74 |
+
working_df = df.copy()
|
| 75 |
+
if "metadata" in working_df.columns:
|
| 76 |
+
# Flatten metadata
|
| 77 |
+
meta_df = pd.json_normalize(working_df["metadata"])
|
| 78 |
+
# We drop the original metadata column from working_df and join the flattened one
|
| 79 |
+
working_df = pd.concat([working_df.drop(columns=["metadata"]), meta_df], axis=1)
|
| 80 |
+
|
| 81 |
+
if "nct_id" not in working_df.columns:
|
| 82 |
+
print("❌ 'nct_id' column not found (checked metadata too).")
|
| 83 |
+
return
|
| 84 |
+
|
| 85 |
+
if "text" not in working_df.columns:
|
| 86 |
+
print("❌ 'text' column not found. Cannot safely deduplicate chunks.")
|
| 87 |
+
return
|
| 88 |
+
|
| 89 |
+
# Identify duplicates based on (nct_id, text) using the flattened view
|
| 90 |
+
duplicates_mask = working_df.duplicated(subset=["nct_id", "text"], keep=False)
|
| 91 |
+
|
| 92 |
+
# We use the mask on working_df to find the IDs
|
| 93 |
+
duplicates_working_df = working_df[duplicates_mask]
|
| 94 |
+
|
| 95 |
+
if duplicates_working_df.empty:
|
| 96 |
+
print("✅ No exact duplicates found. Database is clean.")
|
| 97 |
+
return
|
| 98 |
+
|
| 99 |
+
unique_duplicate_ids = duplicates_working_df["nct_id"].unique()
|
| 100 |
+
print(f"⚠️ Found duplicates affecting {len(unique_duplicate_ids)} studies (NCT IDs).")
|
| 101 |
+
|
| 102 |
+
total_deleted = 0
|
| 103 |
+
total_reinserted = 0
|
| 104 |
+
|
| 105 |
+
# Process each affected NCT ID
|
| 106 |
+
for nct_id in unique_duplicate_ids:
|
| 107 |
+
# Get indices from working_df where nct_id matches
|
| 108 |
+
# This ensures we are looking at the right rows in the ORIGINAL df
|
| 109 |
+
indices = working_df[working_df["nct_id"] == nct_id].index
|
| 110 |
+
|
| 111 |
+
# Extract original records (preserving structure)
|
| 112 |
+
study_records_df = df.loc[indices]
|
| 113 |
+
original_count = len(study_records_df)
|
| 114 |
+
|
| 115 |
+
unique_records = []
|
| 116 |
+
seen_texts = set()
|
| 117 |
+
|
| 118 |
+
records = study_records_df.to_dict("records")
|
| 119 |
+
records.sort(key=calculate_richness, reverse=True)
|
| 120 |
+
|
| 121 |
+
for record in records:
|
| 122 |
+
text_content = record.get("text", "")
|
| 123 |
+
if text_content not in seen_texts:
|
| 124 |
+
unique_records.append(record)
|
| 125 |
+
seen_texts.add(text_content)
|
| 126 |
+
|
| 127 |
+
new_count = len(unique_records)
|
| 128 |
+
|
| 129 |
+
if new_count < original_count:
|
| 130 |
+
print(f" - {nct_id}: Reducing {original_count} -> {new_count} records.")
|
| 131 |
+
|
| 132 |
+
if not dry_run:
|
| 133 |
+
# Delete using the ID (LanceDB SQL filter)
|
| 134 |
+
# Note: In LanceDB SQL, if nct_id is in metadata struct, we access it via metadata.nct_id
|
| 135 |
+
# But wait, tbl.delete() takes a SQL string.
|
| 136 |
+
# If the schema has 'metadata' struct, we must use 'metadata.nct_id'.
|
| 137 |
+
# If it was flattened (unlikely for the table itself), we use 'nct_id'.
|
| 138 |
+
|
| 139 |
+
# We check if 'nct_id' is a top-level column in the original DF
|
| 140 |
+
if "nct_id" in df.columns:
|
| 141 |
+
where_clause = f"nct_id = '{nct_id}'"
|
| 142 |
+
else:
|
| 143 |
+
where_clause = f"metadata.nct_id = '{nct_id}'"
|
| 144 |
+
|
| 145 |
+
tbl.delete(where_clause)
|
| 146 |
+
|
| 147 |
+
if unique_records:
|
| 148 |
+
tbl.add(unique_records)
|
| 149 |
+
|
| 150 |
+
total_deleted += original_count
|
| 151 |
+
total_reinserted += new_count
|
| 152 |
+
else:
|
| 153 |
+
print(f" - {nct_id}: No reduction needed (false positive?).")
|
| 154 |
+
|
| 155 |
+
if dry_run:
|
| 156 |
+
print(f"\n🧪 DRY RUN COMPLETE.")
|
| 157 |
+
print(f" - WOULD remove {total_deleted - total_reinserted} duplicate records.")
|
| 158 |
+
print(f" - WOULD preserve {total_reinserted} unique chunks.")
|
| 159 |
+
else:
|
| 160 |
+
print(f"\n🎉 Deduplication complete!")
|
| 161 |
+
print(f" - Removed {total_deleted - total_reinserted} duplicate records.")
|
| 162 |
+
print(f" - Preserved {total_reinserted} unique chunks.")
|
| 163 |
+
|
| 164 |
+
except Exception as e:
|
| 165 |
+
print(f"❌ Error: {e}")
|
| 166 |
+
import traceback
|
| 167 |
+
traceback.print_exc()
|
| 168 |
+
|
| 169 |
+
if __name__ == "__main__":
|
| 170 |
+
parser = argparse.ArgumentParser(description="Remove duplicate records from LanceDB.")
|
| 171 |
+
parser.add_argument("--dry-run", action="store_true", help="Simulate the process without making changes.")
|
| 172 |
+
args = parser.parse_args()
|
| 173 |
+
|
| 174 |
+
remove_duplicates(dry_run=args.dry_run)
|
tests/test_data_integrity.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
import lancedb
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
|
| 7 |
+
# Add project root to path
|
| 8 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
| 9 |
+
|
| 10 |
+
class TestDataIntegrity(unittest.TestCase):
|
| 11 |
+
def setUp(self):
|
| 12 |
+
# Determine the project root directory
|
| 13 |
+
self.test_dir = os.path.dirname(os.path.abspath(__file__))
|
| 14 |
+
self.project_root = os.path.dirname(self.test_dir)
|
| 15 |
+
self.db_path = os.path.join(self.project_root, "ct_gov_lancedb")
|
| 16 |
+
|
| 17 |
+
def test_pfizer_myeloma_counts(self):
|
| 18 |
+
"""
|
| 19 |
+
Verifies that the database contains the expected number of Pfizer studies
|
| 20 |
+
related to Multiple Myeloma, based on strict keyword matching.
|
| 21 |
+
"""
|
| 22 |
+
if not os.path.exists(self.db_path):
|
| 23 |
+
self.skipTest(f"Database directory '{self.db_path}' does not exist. Skipping data integrity test.")
|
| 24 |
+
|
| 25 |
+
print(f"\n📂 Loading database from {self.db_path}...")
|
| 26 |
+
try:
|
| 27 |
+
db = lancedb.connect(self.db_path)
|
| 28 |
+
tbl = db.open_table("clinical_trials")
|
| 29 |
+
except Exception as e:
|
| 30 |
+
self.skipTest(f"Failed to load LanceDB table: {e}")
|
| 31 |
+
|
| 32 |
+
# Fetch all data (LanceDB is fast enough for this size, or we could query)
|
| 33 |
+
# For integrity check, loading into DF is fine.
|
| 34 |
+
df = tbl.to_pandas()
|
| 35 |
+
|
| 36 |
+
# Handle metadata flattening if needed (LanceDB stores metadata in a struct)
|
| 37 |
+
if "metadata" in df.columns:
|
| 38 |
+
# Flatten the metadata column
|
| 39 |
+
meta_df = pd.json_normalize(df["metadata"])
|
| 40 |
+
df = meta_df
|
| 41 |
+
|
| 42 |
+
# 1. Check for 'org' column
|
| 43 |
+
if "org" not in df.columns:
|
| 44 |
+
self.fail("'org' column missing from metadata.")
|
| 45 |
+
|
| 46 |
+
# 2. Filter by Sponsor (Pfizer)
|
| 47 |
+
pfizer_studies = df[df["org"].str.contains("Pfizer", case=False, na=False)]
|
| 48 |
+
# We expect at least some Pfizer studies if the DB is populated
|
| 49 |
+
self.assertGreater(len(pfizer_studies), 0, "No Pfizer studies found in DB.")
|
| 50 |
+
|
| 51 |
+
# 3. Filter by "Multiple Myeloma" in Title or Conditions
|
| 52 |
+
query = "Multiple Myeloma"
|
| 53 |
+
|
| 54 |
+
def is_relevant(row):
|
| 55 |
+
title = str(row.get("title", "")).lower()
|
| 56 |
+
conditions = str(row.get("condition", "")).lower()
|
| 57 |
+
q = query.lower()
|
| 58 |
+
return q in title or q in conditions
|
| 59 |
+
|
| 60 |
+
relevant_studies = pfizer_studies[pfizer_studies.apply(is_relevant, axis=1)]
|
| 61 |
+
|
| 62 |
+
count = len(relevant_studies)
|
| 63 |
+
print(f"🎯 Pfizer Studies with '{query}' in Title or Conditions: {count}")
|
| 64 |
+
|
| 65 |
+
# Assertion: Based on our previous check, we expect exactly 7.
|
| 66 |
+
# However, to be robust against minor data updates, we can assert a range or exact value.
|
| 67 |
+
# Let's assert it's non-zero and reasonably small (since we know it shouldn't be 514).
|
| 68 |
+
self.assertGreater(count, 0, "Should find at least one relevant study.")
|
| 69 |
+
self.assertLess(count, 50, "Should not find hundreds of studies (strict filter check).")
|
| 70 |
+
|
| 71 |
+
# Optional: Assert exact count if we want to be very strict about data consistency
|
| 72 |
+
# self.assertEqual(count, 7, "Expected exactly 7 studies based on known ground truth.")
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
unittest.main()
|
tests/test_hybrid_search.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# Add project root to path
|
| 6 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
| 7 |
+
|
| 8 |
+
from modules.tools import search_trials
|
| 9 |
+
from modules.utils import load_environment
|
| 10 |
+
|
| 11 |
+
# Mark as integration test since it loads the DB
|
| 12 |
+
@pytest.mark.integration
|
| 13 |
+
def test_hybrid_search_integration():
|
| 14 |
+
"""
|
| 15 |
+
Integration test for Hybrid Search.
|
| 16 |
+
Verifies that the search_trials tool can retrieve results using the hybrid retriever.
|
| 17 |
+
"""
|
| 18 |
+
load_environment()
|
| 19 |
+
|
| 20 |
+
# Test 1: Dynamic ID Search
|
| 21 |
+
# First, find a valid ID from a broad search
|
| 22 |
+
print("\n🔍 Finding a valid ID for testing...")
|
| 23 |
+
broad_results = search_trials.invoke({"query": "cancer"})
|
| 24 |
+
|
| 25 |
+
# Extract an ID from the results
|
| 26 |
+
import re
|
| 27 |
+
match = re.search(r"ID: (NCT\d+)", broad_results)
|
| 28 |
+
if not match:
|
| 29 |
+
pytest.skip("Could not find any studies in DB to test against.")
|
| 30 |
+
|
| 31 |
+
target_id = match.group(1)
|
| 32 |
+
print(f"🎯 Found target ID: {target_id}. Now testing exact search...")
|
| 33 |
+
|
| 34 |
+
# Now search for that specific ID
|
| 35 |
+
results_id = search_trials.invoke({"query": target_id})
|
| 36 |
+
|
| 37 |
+
assert "Found" in results_id
|
| 38 |
+
assert target_id in results_id, f"Hybrid search failed to retrieve exact ID {target_id}"
|
| 39 |
+
|
| 40 |
+
# Extract sponsor from the first result to ensure we test with valid data
|
| 41 |
+
# Result format: "**Title** ... - Sponsor: SponsorName ..."
|
| 42 |
+
sponsor_match = re.search(r"Sponsor: (.*?)\n", broad_results)
|
| 43 |
+
if not sponsor_match:
|
| 44 |
+
print("⚠️ Could not extract sponsor from results. Skipping hybrid test.")
|
| 45 |
+
return
|
| 46 |
+
|
| 47 |
+
target_sponsor = sponsor_match.group(1).strip()
|
| 48 |
+
# Normalize it to get the simple name if possible, or just use it
|
| 49 |
+
# But search_trials expects a simple name to map to variations.
|
| 50 |
+
# If we pass the full name, get_sponsor_variations might return None if not mapped.
|
| 51 |
+
# So let's try to find a mapped sponsor if possible, or just skip if not mapped.
|
| 52 |
+
|
| 53 |
+
from modules.utils import normalize_sponsor
|
| 54 |
+
simple_sponsor = normalize_sponsor(target_sponsor)
|
| 55 |
+
|
| 56 |
+
# If normalization didn't change it, it might not be in our alias list.
|
| 57 |
+
# But we can still try to search with it.
|
| 58 |
+
|
| 59 |
+
print(f"\n🔍 Testing Hybrid Search with dynamic sponsor: '{simple_sponsor}' (Original: {target_sponsor})")
|
| 60 |
+
|
| 61 |
+
# Use a generic query that likely matches the study, or just "study"
|
| 62 |
+
results_hybrid = search_trials.invoke({"query": "study", "sponsor": simple_sponsor})
|
| 63 |
+
|
| 64 |
+
assert "Found" in results_hybrid, f"Should find results for valid sponsor {simple_sponsor}"
|
| 65 |
+
assert target_sponsor in results_hybrid or simple_sponsor in results_hybrid
|
tests/test_sponsor_normalization.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
from modules.utils import normalize_sponsor, get_sponsor_variations, SPONSOR_MAPPINGS
|
| 3 |
+
|
| 4 |
+
class TestSponsorNormalization(unittest.TestCase):
|
| 5 |
+
def test_normalize_sponsor_aliases(self):
|
| 6 |
+
"""Test that common aliases map to canonical names."""
|
| 7 |
+
self.assertEqual(normalize_sponsor("J&J"), "Janssen")
|
| 8 |
+
self.assertEqual(normalize_sponsor("Johnson & Johnson"), "Janssen")
|
| 9 |
+
self.assertEqual(normalize_sponsor("GSK"), "GlaxoSmithKline")
|
| 10 |
+
self.assertEqual(normalize_sponsor("Merck"), "Merck Sharp & Dohme")
|
| 11 |
+
self.assertEqual(normalize_sponsor("BMS"), "Bristol-Myers Squibb")
|
| 12 |
+
|
| 13 |
+
def test_normalize_sponsor_variations(self):
|
| 14 |
+
"""Test that specific DB variations map to canonical names."""
|
| 15 |
+
self.assertEqual(normalize_sponsor("Janssen Research & Development, LLC"), "Janssen")
|
| 16 |
+
self.assertEqual(normalize_sponsor("Pfizer Inc."), "Pfizer")
|
| 17 |
+
self.assertEqual(normalize_sponsor("Merck Sharp & Dohme LLC"), "Merck Sharp & Dohme")
|
| 18 |
+
|
| 19 |
+
def test_normalize_sponsor_canonical(self):
|
| 20 |
+
"""Test that canonical names return themselves."""
|
| 21 |
+
self.assertEqual(normalize_sponsor("Janssen"), "Janssen")
|
| 22 |
+
self.assertEqual(normalize_sponsor("Pfizer"), "Pfizer")
|
| 23 |
+
|
| 24 |
+
def test_get_sponsor_variations(self):
|
| 25 |
+
"""Test that getting variations works for aliases and canonical names."""
|
| 26 |
+
# Test with alias
|
| 27 |
+
vars_jnj = get_sponsor_variations("J&J")
|
| 28 |
+
self.assertIn("Janssen Research & Development, LLC", vars_jnj)
|
| 29 |
+
self.assertIn("Janssen", vars_jnj)
|
| 30 |
+
|
| 31 |
+
# Test with canonical
|
| 32 |
+
vars_janssen = get_sponsor_variations("Janssen")
|
| 33 |
+
self.assertEqual(vars_jnj, vars_janssen)
|
| 34 |
+
|
| 35 |
+
# Test with variation input (should normalize first)
|
| 36 |
+
vars_variation = get_sponsor_variations("Janssen Research & Development, LLC")
|
| 37 |
+
self.assertEqual(vars_janssen, vars_variation)
|
| 38 |
+
|
| 39 |
+
def test_unknown_sponsor(self):
|
| 40 |
+
"""Test behavior for unknown sponsors."""
|
| 41 |
+
self.assertEqual(normalize_sponsor("Unknown Pharma"), "Unknown Pharma")
|
| 42 |
+
self.assertIsNone(get_sponsor_variations("Unknown Pharma"))
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
unittest.main()
|
tests/test_unit.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
from unittest.mock import MagicMock, patch
|
| 6 |
+
|
| 7 |
+
# Add project root to path to import app modules
|
| 8 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
| 9 |
+
|
| 10 |
+
from modules.utils import normalize_sponsor # noqa: E402
|
| 11 |
+
from modules.tools import expand_query # noqa: E402
|
| 12 |
+
from modules.graph_viz import build_graph # noqa: E402
|
| 13 |
+
from llama_index.core.schema import NodeWithScore, TextNode # noqa: E402
|
| 14 |
+
|
| 15 |
+
# --- Tests for normalize_sponsor ---
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def test_normalize_sponsor_aliases():
|
| 19 |
+
assert normalize_sponsor("J&J") == "Janssen"
|
| 20 |
+
assert normalize_sponsor("Johnson & Johnson") == "Janssen"
|
| 21 |
+
assert normalize_sponsor("GSK") == "GlaxoSmithKline"
|
| 22 |
+
assert normalize_sponsor("Merck") == "Merck Sharp & Dohme"
|
| 23 |
+
assert normalize_sponsor("MSD") == "Merck Sharp & Dohme"
|
| 24 |
+
assert normalize_sponsor("BMS") == "Bristol-Myers Squibb"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def test_normalize_sponsor_no_change():
|
| 28 |
+
assert normalize_sponsor("Pfizer") == "Pfizer"
|
| 29 |
+
assert normalize_sponsor("Moderna") == "Moderna"
|
| 30 |
+
assert normalize_sponsor("Unknown Sponsor") == "Unknown Sponsor"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# --- Tests for Analytics Logic (Mocked) ---
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def filter_dataframe(df, phase=None, status=None, sponsor=None, intervention=None):
|
| 37 |
+
"""
|
| 38 |
+
Replicating the logic from get_study_analytics for testing purposes.
|
| 39 |
+
"""
|
| 40 |
+
if phase:
|
| 41 |
+
target_phases = [p.strip().upper().replace(" ", "") for p in phase.split(",")]
|
| 42 |
+
df["phase_upper"] = df["phase"].astype(str).str.upper().str.replace(" ", "")
|
| 43 |
+
mask = df["phase_upper"].apply(lambda x: any(tp in x for tp in target_phases))
|
| 44 |
+
df = df[mask]
|
| 45 |
+
|
| 46 |
+
if status:
|
| 47 |
+
df = df[df["status"].str.upper() == status.upper()]
|
| 48 |
+
|
| 49 |
+
if sponsor:
|
| 50 |
+
target_sponsor = normalize_sponsor(sponsor).lower()
|
| 51 |
+
df["org_lower"] = df["org"].astype(str).apply(normalize_sponsor).str.lower()
|
| 52 |
+
df = df[df["org_lower"].str.contains(target_sponsor, regex=False)]
|
| 53 |
+
|
| 54 |
+
if intervention:
|
| 55 |
+
target_intervention = intervention.lower()
|
| 56 |
+
df["intervention_lower"] = df["intervention"].astype(str).str.lower()
|
| 57 |
+
df = df[df["intervention_lower"].str.contains(target_intervention, regex=False)]
|
| 58 |
+
|
| 59 |
+
return df
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@pytest.fixture
|
| 63 |
+
def sample_df():
|
| 64 |
+
data = {
|
| 65 |
+
"nct_id": ["NCT001", "NCT002", "NCT003", "NCT004"],
|
| 66 |
+
"phase": ["PHASE1", "PHASE2", "PHASE3", "PHASE2"],
|
| 67 |
+
"status": ["RECRUITING", "COMPLETED", "COMPLETED", "RECRUITING"],
|
| 68 |
+
"org": ["Pfizer", "Janssen", "Merck Sharp & Dohme", "Pfizer"],
|
| 69 |
+
"intervention": ["Drug A", "Drug B", "Keytruda", "Drug A + Drug C"],
|
| 70 |
+
"start_year": [2020, 2021, 2022, 2023],
|
| 71 |
+
"title": [
|
| 72 |
+
"Study of Drug A",
|
| 73 |
+
"Study of Drug B",
|
| 74 |
+
"Keytruda Trial",
|
| 75 |
+
"Combo Study",
|
| 76 |
+
],
|
| 77 |
+
"condition": ["Cancer", "Diabetes", "Lung Cancer", "Cancer"],
|
| 78 |
+
}
|
| 79 |
+
return pd.DataFrame(data)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def test_analytics_filter_intervention(sample_df):
|
| 83 |
+
# Filter for Keytruda
|
| 84 |
+
filtered = filter_dataframe(sample_df, intervention="Keytruda")
|
| 85 |
+
assert len(filtered) == 1
|
| 86 |
+
assert filtered.iloc[0]["nct_id"] == "NCT003"
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def test_analytics_filter_intervention_partial(sample_df):
|
| 90 |
+
# Filter for "Drug A" (should match NCT001 and NCT004)
|
| 91 |
+
filtered = filter_dataframe(sample_df, intervention="Drug A")
|
| 92 |
+
assert len(filtered) == 2
|
| 93 |
+
assert set(filtered["nct_id"]) == {"NCT001", "NCT004"}
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# --- Tests for Query Expansion ---
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@patch("modules.tools.Settings")
|
| 100 |
+
def test_expand_query(mock_settings):
|
| 101 |
+
# Mock LLM response
|
| 102 |
+
mock_response = MagicMock()
|
| 103 |
+
mock_response.text = "Expanded Query: cancer OR carcinoma OR tumor"
|
| 104 |
+
mock_settings.llm.complete.return_value = mock_response
|
| 105 |
+
|
| 106 |
+
query = "cancer"
|
| 107 |
+
expanded = expand_query(query)
|
| 108 |
+
|
| 109 |
+
assert "cancer OR carcinoma OR tumor" in expanded
|
| 110 |
+
mock_settings.llm.complete.assert_called_once()
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def test_expand_query_skip_long():
|
| 114 |
+
long_query = "this is a very long query that should definitely be skipped because it has too many words"
|
| 115 |
+
assert expand_query(long_query) == long_query
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# --- Tests for Graph Visualization ---
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def test_build_graph():
|
| 125 |
+
data = [
|
| 126 |
+
{"nct_id": "NCT1", "title": "Study 1", "org": "Pfizer", "condition": "Cancer"},
|
| 127 |
+
{
|
| 128 |
+
"nct_id": "NCT2",
|
| 129 |
+
"title": "Study 2",
|
| 130 |
+
"org": "Merck",
|
| 131 |
+
"condition": "Cancer, Diabetes",
|
| 132 |
+
},
|
| 133 |
+
]
|
| 134 |
+
|
| 135 |
+
nodes, edges, config = build_graph(data)
|
| 136 |
+
|
| 137 |
+
# Check Nodes
|
| 138 |
+
# 2 Studies + 2 Sponsors + 2 Conditions (Cancer, Diabetes) = 6 Nodes
|
| 139 |
+
assert len(nodes) == 6
|
| 140 |
+
|
| 141 |
+
node_ids = [n.id for n in nodes]
|
| 142 |
+
assert "NCT1" in node_ids
|
| 143 |
+
assert "Pfizer" in node_ids
|
| 144 |
+
assert "Cancer" in node_ids
|
| 145 |
+
|
| 146 |
+
# Check Edges
|
| 147 |
+
# NCT1 -> Pfizer, NCT1 -> Cancer (2 edges)
|
| 148 |
+
# NCT2 -> Merck, NCT2 -> Cancer, NCT2 -> Diabetes (3 edges)
|
| 149 |
+
assert len(edges) == 5
|