import json from langchain.tools import tool from langchain_google_genai import ChatGoogleGenerativeAI from langchain.prompts import PromptTemplate from modules.tools import get_study_details from modules.utils import load_environment import streamlit as st import os # Load env for API key load_environment() def get_llm(): """Retrieves LLM instance with dynamic API key.""" # Check session state first (User provided key) api_key = None if hasattr(st, "session_state") and "api_key" in st.session_state: api_key = st.session_state["api_key"] # Fallback to environment variable if not api_key: api_key = os.environ.get("GOOGLE_API_KEY") if not api_key: raise ValueError("Google API Key not found in session state or environment.") return ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0, google_api_key=api_key) EXTRACT_PROMPT = PromptTemplate( template=""" You are a Clinical Informatics Expert. Your task is to extract structured cohort requirements from the following Clinical Trial Eligibility Criteria. Output a JSON object with two keys: "inclusion" and "exclusion". Each key should contain a list of rules. Each rule should have: - "concept": The medical concept (e.g., "Type 2 Diabetes", "Metformin"). - "domain": The domain (Condition, Drug, Measurement, Procedure, Observation). - "temporal": Any temporal logic (e.g., "History of", "Within last 6 months"). - "codes": A list of potential ICD-10 or RxNorm codes (make a best guess). CRITERIA: {criteria} JSON OUTPUT: """, input_variables=["criteria"], ) SQL_PROMPT = PromptTemplate( template=""" You are a SQL Expert specializing in Healthcare Claims Data Analysis. Generate a standard SQL query to define a cohort of patients based on the following structured requirements. ### Schema Assumptions 1. **`medical_claims`** (Diagnoses & Procedures): - `patient_id`, `claim_date`, `diagnosis_code` (ICD-10), `procedure_code` (CPT/HCPCS). 2. **`pharmacy_claims`** (Drugs): - `patient_id`, `fill_date`, `ndc_code`. ### Logic Rules 1. **Conditions (Diagnoses)**: - Require **at least 2 distinct claim dates** where the diagnosis code matches. - These 2 claims must be **at least 30 days apart** (to confirm chronic condition). 2. **Drugs**: - Require at least 1 claim with a matching NDC code. 3. **Procedures**: - Require at least 1 claim with a matching CPT/HCPCS code. 4. **Exclusions**: - Exclude patients who have ANY matching claims for exclusion criteria. ### Requirements (JSON) {requirements} ### Output Generate a single SQL query that selects `patient_id` from the claims tables meeting the criteria. Use Common Table Expressions (CTEs) for clarity. Do NOT output markdown formatting (```sql), just the raw SQL. SQL QUERY: """, input_variables=["requirements"], ) def extract_cohort_requirements(criteria_text: str) -> dict: """Uses LLM to parse criteria text into structured JSON.""" llm = get_llm() chain = EXTRACT_PROMPT | llm response = chain.invoke({"criteria": criteria_text}) try: # Clean up potential markdown code blocks text = response.content.replace("```json", "").replace("```", "").strip() return json.loads(text) except json.JSONDecodeError: return {"error": "Failed to parse LLM output", "raw_output": response.content} def generate_cohort_sql(requirements: dict) -> str: """Uses LLM to translate structured requirements into SQL.""" llm = get_llm() chain = SQL_PROMPT | llm response = chain.invoke({"requirements": json.dumps(requirements, indent=2)}) return response.content.replace("```sql", "").replace("```", "").strip() @tool("get_cohort_sql") def get_cohort_sql(nct_id: str) -> str: """ Generates a SQL query to define the patient cohort for a specific study (NCT ID). Args: nct_id (str): The ClinicalTrials.gov identifier (e.g., NCT01234567). Returns: str: A formatted string containing the Extracted Requirements (JSON) and the Generated SQL. """ # 1. Fetch Study Details # Reuse the existing tool logic to get the text study_text = get_study_details.invoke(nct_id) if "No study found" in study_text: return f"Could not find study {nct_id}." # 2. Extract Requirements requirements = extract_cohort_requirements(study_text) # 3. Generate SQL sql_query = generate_cohort_sql(requirements) return f""" ### 📋 Extracted Cohort Requirements ```json {json.dumps(requirements, indent=2)} ``` ### 💾 Generated SQL Query (OMOP CDM) ```sql {sql_query} ``` """