Spaces:
Sleeping
Sleeping
Merge pull request #1 from humandotlearning/jules/cred-db-mcp-server
Browse files- pyproject.toml +6 -2
- src/cred_db_mcp/db.py +0 -33
- src/cred_db_mcp/main.py +0 -98
- src/cred_db_mcp/mcp_tools.py +0 -185
- src/cred_db_mcp/models.py +0 -62
- src/cred_db_mcp/npi_client.py +0 -40
- src/cred_db_mcp/schemas.py +0 -79
- src/cred_db_mcp_server/__init__.py +4 -0
- src/cred_db_mcp_server/config.py +6 -0
- src/cred_db_mcp_server/main.py +75 -0
- src/cred_db_mcp_server/schemas.py +29 -0
- src/cred_db_mcp_server/tools.py +152 -0
- tests/test_cred_db_mcp_server.py +86 -0
- uv.lock +0 -0
pyproject.toml
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
[project]
|
| 2 |
-
name = "
|
| 3 |
version = "0.1.0"
|
| 4 |
description = "Credential Database MCP Server"
|
| 5 |
requires-python = ">=3.11"
|
|
@@ -9,7 +9,8 @@ dependencies = [
|
|
| 9 |
"sqlalchemy>=2.0.0",
|
| 10 |
"pydantic>=2.0.0",
|
| 11 |
"httpx>=0.24.0",
|
| 12 |
-
"python-dotenv>=1.0.0"
|
|
|
|
| 13 |
]
|
| 14 |
|
| 15 |
[project.optional-dependencies]
|
|
@@ -22,6 +23,9 @@ test = [
|
|
| 22 |
requires = ["hatchling"]
|
| 23 |
build-backend = "hatchling.build"
|
| 24 |
|
|
|
|
|
|
|
|
|
|
| 25 |
[tool.pytest.ini_options]
|
| 26 |
pythonpath = "src"
|
| 27 |
testpaths = ["tests"]
|
|
|
|
| 1 |
[project]
|
| 2 |
+
name = "cred_db_mcp_server"
|
| 3 |
version = "0.1.0"
|
| 4 |
description = "Credential Database MCP Server"
|
| 5 |
requires-python = ">=3.11"
|
|
|
|
| 9 |
"sqlalchemy>=2.0.0",
|
| 10 |
"pydantic>=2.0.0",
|
| 11 |
"httpx>=0.24.0",
|
| 12 |
+
"python-dotenv>=1.0.0",
|
| 13 |
+
"gradio>=5.0.0",
|
| 14 |
]
|
| 15 |
|
| 16 |
[project.optional-dependencies]
|
|
|
|
| 23 |
requires = ["hatchling"]
|
| 24 |
build-backend = "hatchling.build"
|
| 25 |
|
| 26 |
+
[tool.hatch.build.targets.wheel]
|
| 27 |
+
packages = ["src/cred_db_mcp_server"]
|
| 28 |
+
|
| 29 |
[tool.pytest.ini_options]
|
| 30 |
pythonpath = "src"
|
| 31 |
testpaths = ["tests"]
|
src/cred_db_mcp/db.py
DELETED
|
@@ -1,33 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from sqlalchemy import create_engine
|
| 3 |
-
from sqlalchemy.orm import sessionmaker, DeclarativeBase
|
| 4 |
-
|
| 5 |
-
# Default to a local file for dev/test if not specified
|
| 6 |
-
DB_PATH = os.getenv("DB_PATH", "credentialwatch.db")
|
| 7 |
-
DATABASE_URL = f"sqlite:///{DB_PATH}"
|
| 8 |
-
|
| 9 |
-
# specific check for checks that might run in different environments
|
| 10 |
-
if DB_PATH == ":memory:":
|
| 11 |
-
DATABASE_URL = "sqlite:///:memory:"
|
| 12 |
-
|
| 13 |
-
engine = create_engine(
|
| 14 |
-
DATABASE_URL,
|
| 15 |
-
connect_args={"check_same_thread": False} # Needed for SQLite
|
| 16 |
-
)
|
| 17 |
-
|
| 18 |
-
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
| 19 |
-
|
| 20 |
-
class Base(DeclarativeBase):
|
| 21 |
-
pass
|
| 22 |
-
|
| 23 |
-
def get_db():
|
| 24 |
-
"""Dependency for FastAPI to get a DB session."""
|
| 25 |
-
db = SessionLocal()
|
| 26 |
-
try:
|
| 27 |
-
yield db
|
| 28 |
-
finally:
|
| 29 |
-
db.close()
|
| 30 |
-
|
| 31 |
-
def init_db():
|
| 32 |
-
"""Helper to initialize the DB (create tables)."""
|
| 33 |
-
Base.metadata.create_all(bind=engine)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/cred_db_mcp/main.py
DELETED
|
@@ -1,98 +0,0 @@
|
|
| 1 |
-
from fastapi import FastAPI, Depends, HTTPException, Body
|
| 2 |
-
from sqlalchemy.orm import Session
|
| 3 |
-
from typing import List, Union
|
| 4 |
-
|
| 5 |
-
from .db import get_db, init_db
|
| 6 |
-
from .mcp_tools import MCPTools
|
| 7 |
-
from .schemas import (
|
| 8 |
-
SyncProviderInput, AddCredentialInput, ListExpiringInput, GetSnapshotInput,
|
| 9 |
-
ProviderRead, CredentialRead, ExpiringCredentialItem, ProviderSnapshot, ToolError
|
| 10 |
-
)
|
| 11 |
-
|
| 12 |
-
app = FastAPI(title="CredentialWatch DB MCP", version="0.1.0")
|
| 13 |
-
|
| 14 |
-
# Initialize DB on startup (for simple deployments)
|
| 15 |
-
@app.on_event("startup")
|
| 16 |
-
def on_startup():
|
| 17 |
-
init_db()
|
| 18 |
-
|
| 19 |
-
def get_mcp_tools(db: Session = Depends(get_db)) -> MCPTools:
|
| 20 |
-
return MCPTools(db)
|
| 21 |
-
|
| 22 |
-
# --- MCP Tool Endpoints ---
|
| 23 |
-
# Pattern: /mcp/tools/{tool_name}
|
| 24 |
-
# We use POST for all of them as they are "function calls"
|
| 25 |
-
|
| 26 |
-
@app.post("/mcp/tools/sync_provider_from_npi", response_model=Union[ProviderRead, ToolError])
|
| 27 |
-
async def sync_provider_from_npi(
|
| 28 |
-
args: SyncProviderInput,
|
| 29 |
-
tools: MCPTools = Depends(get_mcp_tools)
|
| 30 |
-
):
|
| 31 |
-
"""
|
| 32 |
-
Syncs provider data from the NPI Registry via npi_mcp.
|
| 33 |
-
"""
|
| 34 |
-
result = await tools.sync_provider_from_npi(args.npi)
|
| 35 |
-
if isinstance(result, dict) and "error" in result:
|
| 36 |
-
# Return generic error structure instead of 400 for agent handling if preferred,
|
| 37 |
-
# but usually 400 is better for HTTP.
|
| 38 |
-
# The prompt says "return a structured error".
|
| 39 |
-
return ToolError(error=result["error"])
|
| 40 |
-
return result
|
| 41 |
-
|
| 42 |
-
@app.post("/mcp/tools/add_or_update_credential", response_model=Union[CredentialRead, ToolError])
|
| 43 |
-
def add_or_update_credential(
|
| 44 |
-
args: AddCredentialInput,
|
| 45 |
-
tools: MCPTools = Depends(get_mcp_tools)
|
| 46 |
-
):
|
| 47 |
-
"""
|
| 48 |
-
Adds or updates a credential record.
|
| 49 |
-
"""
|
| 50 |
-
result = tools.add_or_update_credential(
|
| 51 |
-
provider_id=args.provider_id,
|
| 52 |
-
type=args.type,
|
| 53 |
-
issuer=args.issuer,
|
| 54 |
-
number=args.number,
|
| 55 |
-
expiry_date=args.expiry_date
|
| 56 |
-
)
|
| 57 |
-
if isinstance(result, dict) and "error" in result:
|
| 58 |
-
return ToolError(error=result["error"])
|
| 59 |
-
return result
|
| 60 |
-
|
| 61 |
-
@app.post("/mcp/tools/list_expiring_credentials", response_model=List[ExpiringCredentialItem])
|
| 62 |
-
def list_expiring_credentials(
|
| 63 |
-
args: ListExpiringInput,
|
| 64 |
-
tools: MCPTools = Depends(get_mcp_tools)
|
| 65 |
-
):
|
| 66 |
-
"""
|
| 67 |
-
Lists credentials expiring within a window.
|
| 68 |
-
"""
|
| 69 |
-
return tools.list_expiring_credentials(
|
| 70 |
-
window_days=args.window_days,
|
| 71 |
-
dept=args.dept,
|
| 72 |
-
location=args.location
|
| 73 |
-
)
|
| 74 |
-
|
| 75 |
-
@app.post("/mcp/tools/get_provider_snapshot", response_model=Union[ProviderSnapshot, ToolError])
|
| 76 |
-
def get_provider_snapshot(
|
| 77 |
-
args: GetSnapshotInput,
|
| 78 |
-
tools: MCPTools = Depends(get_mcp_tools)
|
| 79 |
-
):
|
| 80 |
-
"""
|
| 81 |
-
Gets a full snapshot of a provider.
|
| 82 |
-
"""
|
| 83 |
-
result = tools.get_provider_snapshot(
|
| 84 |
-
provider_id=args.provider_id,
|
| 85 |
-
npi=args.npi
|
| 86 |
-
)
|
| 87 |
-
if isinstance(result, dict) and "error" in result:
|
| 88 |
-
return ToolError(error=result["error"])
|
| 89 |
-
return result
|
| 90 |
-
|
| 91 |
-
# Simple info endpoint
|
| 92 |
-
@app.get("/")
|
| 93 |
-
def read_root():
|
| 94 |
-
return {
|
| 95 |
-
"service": "cred_db_mcp",
|
| 96 |
-
"status": "running",
|
| 97 |
-
"docs": "/docs"
|
| 98 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/cred_db_mcp/mcp_tools.py
DELETED
|
@@ -1,185 +0,0 @@
|
|
| 1 |
-
from sqlalchemy.orm import Session
|
| 2 |
-
from sqlalchemy import select, and_, func
|
| 3 |
-
from datetime import date, timedelta, datetime
|
| 4 |
-
import json
|
| 5 |
-
|
| 6 |
-
from .models import Provider, Credential, Alert
|
| 7 |
-
from .schemas import (
|
| 8 |
-
ProviderRead, CredentialRead, AlertRead,
|
| 9 |
-
ExpiringCredentialItem, ProviderSnapshot
|
| 10 |
-
)
|
| 11 |
-
from .npi_client import NPIClient
|
| 12 |
-
|
| 13 |
-
class MCPTools:
|
| 14 |
-
def __init__(self, db: Session):
|
| 15 |
-
self.db = db
|
| 16 |
-
self.npi_client = NPIClient()
|
| 17 |
-
|
| 18 |
-
async def sync_provider_from_npi(self, npi: str):
|
| 19 |
-
# 1. Call npi_mcp
|
| 20 |
-
npi_data = await self.npi_client.get_provider_by_npi(npi)
|
| 21 |
-
|
| 22 |
-
if not npi_data:
|
| 23 |
-
return {"error": f"NPI {npi} not found in upstream registry"}
|
| 24 |
-
|
| 25 |
-
# Map NPI data to our schema.
|
| 26 |
-
# Assuming npi_data has keys: npi, first_name, last_name, taxonomy_desc, practice_address
|
| 27 |
-
# This mapping is hypothetical based on typical NPI registry fields.
|
| 28 |
-
|
| 29 |
-
full_name = f"{npi_data.get('first_name', '')} {npi_data.get('last_name', '')}".strip()
|
| 30 |
-
if not full_name:
|
| 31 |
-
full_name = npi_data.get('organization_name', 'Unknown')
|
| 32 |
-
|
| 33 |
-
specialty = npi_data.get('taxonomy_desc', 'Unknown')
|
| 34 |
-
address_obj = npi_data.get('practice_address', {})
|
| 35 |
-
# Simple string representation of location
|
| 36 |
-
location = f"{address_obj.get('city', '')}, {address_obj.get('state', '')}" if isinstance(address_obj, dict) else str(address_obj)
|
| 37 |
-
|
| 38 |
-
# 2. Update or Create in DB
|
| 39 |
-
provider = self.db.scalar(select(Provider).where(Provider.npi == npi))
|
| 40 |
-
|
| 41 |
-
if not provider:
|
| 42 |
-
provider = Provider(
|
| 43 |
-
npi=npi,
|
| 44 |
-
full_name=full_name,
|
| 45 |
-
primary_specialty=specialty,
|
| 46 |
-
location=location,
|
| 47 |
-
is_active=True
|
| 48 |
-
)
|
| 49 |
-
self.db.add(provider)
|
| 50 |
-
else:
|
| 51 |
-
provider.full_name = full_name
|
| 52 |
-
provider.primary_specialty = specialty
|
| 53 |
-
provider.location = location
|
| 54 |
-
# Don't overwrite dept if set locally? assuming upstream doesn't have dept.
|
| 55 |
-
|
| 56 |
-
self.db.commit()
|
| 57 |
-
self.db.refresh(provider)
|
| 58 |
-
|
| 59 |
-
return ProviderRead.model_validate(provider)
|
| 60 |
-
|
| 61 |
-
def add_or_update_credential(
|
| 62 |
-
self, provider_id: int, type: str, issuer: str, number: str, expiry_date: str
|
| 63 |
-
):
|
| 64 |
-
# Parse expiry_date (assuming YYYY-MM-DD)
|
| 65 |
-
try:
|
| 66 |
-
exp_date = datetime.strptime(expiry_date, "%Y-%m-%d").date()
|
| 67 |
-
except ValueError:
|
| 68 |
-
return {"error": "Invalid date format. Use YYYY-MM-DD"}
|
| 69 |
-
|
| 70 |
-
# Check provider exists
|
| 71 |
-
provider = self.db.get(Provider, provider_id)
|
| 72 |
-
if not provider:
|
| 73 |
-
return {"error": f"Provider with ID {provider_id} not found"}
|
| 74 |
-
|
| 75 |
-
# Find existing credential
|
| 76 |
-
stmt = select(Credential).where(
|
| 77 |
-
and_(
|
| 78 |
-
Credential.provider_id == provider_id,
|
| 79 |
-
Credential.type == type,
|
| 80 |
-
Credential.issuer == issuer,
|
| 81 |
-
Credential.number == number
|
| 82 |
-
)
|
| 83 |
-
)
|
| 84 |
-
credential = self.db.scalar(stmt)
|
| 85 |
-
|
| 86 |
-
if credential:
|
| 87 |
-
credential.expiry_date = exp_date
|
| 88 |
-
# Heuristic status update
|
| 89 |
-
credential.status = "active" if exp_date > date.today() else "expired"
|
| 90 |
-
credential.updated_at = datetime.now()
|
| 91 |
-
else:
|
| 92 |
-
credential = Credential(
|
| 93 |
-
provider_id=provider_id,
|
| 94 |
-
type=type,
|
| 95 |
-
issuer=issuer,
|
| 96 |
-
number=number,
|
| 97 |
-
expiry_date=exp_date,
|
| 98 |
-
status="active" if exp_date > date.today() else "expired",
|
| 99 |
-
created_at=datetime.now()
|
| 100 |
-
)
|
| 101 |
-
self.db.add(credential)
|
| 102 |
-
|
| 103 |
-
self.db.commit()
|
| 104 |
-
self.db.refresh(credential)
|
| 105 |
-
|
| 106 |
-
return CredentialRead.model_validate(credential)
|
| 107 |
-
|
| 108 |
-
def list_expiring_credentials(
|
| 109 |
-
self, window_days: int, dept: str | None = None, location: str | None = None
|
| 110 |
-
):
|
| 111 |
-
today = date.today()
|
| 112 |
-
cutoff = today + timedelta(days=window_days)
|
| 113 |
-
|
| 114 |
-
# Build Query
|
| 115 |
-
stmt = select(Credential, Provider).join(Provider).where(
|
| 116 |
-
and_(
|
| 117 |
-
Credential.status == "active",
|
| 118 |
-
Credential.expiry_date <= cutoff,
|
| 119 |
-
Credential.expiry_date >= today # Only future or today (past would be expired)
|
| 120 |
-
)
|
| 121 |
-
)
|
| 122 |
-
|
| 123 |
-
if dept:
|
| 124 |
-
stmt = stmt.where(Provider.dept == dept)
|
| 125 |
-
if location:
|
| 126 |
-
stmt = stmt.where(Provider.location.ilike(f"%{location}%"))
|
| 127 |
-
|
| 128 |
-
results = self.db.execute(stmt).all()
|
| 129 |
-
|
| 130 |
-
output = []
|
| 131 |
-
for cred, prov in results:
|
| 132 |
-
if not cred.expiry_date:
|
| 133 |
-
continue
|
| 134 |
-
|
| 135 |
-
days_to_expiry = (cred.expiry_date - today).days
|
| 136 |
-
|
| 137 |
-
# Risk Score Heuristic
|
| 138 |
-
# 3 for <30 days, 2 for 30–60, 1 for 60–90
|
| 139 |
-
if days_to_expiry < 30:
|
| 140 |
-
risk_score = 3
|
| 141 |
-
elif days_to_expiry < 60:
|
| 142 |
-
risk_score = 2
|
| 143 |
-
else:
|
| 144 |
-
risk_score = 1
|
| 145 |
-
|
| 146 |
-
output.append(ExpiringCredentialItem(
|
| 147 |
-
provider=ProviderRead.model_validate(prov),
|
| 148 |
-
credential=CredentialRead.model_validate(cred),
|
| 149 |
-
days_to_expiry=days_to_expiry,
|
| 150 |
-
risk_score=risk_score
|
| 151 |
-
))
|
| 152 |
-
|
| 153 |
-
return output
|
| 154 |
-
|
| 155 |
-
def get_provider_snapshot(
|
| 156 |
-
self, provider_id: int | None = None, npi: str | None = None
|
| 157 |
-
):
|
| 158 |
-
if not provider_id and not npi:
|
| 159 |
-
return {"error": "Must provide either provider_id or npi"}
|
| 160 |
-
|
| 161 |
-
stmt = select(Provider)
|
| 162 |
-
if provider_id:
|
| 163 |
-
stmt = stmt.where(Provider.id == provider_id)
|
| 164 |
-
else:
|
| 165 |
-
stmt = stmt.where(Provider.npi == npi)
|
| 166 |
-
|
| 167 |
-
provider = self.db.scalar(stmt)
|
| 168 |
-
if not provider:
|
| 169 |
-
return {"error": "Provider not found"}
|
| 170 |
-
|
| 171 |
-
# Get Credentials
|
| 172 |
-
creds = self.db.scalars(
|
| 173 |
-
select(Credential).where(Credential.provider_id == provider.id)
|
| 174 |
-
).all()
|
| 175 |
-
|
| 176 |
-
# Get Alerts (Optional)
|
| 177 |
-
alerts = self.db.scalars(
|
| 178 |
-
select(Alert).where(Alert.provider_id == provider.id)
|
| 179 |
-
).all()
|
| 180 |
-
|
| 181 |
-
return ProviderSnapshot(
|
| 182 |
-
provider=ProviderRead.model_validate(provider),
|
| 183 |
-
credentials=[CredentialRead.model_validate(c) for c in creds],
|
| 184 |
-
alerts=[AlertRead.model_validate(a) for a in alerts]
|
| 185 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/cred_db_mcp/models.py
DELETED
|
@@ -1,62 +0,0 @@
|
|
| 1 |
-
from datetime import datetime, date
|
| 2 |
-
from typing import Optional, List, Any
|
| 3 |
-
from sqlalchemy import (
|
| 4 |
-
String, Integer, Boolean, DateTime, Date, ForeignKey, JSON, func
|
| 5 |
-
)
|
| 6 |
-
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
| 7 |
-
|
| 8 |
-
from .db import Base
|
| 9 |
-
|
| 10 |
-
class Provider(Base):
|
| 11 |
-
__tablename__ = "providers"
|
| 12 |
-
|
| 13 |
-
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
|
| 14 |
-
npi: Mapped[Optional[str]] = mapped_column(String, index=True, nullable=True)
|
| 15 |
-
full_name: Mapped[str] = mapped_column(String)
|
| 16 |
-
dept: Mapped[Optional[str]] = mapped_column(String, nullable=True)
|
| 17 |
-
location: Mapped[Optional[str]] = mapped_column(String, nullable=True)
|
| 18 |
-
primary_specialty: Mapped[Optional[str]] = mapped_column(String, nullable=True)
|
| 19 |
-
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
| 20 |
-
|
| 21 |
-
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
| 22 |
-
updated_at: Mapped[datetime] = mapped_column(
|
| 23 |
-
DateTime, server_default=func.now(), onupdate=func.now()
|
| 24 |
-
)
|
| 25 |
-
|
| 26 |
-
credentials: Mapped[List["Credential"]] = relationship(
|
| 27 |
-
"Credential", back_populates="provider", cascade="all, delete-orphan"
|
| 28 |
-
)
|
| 29 |
-
|
| 30 |
-
class Credential(Base):
|
| 31 |
-
__tablename__ = "credentials"
|
| 32 |
-
|
| 33 |
-
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
|
| 34 |
-
provider_id: Mapped[int] = mapped_column(
|
| 35 |
-
Integer, ForeignKey("providers.id"), nullable=False
|
| 36 |
-
)
|
| 37 |
-
type: Mapped[str] = mapped_column(String) # e.g. "state_license", "board_cert"
|
| 38 |
-
issuer: Mapped[str] = mapped_column(String)
|
| 39 |
-
number: Mapped[str] = mapped_column(String)
|
| 40 |
-
status: Mapped[str] = mapped_column(String) # "active", "expired", etc.
|
| 41 |
-
|
| 42 |
-
issue_date: Mapped[Optional[date]] = mapped_column(Date, nullable=True)
|
| 43 |
-
expiry_date: Mapped[Optional[date]] = mapped_column(Date, nullable=True)
|
| 44 |
-
last_verified_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
| 45 |
-
|
| 46 |
-
metadata_json: Mapped[Optional[Any]] = mapped_column(JSON, nullable=True)
|
| 47 |
-
|
| 48 |
-
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
| 49 |
-
updated_at: Mapped[datetime] = mapped_column(
|
| 50 |
-
DateTime, server_default=func.now(), onupdate=func.now()
|
| 51 |
-
)
|
| 52 |
-
|
| 53 |
-
provider: Mapped["Provider"] = relationship("Provider", back_populates="credentials")
|
| 54 |
-
|
| 55 |
-
class Alert(Base):
|
| 56 |
-
__tablename__ = "alerts"
|
| 57 |
-
|
| 58 |
-
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
|
| 59 |
-
provider_id: Mapped[int] = mapped_column(Integer, ForeignKey("providers.id"))
|
| 60 |
-
message: Mapped[str] = mapped_column(String)
|
| 61 |
-
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
| 62 |
-
# Minimal schema for alerts as requested
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/cred_db_mcp/npi_client.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import httpx
|
| 3 |
-
from typing import Optional, Dict, Any
|
| 4 |
-
|
| 5 |
-
class NPIClient:
|
| 6 |
-
def __init__(self, base_url: Optional[str] = None):
|
| 7 |
-
self.base_url = base_url or os.getenv("NPI_MCP_URL", "http://localhost:8001")
|
| 8 |
-
# Ensure no trailing slash
|
| 9 |
-
self.base_url = self.base_url.rstrip("/")
|
| 10 |
-
|
| 11 |
-
async def get_provider_by_npi(self, npi: str) -> Optional[Dict[str, Any]]:
|
| 12 |
-
"""
|
| 13 |
-
Calls the npi_mcp server to get provider details.
|
| 14 |
-
Assumes npi_mcp exposes a tool 'get_provider_by_npi' via HTTP.
|
| 15 |
-
|
| 16 |
-
The NPI MCP is expected to have a similar structure: POST /mcp/tools/get_provider_by_npi
|
| 17 |
-
"""
|
| 18 |
-
url = f"{self.base_url}/mcp/tools/get_provider_by_npi"
|
| 19 |
-
|
| 20 |
-
# We assume the NPI MCP accepts arguments in the body
|
| 21 |
-
payload = {"npi": npi}
|
| 22 |
-
|
| 23 |
-
try:
|
| 24 |
-
async with httpx.AsyncClient() as client:
|
| 25 |
-
response = await client.post(url, json=payload, timeout=10.0)
|
| 26 |
-
|
| 27 |
-
if response.status_code == 200:
|
| 28 |
-
data = response.json()
|
| 29 |
-
# If the tool wraps return value, e.g. {"result": ...} adjust here.
|
| 30 |
-
# For now assuming direct return or generic tool response.
|
| 31 |
-
return data
|
| 32 |
-
elif response.status_code == 404:
|
| 33 |
-
return None
|
| 34 |
-
else:
|
| 35 |
-
# Log error or raise
|
| 36 |
-
print(f"Error calling NPI MCP: {response.status_code} {response.text}")
|
| 37 |
-
return None
|
| 38 |
-
except httpx.RequestError as e:
|
| 39 |
-
print(f"Request error calling NPI MCP: {e}")
|
| 40 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/cred_db_mcp/schemas.py
DELETED
|
@@ -1,79 +0,0 @@
|
|
| 1 |
-
from pydantic import BaseModel, ConfigDict
|
| 2 |
-
from typing import Optional, List, Any
|
| 3 |
-
from datetime import date, datetime
|
| 4 |
-
|
| 5 |
-
# --- Base Models ---
|
| 6 |
-
|
| 7 |
-
class ProviderBase(BaseModel):
|
| 8 |
-
npi: Optional[str] = None
|
| 9 |
-
full_name: str
|
| 10 |
-
dept: Optional[str] = None
|
| 11 |
-
location: Optional[str] = None
|
| 12 |
-
primary_specialty: Optional[str] = None
|
| 13 |
-
is_active: bool = True
|
| 14 |
-
|
| 15 |
-
class ProviderRead(ProviderBase):
|
| 16 |
-
id: int
|
| 17 |
-
created_at: datetime
|
| 18 |
-
updated_at: datetime
|
| 19 |
-
model_config = ConfigDict(from_attributes=True)
|
| 20 |
-
|
| 21 |
-
class CredentialBase(BaseModel):
|
| 22 |
-
type: str
|
| 23 |
-
issuer: str
|
| 24 |
-
number: str
|
| 25 |
-
status: str
|
| 26 |
-
issue_date: Optional[date] = None
|
| 27 |
-
expiry_date: Optional[date] = None
|
| 28 |
-
last_verified_at: Optional[datetime] = None
|
| 29 |
-
metadata_json: Optional[Any] = None
|
| 30 |
-
|
| 31 |
-
class CredentialRead(CredentialBase):
|
| 32 |
-
id: int
|
| 33 |
-
provider_id: int
|
| 34 |
-
created_at: datetime
|
| 35 |
-
updated_at: datetime
|
| 36 |
-
model_config = ConfigDict(from_attributes=True)
|
| 37 |
-
|
| 38 |
-
class AlertRead(BaseModel):
|
| 39 |
-
id: int
|
| 40 |
-
provider_id: int
|
| 41 |
-
message: str
|
| 42 |
-
created_at: datetime
|
| 43 |
-
model_config = ConfigDict(from_attributes=True)
|
| 44 |
-
|
| 45 |
-
# --- Tool IO Models ---
|
| 46 |
-
|
| 47 |
-
class SyncProviderInput(BaseModel):
|
| 48 |
-
npi: str
|
| 49 |
-
|
| 50 |
-
class AddCredentialInput(BaseModel):
|
| 51 |
-
provider_id: int
|
| 52 |
-
type: str
|
| 53 |
-
issuer: str
|
| 54 |
-
number: str
|
| 55 |
-
expiry_date: str # Expecting YYYY-MM-DD string as per prompt, or we can parse it
|
| 56 |
-
|
| 57 |
-
class ListExpiringInput(BaseModel):
|
| 58 |
-
window_days: int
|
| 59 |
-
dept: Optional[str] = None
|
| 60 |
-
location: Optional[str] = None
|
| 61 |
-
|
| 62 |
-
class ExpiringCredentialItem(BaseModel):
|
| 63 |
-
provider: ProviderRead
|
| 64 |
-
credential: CredentialRead
|
| 65 |
-
days_to_expiry: int
|
| 66 |
-
risk_score: int
|
| 67 |
-
|
| 68 |
-
class GetSnapshotInput(BaseModel):
|
| 69 |
-
provider_id: Optional[int] = None
|
| 70 |
-
npi: Optional[str] = None
|
| 71 |
-
|
| 72 |
-
class ProviderSnapshot(BaseModel):
|
| 73 |
-
provider: ProviderRead
|
| 74 |
-
credentials: List[CredentialRead]
|
| 75 |
-
alerts: Optional[List[AlertRead]] = []
|
| 76 |
-
|
| 77 |
-
class ToolError(BaseModel):
|
| 78 |
-
error: str
|
| 79 |
-
details: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/cred_db_mcp_server/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The `cred_db_mcp_server` package exposes an MCP server wrapper around the `CRED_API` service.
|
| 3 |
+
It uses Gradio to provide the MCP endpoints.
|
| 4 |
+
"""
|
src/cred_db_mcp_server/config.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
|
| 4 |
+
load_dotenv()
|
| 5 |
+
|
| 6 |
+
CRED_API_BASE_URL = os.getenv("CRED_API_BASE_URL", "http://localhost:8000")
|
src/cred_db_mcp_server/main.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from .tools import (
|
| 3 |
+
sync_provider_from_npi,
|
| 4 |
+
add_or_update_credential,
|
| 5 |
+
list_expiring_credentials,
|
| 6 |
+
get_provider_snapshot
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
def main():
|
| 10 |
+
"""
|
| 11 |
+
Main entry point for the Gradio MCP server.
|
| 12 |
+
"""
|
| 13 |
+
# Create the Gradio interface
|
| 14 |
+
# We use a TabbedInterface to organize the tools if accessed via UI,
|
| 15 |
+
# but primarily they are exposed as MCP tools.
|
| 16 |
+
|
| 17 |
+
# Tool 1: sync_provider_from_npi
|
| 18 |
+
iface_sync = gr.Interface(
|
| 19 |
+
fn=sync_provider_from_npi,
|
| 20 |
+
inputs=[gr.Textbox(label="NPI")],
|
| 21 |
+
outputs=gr.JSON(label="Provider Data"),
|
| 22 |
+
description="Syncs a provider's data from the NPI registry.",
|
| 23 |
+
allow_flagging="never"
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
# Tool 2: add_or_update_credential
|
| 27 |
+
iface_add_cred = gr.Interface(
|
| 28 |
+
fn=add_or_update_credential,
|
| 29 |
+
inputs=[
|
| 30 |
+
gr.Number(label="Provider ID", precision=0),
|
| 31 |
+
gr.Textbox(label="Type"),
|
| 32 |
+
gr.Textbox(label="Issuer"),
|
| 33 |
+
gr.Textbox(label="Number"),
|
| 34 |
+
gr.Textbox(label="Expiry Date (YYYY-MM-DD)"),
|
| 35 |
+
],
|
| 36 |
+
outputs=gr.JSON(label="Credential Data"),
|
| 37 |
+
description="Adds or updates a credential for a provider.",
|
| 38 |
+
allow_flagging="never"
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Tool 3: list_expiring_credentials
|
| 42 |
+
iface_expiring = gr.Interface(
|
| 43 |
+
fn=list_expiring_credentials,
|
| 44 |
+
inputs=[
|
| 45 |
+
gr.Number(label="Window Days", precision=0),
|
| 46 |
+
gr.Textbox(label="Department", value=None),
|
| 47 |
+
gr.Textbox(label="Location", value=None),
|
| 48 |
+
],
|
| 49 |
+
outputs=gr.JSON(label="Expiring Credentials"),
|
| 50 |
+
description="Lists credentials expiring within a certain number of days.",
|
| 51 |
+
allow_flagging="never"
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# Tool 4: get_provider_snapshot
|
| 55 |
+
iface_snapshot = gr.Interface(
|
| 56 |
+
fn=get_provider_snapshot,
|
| 57 |
+
inputs=[
|
| 58 |
+
gr.Number(label="Provider ID", precision=0, value=None),
|
| 59 |
+
gr.Textbox(label="NPI", value=None),
|
| 60 |
+
],
|
| 61 |
+
outputs=gr.JSON(label="Provider Snapshot"),
|
| 62 |
+
description="Gets a snapshot of a provider's data including credentials and alerts.",
|
| 63 |
+
allow_flagging="never"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
demo = gr.TabbedInterface(
|
| 67 |
+
[iface_sync, iface_add_cred, iface_expiring, iface_snapshot],
|
| 68 |
+
["Sync Provider", "Add/Update Credential", "List Expiring", "Provider Snapshot"]
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Launch with mcp_server=True to enable MCP endpoints
|
| 72 |
+
demo.launch(mcp_server=True)
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
main()
|
src/cred_db_mcp_server/schemas.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import Optional, List
|
| 3 |
+
|
| 4 |
+
# Provider Schema
|
| 5 |
+
class Provider(BaseModel):
|
| 6 |
+
id: int
|
| 7 |
+
npi: str
|
| 8 |
+
name: str
|
| 9 |
+
# Add other fields as necessary based on CRED_API response
|
| 10 |
+
|
| 11 |
+
class Credential(BaseModel):
|
| 12 |
+
id: Optional[int] = None
|
| 13 |
+
provider_id: int
|
| 14 |
+
type: str
|
| 15 |
+
issuer: str
|
| 16 |
+
number: str
|
| 17 |
+
expiry_date: str
|
| 18 |
+
# Add other fields as necessary
|
| 19 |
+
|
| 20 |
+
class ExpiringCredential(BaseModel):
|
| 21 |
+
provider: Provider
|
| 22 |
+
credential: Credential
|
| 23 |
+
days_to_expiry: int
|
| 24 |
+
risk_score: float
|
| 25 |
+
|
| 26 |
+
class ProviderSnapshot(BaseModel):
|
| 27 |
+
provider: Provider
|
| 28 |
+
credentials: List[Credential]
|
| 29 |
+
alerts: Optional[List[dict]] = None
|
src/cred_db_mcp_server/tools.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This module implements the tools exposed by the MCP server.
|
| 3 |
+
Each function corresponds to an MCP tool.
|
| 4 |
+
"""
|
| 5 |
+
import httpx
|
| 6 |
+
from typing import Optional, List, Dict, Any
|
| 7 |
+
from .config import CRED_API_BASE_URL
|
| 8 |
+
# from .schemas import Provider, Credential, ExpiringCredential, ProviderSnapshot
|
| 9 |
+
|
| 10 |
+
# We can use simple types in signature for Gradio MCP compatibility
|
| 11 |
+
# but use schemas for internal validation if needed.
|
| 12 |
+
# Since the prompt asked for mapped errors, we'll wrap calls.
|
| 13 |
+
|
| 14 |
+
def sync_provider_from_npi(npi: str) -> Dict[str, Any]:
|
| 15 |
+
"""
|
| 16 |
+
Syncs a provider's data from the NPI registry.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
npi (str): The NPI number of the provider.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
dict: The provider object returned by the API.
|
| 23 |
+
"""
|
| 24 |
+
url = f"{CRED_API_BASE_URL}/providers/sync_from_npi"
|
| 25 |
+
try:
|
| 26 |
+
with httpx.Client() as client:
|
| 27 |
+
response = client.post(url, json={"npi": npi})
|
| 28 |
+
response.raise_for_status()
|
| 29 |
+
return response.json()
|
| 30 |
+
except httpx.HTTPStatusError as e:
|
| 31 |
+
# Map HTTP errors to MCP/User friendly errors
|
| 32 |
+
if e.response.status_code == 404:
|
| 33 |
+
return {"error": f"Provider with NPI {npi} not found or sync failed."}
|
| 34 |
+
return {"error": f"API Error: {e.response.text}"}
|
| 35 |
+
except Exception as e:
|
| 36 |
+
return {"error": f"Connection Error: {str(e)}"}
|
| 37 |
+
|
| 38 |
+
def add_or_update_credential(
|
| 39 |
+
provider_id: int,
|
| 40 |
+
type: str,
|
| 41 |
+
issuer: str,
|
| 42 |
+
number: str,
|
| 43 |
+
expiry_date: str
|
| 44 |
+
) -> Dict[str, Any]:
|
| 45 |
+
"""
|
| 46 |
+
Adds or updates a credential for a provider.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
provider_id (int): The internal ID of the provider.
|
| 50 |
+
type (str): The type of credential (e.g., 'Medical License').
|
| 51 |
+
issuer (str): The issuing body (e.g., 'State Board').
|
| 52 |
+
number (str): The credential number.
|
| 53 |
+
expiry_date (str): The expiry date in YYYY-MM-DD format.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
dict: The created or updated credential object.
|
| 57 |
+
"""
|
| 58 |
+
url = f"{CRED_API_BASE_URL}/credentials/add_or_update"
|
| 59 |
+
payload = {
|
| 60 |
+
"provider_id": provider_id,
|
| 61 |
+
"type": type,
|
| 62 |
+
"issuer": issuer,
|
| 63 |
+
"number": number,
|
| 64 |
+
"expiry_date": expiry_date
|
| 65 |
+
}
|
| 66 |
+
try:
|
| 67 |
+
with httpx.Client() as client:
|
| 68 |
+
response = client.post(url, json=payload)
|
| 69 |
+
response.raise_for_status()
|
| 70 |
+
return response.json()
|
| 71 |
+
except httpx.HTTPStatusError as e:
|
| 72 |
+
return {"error": f"API Error: {e.response.text}"}
|
| 73 |
+
except Exception as e:
|
| 74 |
+
return {"error": f"Connection Error: {str(e)}"}
|
| 75 |
+
|
| 76 |
+
def list_expiring_credentials(
|
| 77 |
+
window_days: int,
|
| 78 |
+
dept: Optional[str] = None,
|
| 79 |
+
location: Optional[str] = None
|
| 80 |
+
) -> List[Dict[str, Any]]:
|
| 81 |
+
"""
|
| 82 |
+
Lists credentials expiring within a certain number of days.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
window_days (int): The number of days to check for expiry.
|
| 86 |
+
dept (str, optional): Filter by department.
|
| 87 |
+
location (str, optional): Filter by location.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
list: A list of objects containing provider, credential, days_to_expiry, and risk_score.
|
| 91 |
+
"""
|
| 92 |
+
url = f"{CRED_API_BASE_URL}/credentials/expiring"
|
| 93 |
+
payload = {
|
| 94 |
+
"window_days": window_days,
|
| 95 |
+
"dept": dept,
|
| 96 |
+
"location": location
|
| 97 |
+
}
|
| 98 |
+
# Remove None values to avoid sending them if API doesn't expect them or treat them as valid
|
| 99 |
+
payload = {k: v for k, v in payload.items() if v is not None}
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
with httpx.Client() as client:
|
| 103 |
+
response = client.post(url, json=payload)
|
| 104 |
+
response.raise_for_status()
|
| 105 |
+
return response.json()
|
| 106 |
+
except httpx.HTTPStatusError as e:
|
| 107 |
+
# In case of error, return a list with an error dict or raise
|
| 108 |
+
# For MCP, returning structured error info is usually better than crashing
|
| 109 |
+
# But for list return type, we might need to handle differently.
|
| 110 |
+
# Here we assume the tool call handles exceptions or checks for "error" key in result if it was a dict.
|
| 111 |
+
# Since return type is List, we can't easily return a dict error.
|
| 112 |
+
# We'll return an empty list and print error or raise.
|
| 113 |
+
# Let's raise ValueError which Gradio might catch and show.
|
| 114 |
+
raise ValueError(f"API Error: {e.response.text}")
|
| 115 |
+
except Exception as e:
|
| 116 |
+
raise ConnectionError(f"Connection Error: {str(e)}")
|
| 117 |
+
|
| 118 |
+
def get_provider_snapshot(
|
| 119 |
+
provider_id: Optional[int] = None,
|
| 120 |
+
npi: Optional[str] = None
|
| 121 |
+
) -> Dict[str, Any]:
|
| 122 |
+
"""
|
| 123 |
+
Gets a snapshot of a provider's data including credentials and alerts.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
provider_id (int, optional): The provider's internal ID.
|
| 127 |
+
npi (str, optional): The provider's NPI.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
dict: Object containing provider details, credentials, and alerts.
|
| 131 |
+
"""
|
| 132 |
+
if provider_id is None and npi is None:
|
| 133 |
+
return {"error": "Must provide either provider_id or npi."}
|
| 134 |
+
|
| 135 |
+
url = f"{CRED_API_BASE_URL}/providers/snapshot"
|
| 136 |
+
payload = {}
|
| 137 |
+
if provider_id is not None:
|
| 138 |
+
payload["provider_id"] = provider_id
|
| 139 |
+
if npi is not None:
|
| 140 |
+
payload["npi"] = npi
|
| 141 |
+
|
| 142 |
+
try:
|
| 143 |
+
with httpx.Client() as client:
|
| 144 |
+
response = client.post(url, json=payload)
|
| 145 |
+
response.raise_for_status()
|
| 146 |
+
return response.json()
|
| 147 |
+
except httpx.HTTPStatusError as e:
|
| 148 |
+
if e.response.status_code == 404:
|
| 149 |
+
return {"error": "Provider not found."}
|
| 150 |
+
return {"error": f"API Error: {e.response.text}"}
|
| 151 |
+
except Exception as e:
|
| 152 |
+
return {"error": f"Connection Error: {str(e)}"}
|
tests/test_cred_db_mcp_server.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import pytest
|
| 3 |
+
from unittest.mock import patch, MagicMock
|
| 4 |
+
from cred_db_mcp_server.tools import (
|
| 5 |
+
sync_provider_from_npi,
|
| 6 |
+
add_or_update_credential,
|
| 7 |
+
list_expiring_credentials,
|
| 8 |
+
get_provider_snapshot
|
| 9 |
+
)
|
| 10 |
+
import httpx
|
| 11 |
+
|
| 12 |
+
# Mock response data
|
| 13 |
+
MOCK_PROVIDER = {"id": 1, "npi": "1234567890", "name": "Dr. Test"}
|
| 14 |
+
MOCK_CREDENTIAL = {
|
| 15 |
+
"id": 10,
|
| 16 |
+
"provider_id": 1,
|
| 17 |
+
"type": "Medical License",
|
| 18 |
+
"issuer": "State Board",
|
| 19 |
+
"number": "MD12345",
|
| 20 |
+
"expiry_date": "2025-01-01"
|
| 21 |
+
}
|
| 22 |
+
MOCK_EXPIRING_LIST = [
|
| 23 |
+
{
|
| 24 |
+
"provider": MOCK_PROVIDER,
|
| 25 |
+
"credential": MOCK_CREDENTIAL,
|
| 26 |
+
"days_to_expiry": 30,
|
| 27 |
+
"risk_score": 0.5
|
| 28 |
+
}
|
| 29 |
+
]
|
| 30 |
+
MOCK_SNAPSHOT = {
|
| 31 |
+
"provider": MOCK_PROVIDER,
|
| 32 |
+
"credentials": [MOCK_CREDENTIAL],
|
| 33 |
+
"alerts": []
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
@patch("cred_db_mcp_server.tools.httpx.Client")
|
| 37 |
+
def test_sync_provider_from_npi(mock_client_cls):
|
| 38 |
+
mock_client = MagicMock()
|
| 39 |
+
mock_client_cls.return_value.__enter__.return_value = mock_client
|
| 40 |
+
mock_client.post.return_value.status_code = 200
|
| 41 |
+
mock_client.post.return_value.json.return_value = MOCK_PROVIDER
|
| 42 |
+
|
| 43 |
+
result = sync_provider_from_npi("1234567890")
|
| 44 |
+
assert result == MOCK_PROVIDER
|
| 45 |
+
mock_client.post.assert_called_once()
|
| 46 |
+
assert "sync_from_npi" in mock_client.post.call_args[0][0]
|
| 47 |
+
|
| 48 |
+
@patch("cred_db_mcp_server.tools.httpx.Client")
|
| 49 |
+
def test_add_or_update_credential(mock_client_cls):
|
| 50 |
+
mock_client = MagicMock()
|
| 51 |
+
mock_client_cls.return_value.__enter__.return_value = mock_client
|
| 52 |
+
mock_client.post.return_value.status_code = 200
|
| 53 |
+
mock_client.post.return_value.json.return_value = MOCK_CREDENTIAL
|
| 54 |
+
|
| 55 |
+
result = add_or_update_credential(1, "Medical License", "State Board", "MD12345", "2025-01-01")
|
| 56 |
+
assert result == MOCK_CREDENTIAL
|
| 57 |
+
mock_client.post.assert_called_once()
|
| 58 |
+
assert "add_or_update" in mock_client.post.call_args[0][0]
|
| 59 |
+
|
| 60 |
+
@patch("cred_db_mcp_server.tools.httpx.Client")
|
| 61 |
+
def test_list_expiring_credentials(mock_client_cls):
|
| 62 |
+
mock_client = MagicMock()
|
| 63 |
+
mock_client_cls.return_value.__enter__.return_value = mock_client
|
| 64 |
+
mock_client.post.return_value.status_code = 200
|
| 65 |
+
mock_client.post.return_value.json.return_value = MOCK_EXPIRING_LIST
|
| 66 |
+
|
| 67 |
+
result = list_expiring_credentials(30)
|
| 68 |
+
assert result == MOCK_EXPIRING_LIST
|
| 69 |
+
mock_client.post.assert_called_once()
|
| 70 |
+
assert "expiring" in mock_client.post.call_args[0][0]
|
| 71 |
+
|
| 72 |
+
@patch("cred_db_mcp_server.tools.httpx.Client")
|
| 73 |
+
def test_get_provider_snapshot(mock_client_cls):
|
| 74 |
+
mock_client = MagicMock()
|
| 75 |
+
mock_client_cls.return_value.__enter__.return_value = mock_client
|
| 76 |
+
mock_client.post.return_value.status_code = 200
|
| 77 |
+
mock_client.post.return_value.json.return_value = MOCK_SNAPSHOT
|
| 78 |
+
|
| 79 |
+
result = get_provider_snapshot(provider_id=1)
|
| 80 |
+
assert result == MOCK_SNAPSHOT
|
| 81 |
+
mock_client.post.assert_called_once()
|
| 82 |
+
assert "snapshot" in mock_client.post.call_args[0][0]
|
| 83 |
+
|
| 84 |
+
def test_get_provider_snapshot_no_args():
|
| 85 |
+
result = get_provider_snapshot()
|
| 86 |
+
assert "error" in result
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|