google-labs-jules[bot] commited on
Commit
6682851
·
0 Parent(s):

Implement cred_db_mcp server with FastAPI and SQLite

Browse files

- Set up project structure with uv/pyproject.toml
- Implement SQLAlchemy 2.0 models for Providers and Credentials
- Implement MCP tools: sync_provider, add_credential, list_expiring, get_snapshot
- Expose tools via FastAPI endpoints
- Add comprehensive tests for DB logic

.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.db
4
+ .env
5
+ .pytest_cache/
README.md ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Usage of cred_db_mcp
3
+
4
+ This server provides direct database access and logic for **CredentialWatch** via MCP-compliant tools exposed as HTTP endpoints.
5
+
6
+ ## Agents Usage
7
+
8
+ Agents can use this server to:
9
+ 1. **Onboard new providers**: Call `sync_provider_from_npi` to fetch details from the NPI registry and create a local record.
10
+ 2. **Manage Credentials**: Use `add_or_update_credential` to keep license data up to date.
11
+ 3. **Monitor Compliance**: Use `list_expiring_credentials` to proactively find providers who need to renew licenses.
12
+ 4. **Context Retrieval**: Use `get_provider_snapshot` to get all known data about a provider before answering user questions.
13
+
14
+ ## Running the Server
15
+
16
+ ```bash
17
+ # Install dependencies
18
+ pip install -e .
19
+
20
+ # Run
21
+ uvicorn src.cred_db_mcp.main:app --reload
22
+ ```
23
+
24
+ ## Example Tool Call
25
+
26
+ **Tool:** `list_expiring_credentials`
27
+ **Endpoint:** `POST /mcp/tools/list_expiring_credentials`
28
+ **Headers:** `Content-Type: application/json`
29
+
30
+ **Body:**
31
+ ```json
32
+ {
33
+ "window_days": 90,
34
+ "dept": "Cardiology"
35
+ }
36
+ ```
37
+
38
+ **Response:**
39
+ ```json
40
+ [
41
+ {
42
+ "provider": { "full_name": "Dr. Alice Smith", ... },
43
+ "credential": { "type": "state_license", "expiry_date": "2023-12-01", ... },
44
+ "days_to_expiry": 25,
45
+ "risk_score": 3
46
+ }
47
+ ]
48
+ ```
pyproject.toml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "cred_db_mcp"
3
+ version = "0.1.0"
4
+ description = "Credential Database MCP Server"
5
+ requires-python = ">=3.11"
6
+ dependencies = [
7
+ "fastapi>=0.100.0",
8
+ "uvicorn>=0.20.0",
9
+ "sqlalchemy>=2.0.0",
10
+ "pydantic>=2.0.0",
11
+ "httpx>=0.24.0",
12
+ "python-dotenv>=1.0.0"
13
+ ]
14
+
15
+ [project.optional-dependencies]
16
+ test = [
17
+ "pytest>=7.0.0",
18
+ "pytest-asyncio>=0.21.0"
19
+ ]
20
+
21
+ [build-system]
22
+ requires = ["hatchling"]
23
+ build-backend = "hatchling.build"
24
+
25
+ [tool.pytest.ini_options]
26
+ pythonpath = "src"
27
+ testpaths = ["tests"]
src/cred_db_mcp/db.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from sqlalchemy import create_engine
3
+ from sqlalchemy.orm import sessionmaker, DeclarativeBase
4
+
5
+ # Default to a local file for dev/test if not specified
6
+ DB_PATH = os.getenv("DB_PATH", "credentialwatch.db")
7
+ DATABASE_URL = f"sqlite:///{DB_PATH}"
8
+
9
+ # specific check for checks that might run in different environments
10
+ if DB_PATH == ":memory:":
11
+ DATABASE_URL = "sqlite:///:memory:"
12
+
13
+ engine = create_engine(
14
+ DATABASE_URL,
15
+ connect_args={"check_same_thread": False} # Needed for SQLite
16
+ )
17
+
18
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
19
+
20
+ class Base(DeclarativeBase):
21
+ pass
22
+
23
+ def get_db():
24
+ """Dependency for FastAPI to get a DB session."""
25
+ db = SessionLocal()
26
+ try:
27
+ yield db
28
+ finally:
29
+ db.close()
30
+
31
+ def init_db():
32
+ """Helper to initialize the DB (create tables)."""
33
+ Base.metadata.create_all(bind=engine)
src/cred_db_mcp/main.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Depends, HTTPException, Body
2
+ from sqlalchemy.orm import Session
3
+ from typing import List, Union
4
+
5
+ from .db import get_db, init_db
6
+ from .mcp_tools import MCPTools
7
+ from .schemas import (
8
+ SyncProviderInput, AddCredentialInput, ListExpiringInput, GetSnapshotInput,
9
+ ProviderRead, CredentialRead, ExpiringCredentialItem, ProviderSnapshot, ToolError
10
+ )
11
+
12
+ app = FastAPI(title="CredentialWatch DB MCP", version="0.1.0")
13
+
14
+ # Initialize DB on startup (for simple deployments)
15
+ @app.on_event("startup")
16
+ def on_startup():
17
+ init_db()
18
+
19
+ def get_mcp_tools(db: Session = Depends(get_db)) -> MCPTools:
20
+ return MCPTools(db)
21
+
22
+ # --- MCP Tool Endpoints ---
23
+ # Pattern: /mcp/tools/{tool_name}
24
+ # We use POST for all of them as they are "function calls"
25
+
26
+ @app.post("/mcp/tools/sync_provider_from_npi", response_model=Union[ProviderRead, ToolError])
27
+ async def sync_provider_from_npi(
28
+ args: SyncProviderInput,
29
+ tools: MCPTools = Depends(get_mcp_tools)
30
+ ):
31
+ """
32
+ Syncs provider data from the NPI Registry via npi_mcp.
33
+ """
34
+ result = await tools.sync_provider_from_npi(args.npi)
35
+ if isinstance(result, dict) and "error" in result:
36
+ # Return generic error structure instead of 400 for agent handling if preferred,
37
+ # but usually 400 is better for HTTP.
38
+ # The prompt says "return a structured error".
39
+ return ToolError(error=result["error"])
40
+ return result
41
+
42
+ @app.post("/mcp/tools/add_or_update_credential", response_model=Union[CredentialRead, ToolError])
43
+ def add_or_update_credential(
44
+ args: AddCredentialInput,
45
+ tools: MCPTools = Depends(get_mcp_tools)
46
+ ):
47
+ """
48
+ Adds or updates a credential record.
49
+ """
50
+ result = tools.add_or_update_credential(
51
+ provider_id=args.provider_id,
52
+ type=args.type,
53
+ issuer=args.issuer,
54
+ number=args.number,
55
+ expiry_date=args.expiry_date
56
+ )
57
+ if isinstance(result, dict) and "error" in result:
58
+ return ToolError(error=result["error"])
59
+ return result
60
+
61
+ @app.post("/mcp/tools/list_expiring_credentials", response_model=List[ExpiringCredentialItem])
62
+ def list_expiring_credentials(
63
+ args: ListExpiringInput,
64
+ tools: MCPTools = Depends(get_mcp_tools)
65
+ ):
66
+ """
67
+ Lists credentials expiring within a window.
68
+ """
69
+ return tools.list_expiring_credentials(
70
+ window_days=args.window_days,
71
+ dept=args.dept,
72
+ location=args.location
73
+ )
74
+
75
+ @app.post("/mcp/tools/get_provider_snapshot", response_model=Union[ProviderSnapshot, ToolError])
76
+ def get_provider_snapshot(
77
+ args: GetSnapshotInput,
78
+ tools: MCPTools = Depends(get_mcp_tools)
79
+ ):
80
+ """
81
+ Gets a full snapshot of a provider.
82
+ """
83
+ result = tools.get_provider_snapshot(
84
+ provider_id=args.provider_id,
85
+ npi=args.npi
86
+ )
87
+ if isinstance(result, dict) and "error" in result:
88
+ return ToolError(error=result["error"])
89
+ return result
90
+
91
+ # Simple info endpoint
92
+ @app.get("/")
93
+ def read_root():
94
+ return {
95
+ "service": "cred_db_mcp",
96
+ "status": "running",
97
+ "docs": "/docs"
98
+ }
src/cred_db_mcp/mcp_tools.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy.orm import Session
2
+ from sqlalchemy import select, and_, func
3
+ from datetime import date, timedelta, datetime
4
+ import json
5
+
6
+ from .models import Provider, Credential, Alert
7
+ from .schemas import (
8
+ ProviderRead, CredentialRead, AlertRead,
9
+ ExpiringCredentialItem, ProviderSnapshot
10
+ )
11
+ from .npi_client import NPIClient
12
+
13
+ class MCPTools:
14
+ def __init__(self, db: Session):
15
+ self.db = db
16
+ self.npi_client = NPIClient()
17
+
18
+ async def sync_provider_from_npi(self, npi: str):
19
+ # 1. Call npi_mcp
20
+ npi_data = await self.npi_client.get_provider_by_npi(npi)
21
+
22
+ if not npi_data:
23
+ return {"error": f"NPI {npi} not found in upstream registry"}
24
+
25
+ # Map NPI data to our schema.
26
+ # Assuming npi_data has keys: npi, first_name, last_name, taxonomy_desc, practice_address
27
+ # This mapping is hypothetical based on typical NPI registry fields.
28
+
29
+ full_name = f"{npi_data.get('first_name', '')} {npi_data.get('last_name', '')}".strip()
30
+ if not full_name:
31
+ full_name = npi_data.get('organization_name', 'Unknown')
32
+
33
+ specialty = npi_data.get('taxonomy_desc', 'Unknown')
34
+ address_obj = npi_data.get('practice_address', {})
35
+ # Simple string representation of location
36
+ location = f"{address_obj.get('city', '')}, {address_obj.get('state', '')}" if isinstance(address_obj, dict) else str(address_obj)
37
+
38
+ # 2. Update or Create in DB
39
+ provider = self.db.scalar(select(Provider).where(Provider.npi == npi))
40
+
41
+ if not provider:
42
+ provider = Provider(
43
+ npi=npi,
44
+ full_name=full_name,
45
+ primary_specialty=specialty,
46
+ location=location,
47
+ is_active=True
48
+ )
49
+ self.db.add(provider)
50
+ else:
51
+ provider.full_name = full_name
52
+ provider.primary_specialty = specialty
53
+ provider.location = location
54
+ # Don't overwrite dept if set locally? assuming upstream doesn't have dept.
55
+
56
+ self.db.commit()
57
+ self.db.refresh(provider)
58
+
59
+ return ProviderRead.model_validate(provider)
60
+
61
+ def add_or_update_credential(
62
+ self, provider_id: int, type: str, issuer: str, number: str, expiry_date: str
63
+ ):
64
+ # Parse expiry_date (assuming YYYY-MM-DD)
65
+ try:
66
+ exp_date = datetime.strptime(expiry_date, "%Y-%m-%d").date()
67
+ except ValueError:
68
+ return {"error": "Invalid date format. Use YYYY-MM-DD"}
69
+
70
+ # Check provider exists
71
+ provider = self.db.get(Provider, provider_id)
72
+ if not provider:
73
+ return {"error": f"Provider with ID {provider_id} not found"}
74
+
75
+ # Find existing credential
76
+ stmt = select(Credential).where(
77
+ and_(
78
+ Credential.provider_id == provider_id,
79
+ Credential.type == type,
80
+ Credential.issuer == issuer,
81
+ Credential.number == number
82
+ )
83
+ )
84
+ credential = self.db.scalar(stmt)
85
+
86
+ if credential:
87
+ credential.expiry_date = exp_date
88
+ # Heuristic status update
89
+ credential.status = "active" if exp_date > date.today() else "expired"
90
+ credential.updated_at = datetime.now()
91
+ else:
92
+ credential = Credential(
93
+ provider_id=provider_id,
94
+ type=type,
95
+ issuer=issuer,
96
+ number=number,
97
+ expiry_date=exp_date,
98
+ status="active" if exp_date > date.today() else "expired",
99
+ created_at=datetime.now()
100
+ )
101
+ self.db.add(credential)
102
+
103
+ self.db.commit()
104
+ self.db.refresh(credential)
105
+
106
+ return CredentialRead.model_validate(credential)
107
+
108
+ def list_expiring_credentials(
109
+ self, window_days: int, dept: str | None = None, location: str | None = None
110
+ ):
111
+ today = date.today()
112
+ cutoff = today + timedelta(days=window_days)
113
+
114
+ # Build Query
115
+ stmt = select(Credential, Provider).join(Provider).where(
116
+ and_(
117
+ Credential.status == "active",
118
+ Credential.expiry_date <= cutoff,
119
+ Credential.expiry_date >= today # Only future or today (past would be expired)
120
+ )
121
+ )
122
+
123
+ if dept:
124
+ stmt = stmt.where(Provider.dept == dept)
125
+ if location:
126
+ stmt = stmt.where(Provider.location.ilike(f"%{location}%"))
127
+
128
+ results = self.db.execute(stmt).all()
129
+
130
+ output = []
131
+ for cred, prov in results:
132
+ if not cred.expiry_date:
133
+ continue
134
+
135
+ days_to_expiry = (cred.expiry_date - today).days
136
+
137
+ # Risk Score Heuristic
138
+ # 3 for <30 days, 2 for 30–60, 1 for 60–90
139
+ if days_to_expiry < 30:
140
+ risk_score = 3
141
+ elif days_to_expiry < 60:
142
+ risk_score = 2
143
+ else:
144
+ risk_score = 1
145
+
146
+ output.append(ExpiringCredentialItem(
147
+ provider=ProviderRead.model_validate(prov),
148
+ credential=CredentialRead.model_validate(cred),
149
+ days_to_expiry=days_to_expiry,
150
+ risk_score=risk_score
151
+ ))
152
+
153
+ return output
154
+
155
+ def get_provider_snapshot(
156
+ self, provider_id: int | None = None, npi: str | None = None
157
+ ):
158
+ if not provider_id and not npi:
159
+ return {"error": "Must provide either provider_id or npi"}
160
+
161
+ stmt = select(Provider)
162
+ if provider_id:
163
+ stmt = stmt.where(Provider.id == provider_id)
164
+ else:
165
+ stmt = stmt.where(Provider.npi == npi)
166
+
167
+ provider = self.db.scalar(stmt)
168
+ if not provider:
169
+ return {"error": "Provider not found"}
170
+
171
+ # Get Credentials
172
+ creds = self.db.scalars(
173
+ select(Credential).where(Credential.provider_id == provider.id)
174
+ ).all()
175
+
176
+ # Get Alerts (Optional)
177
+ alerts = self.db.scalars(
178
+ select(Alert).where(Alert.provider_id == provider.id)
179
+ ).all()
180
+
181
+ return ProviderSnapshot(
182
+ provider=ProviderRead.model_validate(provider),
183
+ credentials=[CredentialRead.model_validate(c) for c in creds],
184
+ alerts=[AlertRead.model_validate(a) for a in alerts]
185
+ )
src/cred_db_mcp/models.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime, date
2
+ from typing import Optional, List, Any
3
+ from sqlalchemy import (
4
+ String, Integer, Boolean, DateTime, Date, ForeignKey, JSON, func
5
+ )
6
+ from sqlalchemy.orm import Mapped, mapped_column, relationship
7
+
8
+ from .db import Base
9
+
10
+ class Provider(Base):
11
+ __tablename__ = "providers"
12
+
13
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
14
+ npi: Mapped[Optional[str]] = mapped_column(String, index=True, nullable=True)
15
+ full_name: Mapped[str] = mapped_column(String)
16
+ dept: Mapped[Optional[str]] = mapped_column(String, nullable=True)
17
+ location: Mapped[Optional[str]] = mapped_column(String, nullable=True)
18
+ primary_specialty: Mapped[Optional[str]] = mapped_column(String, nullable=True)
19
+ is_active: Mapped[bool] = mapped_column(Boolean, default=True)
20
+
21
+ created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
22
+ updated_at: Mapped[datetime] = mapped_column(
23
+ DateTime, server_default=func.now(), onupdate=func.now()
24
+ )
25
+
26
+ credentials: Mapped[List["Credential"]] = relationship(
27
+ "Credential", back_populates="provider", cascade="all, delete-orphan"
28
+ )
29
+
30
+ class Credential(Base):
31
+ __tablename__ = "credentials"
32
+
33
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
34
+ provider_id: Mapped[int] = mapped_column(
35
+ Integer, ForeignKey("providers.id"), nullable=False
36
+ )
37
+ type: Mapped[str] = mapped_column(String) # e.g. "state_license", "board_cert"
38
+ issuer: Mapped[str] = mapped_column(String)
39
+ number: Mapped[str] = mapped_column(String)
40
+ status: Mapped[str] = mapped_column(String) # "active", "expired", etc.
41
+
42
+ issue_date: Mapped[Optional[date]] = mapped_column(Date, nullable=True)
43
+ expiry_date: Mapped[Optional[date]] = mapped_column(Date, nullable=True)
44
+ last_verified_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
45
+
46
+ metadata_json: Mapped[Optional[Any]] = mapped_column(JSON, nullable=True)
47
+
48
+ created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
49
+ updated_at: Mapped[datetime] = mapped_column(
50
+ DateTime, server_default=func.now(), onupdate=func.now()
51
+ )
52
+
53
+ provider: Mapped["Provider"] = relationship("Provider", back_populates="credentials")
54
+
55
+ class Alert(Base):
56
+ __tablename__ = "alerts"
57
+
58
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
59
+ provider_id: Mapped[int] = mapped_column(Integer, ForeignKey("providers.id"))
60
+ message: Mapped[str] = mapped_column(String)
61
+ created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
62
+ # Minimal schema for alerts as requested
src/cred_db_mcp/npi_client.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import httpx
3
+ from typing import Optional, Dict, Any
4
+
5
+ class NPIClient:
6
+ def __init__(self, base_url: Optional[str] = None):
7
+ self.base_url = base_url or os.getenv("NPI_MCP_URL", "http://localhost:8001")
8
+ # Ensure no trailing slash
9
+ self.base_url = self.base_url.rstrip("/")
10
+
11
+ async def get_provider_by_npi(self, npi: str) -> Optional[Dict[str, Any]]:
12
+ """
13
+ Calls the npi_mcp server to get provider details.
14
+ Assumes npi_mcp exposes a tool 'get_provider_by_npi' via HTTP.
15
+
16
+ The NPI MCP is expected to have a similar structure: POST /mcp/tools/get_provider_by_npi
17
+ """
18
+ url = f"{self.base_url}/mcp/tools/get_provider_by_npi"
19
+
20
+ # We assume the NPI MCP accepts arguments in the body
21
+ payload = {"npi": npi}
22
+
23
+ try:
24
+ async with httpx.AsyncClient() as client:
25
+ response = await client.post(url, json=payload, timeout=10.0)
26
+
27
+ if response.status_code == 200:
28
+ data = response.json()
29
+ # If the tool wraps return value, e.g. {"result": ...} adjust here.
30
+ # For now assuming direct return or generic tool response.
31
+ return data
32
+ elif response.status_code == 404:
33
+ return None
34
+ else:
35
+ # Log error or raise
36
+ print(f"Error calling NPI MCP: {response.status_code} {response.text}")
37
+ return None
38
+ except httpx.RequestError as e:
39
+ print(f"Request error calling NPI MCP: {e}")
40
+ return None
src/cred_db_mcp/schemas.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, ConfigDict
2
+ from typing import Optional, List, Any
3
+ from datetime import date, datetime
4
+
5
+ # --- Base Models ---
6
+
7
+ class ProviderBase(BaseModel):
8
+ npi: Optional[str] = None
9
+ full_name: str
10
+ dept: Optional[str] = None
11
+ location: Optional[str] = None
12
+ primary_specialty: Optional[str] = None
13
+ is_active: bool = True
14
+
15
+ class ProviderRead(ProviderBase):
16
+ id: int
17
+ created_at: datetime
18
+ updated_at: datetime
19
+ model_config = ConfigDict(from_attributes=True)
20
+
21
+ class CredentialBase(BaseModel):
22
+ type: str
23
+ issuer: str
24
+ number: str
25
+ status: str
26
+ issue_date: Optional[date] = None
27
+ expiry_date: Optional[date] = None
28
+ last_verified_at: Optional[datetime] = None
29
+ metadata_json: Optional[Any] = None
30
+
31
+ class CredentialRead(CredentialBase):
32
+ id: int
33
+ provider_id: int
34
+ created_at: datetime
35
+ updated_at: datetime
36
+ model_config = ConfigDict(from_attributes=True)
37
+
38
+ class AlertRead(BaseModel):
39
+ id: int
40
+ provider_id: int
41
+ message: str
42
+ created_at: datetime
43
+ model_config = ConfigDict(from_attributes=True)
44
+
45
+ # --- Tool IO Models ---
46
+
47
+ class SyncProviderInput(BaseModel):
48
+ npi: str
49
+
50
+ class AddCredentialInput(BaseModel):
51
+ provider_id: int
52
+ type: str
53
+ issuer: str
54
+ number: str
55
+ expiry_date: str # Expecting YYYY-MM-DD string as per prompt, or we can parse it
56
+
57
+ class ListExpiringInput(BaseModel):
58
+ window_days: int
59
+ dept: Optional[str] = None
60
+ location: Optional[str] = None
61
+
62
+ class ExpiringCredentialItem(BaseModel):
63
+ provider: ProviderRead
64
+ credential: CredentialRead
65
+ days_to_expiry: int
66
+ risk_score: int
67
+
68
+ class GetSnapshotInput(BaseModel):
69
+ provider_id: Optional[int] = None
70
+ npi: Optional[str] = None
71
+
72
+ class ProviderSnapshot(BaseModel):
73
+ provider: ProviderRead
74
+ credentials: List[CredentialRead]
75
+ alerts: Optional[List[AlertRead]] = []
76
+
77
+ class ToolError(BaseModel):
78
+ error: str
79
+ details: Optional[str] = None
tests/mock_npi_mcp.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+
4
+ app = FastAPI()
5
+
6
+ class NpiRequest(BaseModel):
7
+ npi: str
8
+
9
+ @app.post("/mcp/tools/get_provider_by_npi")
10
+ def get_provider_by_npi(req: NpiRequest):
11
+ if req.npi == "1234567890":
12
+ return {
13
+ "npi": "1234567890",
14
+ "first_name": "John",
15
+ "last_name": "Doe",
16
+ "taxonomy_desc": "Cardiology",
17
+ "practice_address": {
18
+ "city": "New York",
19
+ "state": "NY"
20
+ }
21
+ }
22
+ return None
tests/test_cred_db_mcp.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from fastapi.testclient import TestClient
3
+ from sqlalchemy import create_engine
4
+ from sqlalchemy.orm import sessionmaker
5
+ from sqlalchemy.pool import StaticPool
6
+ from datetime import date, timedelta
7
+ import os
8
+
9
+ from src.cred_db_mcp.main import app, get_db
10
+ from src.cred_db_mcp.db import Base
11
+ from src.cred_db_mcp.models import Provider, Credential
12
+
13
+ # Setup in-memory DB for tests
14
+ # Use StaticPool to share the in-memory DB across multiple sessions/connections
15
+ SQLALCHEMY_DATABASE_URL = "sqlite:///:memory:"
16
+
17
+ engine = create_engine(
18
+ SQLALCHEMY_DATABASE_URL,
19
+ connect_args={"check_same_thread": False},
20
+ poolclass=StaticPool
21
+ )
22
+ TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
23
+
24
+ def override_get_db():
25
+ try:
26
+ db = TestingSessionLocal()
27
+ yield db
28
+ finally:
29
+ db.close()
30
+
31
+ app.dependency_overrides[get_db] = override_get_db
32
+
33
+ @pytest.fixture(scope="module")
34
+ def test_client():
35
+ # Create tables once for the module
36
+ Base.metadata.create_all(bind=engine)
37
+ client = TestClient(app)
38
+ yield client
39
+ Base.metadata.drop_all(bind=engine)
40
+
41
+ @pytest.fixture(autouse=True)
42
+ def clean_tables():
43
+ # Optional: Clear data between tests if needed, but for now just appending is fine
44
+ # as long as IDs/NPIs don't clash or we don't care about accumulation.
45
+ # To be safer, we can delete all data.
46
+ with engine.connect() as conn:
47
+ conn.execute(Credential.__table__.delete())
48
+ conn.execute(Provider.__table__.delete())
49
+ conn.commit()
50
+
51
+ @pytest.fixture
52
+ def db_session():
53
+ db = TestingSessionLocal()
54
+ yield db
55
+ db.close()
56
+
57
+ def test_read_root(test_client):
58
+ response = test_client.get("/")
59
+ assert response.status_code == 200
60
+ assert response.json()["service"] == "cred_db_mcp"
61
+
62
+ def test_add_and_snapshot_provider(test_client, db_session):
63
+ prov = Provider(
64
+ npi="999999", full_name="Test Doc", primary_specialty="General", is_active=True
65
+ )
66
+ db_session.add(prov)
67
+ db_session.commit()
68
+
69
+ response = test_client.post(
70
+ "/mcp/tools/get_provider_snapshot",
71
+ json={"npi": "999999"}
72
+ )
73
+ assert response.status_code == 200
74
+ data = response.json()
75
+ assert data["provider"]["full_name"] == "Test Doc"
76
+ assert len(data["credentials"]) == 0
77
+
78
+ def test_add_credential(test_client, db_session):
79
+ # Seed provider
80
+ prov = Provider(
81
+ npi="888888", full_name="Credential Doc", primary_specialty="Surgery", is_active=True
82
+ )
83
+ db_session.add(prov)
84
+ db_session.commit()
85
+ db_session.refresh(prov)
86
+
87
+ # Add Credential via tool
88
+ expiry = (date.today() + timedelta(days=100)).strftime("%Y-%m-%d")
89
+ response = test_client.post(
90
+ "/mcp/tools/add_or_update_credential",
91
+ json={
92
+ "provider_id": prov.id,
93
+ "type": "board_cert",
94
+ "issuer": "ABMS",
95
+ "number": "XYZ123",
96
+ "expiry_date": expiry
97
+ }
98
+ )
99
+ assert response.status_code == 200
100
+ data = response.json()
101
+ assert data["number"] == "XYZ123"
102
+ assert data["status"] == "active"
103
+
104
+ def test_list_expiring(test_client, db_session):
105
+ # Seed provider
106
+ prov = Provider(
107
+ npi="777777", full_name="Expiring Doc", dept="ER", location="NYC", is_active=True
108
+ )
109
+ db_session.add(prov)
110
+ db_session.commit()
111
+ db_session.refresh(prov)
112
+
113
+ # Add expiring credential (in 10 days)
114
+ # expiry date must be a date object for the model
115
+ expiry_1 = date.today() + timedelta(days=10)
116
+ cred = Credential(
117
+ provider_id=prov.id,
118
+ type="license",
119
+ issuer="State",
120
+ number="L1",
121
+ expiry_date=expiry_1,
122
+ status="active"
123
+ )
124
+ db_session.add(cred)
125
+
126
+ # Add non-expiring credential (in 100 days)
127
+ expiry_2 = date.today() + timedelta(days=100)
128
+ cred2 = Credential(
129
+ provider_id=prov.id,
130
+ type="license",
131
+ issuer="State",
132
+ number="L2",
133
+ expiry_date=expiry_2,
134
+ status="active"
135
+ )
136
+ db_session.add(cred2)
137
+ db_session.commit()
138
+
139
+ # Test tool
140
+ response = test_client.post(
141
+ "/mcp/tools/list_expiring_credentials",
142
+ json={
143
+ "window_days": 30,
144
+ "dept": "ER"
145
+ }
146
+ )
147
+ assert response.status_code == 200
148
+ data = response.json()
149
+ assert len(data) == 1
150
+ assert data[0]["credential"]["number"] == "L1"
151
+ assert data[0]["days_to_expiry"] == 10
152
+ assert data[0]["risk_score"] == 3 # < 30 days