nithin varghese commited on
Commit
da86ac7
·
unverified ·
2 Parent(s): 6682851 8e07f8c

Merge pull request #1 from humandotlearning/jules/cred-db-mcp-server

Browse files
pyproject.toml CHANGED
@@ -1,5 +1,5 @@
1
  [project]
2
- name = "cred_db_mcp"
3
  version = "0.1.0"
4
  description = "Credential Database MCP Server"
5
  requires-python = ">=3.11"
@@ -9,7 +9,8 @@ dependencies = [
9
  "sqlalchemy>=2.0.0",
10
  "pydantic>=2.0.0",
11
  "httpx>=0.24.0",
12
- "python-dotenv>=1.0.0"
 
13
  ]
14
 
15
  [project.optional-dependencies]
@@ -22,6 +23,9 @@ test = [
22
  requires = ["hatchling"]
23
  build-backend = "hatchling.build"
24
 
 
 
 
25
  [tool.pytest.ini_options]
26
  pythonpath = "src"
27
  testpaths = ["tests"]
 
1
  [project]
2
+ name = "cred_db_mcp_server"
3
  version = "0.1.0"
4
  description = "Credential Database MCP Server"
5
  requires-python = ">=3.11"
 
9
  "sqlalchemy>=2.0.0",
10
  "pydantic>=2.0.0",
11
  "httpx>=0.24.0",
12
+ "python-dotenv>=1.0.0",
13
+ "gradio>=5.0.0",
14
  ]
15
 
16
  [project.optional-dependencies]
 
23
  requires = ["hatchling"]
24
  build-backend = "hatchling.build"
25
 
26
+ [tool.hatch.build.targets.wheel]
27
+ packages = ["src/cred_db_mcp_server"]
28
+
29
  [tool.pytest.ini_options]
30
  pythonpath = "src"
31
  testpaths = ["tests"]
src/cred_db_mcp/db.py DELETED
@@ -1,33 +0,0 @@
1
- import os
2
- from sqlalchemy import create_engine
3
- from sqlalchemy.orm import sessionmaker, DeclarativeBase
4
-
5
- # Default to a local file for dev/test if not specified
6
- DB_PATH = os.getenv("DB_PATH", "credentialwatch.db")
7
- DATABASE_URL = f"sqlite:///{DB_PATH}"
8
-
9
- # specific check for checks that might run in different environments
10
- if DB_PATH == ":memory:":
11
- DATABASE_URL = "sqlite:///:memory:"
12
-
13
- engine = create_engine(
14
- DATABASE_URL,
15
- connect_args={"check_same_thread": False} # Needed for SQLite
16
- )
17
-
18
- SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
19
-
20
- class Base(DeclarativeBase):
21
- pass
22
-
23
- def get_db():
24
- """Dependency for FastAPI to get a DB session."""
25
- db = SessionLocal()
26
- try:
27
- yield db
28
- finally:
29
- db.close()
30
-
31
- def init_db():
32
- """Helper to initialize the DB (create tables)."""
33
- Base.metadata.create_all(bind=engine)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/cred_db_mcp/main.py DELETED
@@ -1,98 +0,0 @@
1
- from fastapi import FastAPI, Depends, HTTPException, Body
2
- from sqlalchemy.orm import Session
3
- from typing import List, Union
4
-
5
- from .db import get_db, init_db
6
- from .mcp_tools import MCPTools
7
- from .schemas import (
8
- SyncProviderInput, AddCredentialInput, ListExpiringInput, GetSnapshotInput,
9
- ProviderRead, CredentialRead, ExpiringCredentialItem, ProviderSnapshot, ToolError
10
- )
11
-
12
- app = FastAPI(title="CredentialWatch DB MCP", version="0.1.0")
13
-
14
- # Initialize DB on startup (for simple deployments)
15
- @app.on_event("startup")
16
- def on_startup():
17
- init_db()
18
-
19
- def get_mcp_tools(db: Session = Depends(get_db)) -> MCPTools:
20
- return MCPTools(db)
21
-
22
- # --- MCP Tool Endpoints ---
23
- # Pattern: /mcp/tools/{tool_name}
24
- # We use POST for all of them as they are "function calls"
25
-
26
- @app.post("/mcp/tools/sync_provider_from_npi", response_model=Union[ProviderRead, ToolError])
27
- async def sync_provider_from_npi(
28
- args: SyncProviderInput,
29
- tools: MCPTools = Depends(get_mcp_tools)
30
- ):
31
- """
32
- Syncs provider data from the NPI Registry via npi_mcp.
33
- """
34
- result = await tools.sync_provider_from_npi(args.npi)
35
- if isinstance(result, dict) and "error" in result:
36
- # Return generic error structure instead of 400 for agent handling if preferred,
37
- # but usually 400 is better for HTTP.
38
- # The prompt says "return a structured error".
39
- return ToolError(error=result["error"])
40
- return result
41
-
42
- @app.post("/mcp/tools/add_or_update_credential", response_model=Union[CredentialRead, ToolError])
43
- def add_or_update_credential(
44
- args: AddCredentialInput,
45
- tools: MCPTools = Depends(get_mcp_tools)
46
- ):
47
- """
48
- Adds or updates a credential record.
49
- """
50
- result = tools.add_or_update_credential(
51
- provider_id=args.provider_id,
52
- type=args.type,
53
- issuer=args.issuer,
54
- number=args.number,
55
- expiry_date=args.expiry_date
56
- )
57
- if isinstance(result, dict) and "error" in result:
58
- return ToolError(error=result["error"])
59
- return result
60
-
61
- @app.post("/mcp/tools/list_expiring_credentials", response_model=List[ExpiringCredentialItem])
62
- def list_expiring_credentials(
63
- args: ListExpiringInput,
64
- tools: MCPTools = Depends(get_mcp_tools)
65
- ):
66
- """
67
- Lists credentials expiring within a window.
68
- """
69
- return tools.list_expiring_credentials(
70
- window_days=args.window_days,
71
- dept=args.dept,
72
- location=args.location
73
- )
74
-
75
- @app.post("/mcp/tools/get_provider_snapshot", response_model=Union[ProviderSnapshot, ToolError])
76
- def get_provider_snapshot(
77
- args: GetSnapshotInput,
78
- tools: MCPTools = Depends(get_mcp_tools)
79
- ):
80
- """
81
- Gets a full snapshot of a provider.
82
- """
83
- result = tools.get_provider_snapshot(
84
- provider_id=args.provider_id,
85
- npi=args.npi
86
- )
87
- if isinstance(result, dict) and "error" in result:
88
- return ToolError(error=result["error"])
89
- return result
90
-
91
- # Simple info endpoint
92
- @app.get("/")
93
- def read_root():
94
- return {
95
- "service": "cred_db_mcp",
96
- "status": "running",
97
- "docs": "/docs"
98
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/cred_db_mcp/mcp_tools.py DELETED
@@ -1,185 +0,0 @@
1
- from sqlalchemy.orm import Session
2
- from sqlalchemy import select, and_, func
3
- from datetime import date, timedelta, datetime
4
- import json
5
-
6
- from .models import Provider, Credential, Alert
7
- from .schemas import (
8
- ProviderRead, CredentialRead, AlertRead,
9
- ExpiringCredentialItem, ProviderSnapshot
10
- )
11
- from .npi_client import NPIClient
12
-
13
- class MCPTools:
14
- def __init__(self, db: Session):
15
- self.db = db
16
- self.npi_client = NPIClient()
17
-
18
- async def sync_provider_from_npi(self, npi: str):
19
- # 1. Call npi_mcp
20
- npi_data = await self.npi_client.get_provider_by_npi(npi)
21
-
22
- if not npi_data:
23
- return {"error": f"NPI {npi} not found in upstream registry"}
24
-
25
- # Map NPI data to our schema.
26
- # Assuming npi_data has keys: npi, first_name, last_name, taxonomy_desc, practice_address
27
- # This mapping is hypothetical based on typical NPI registry fields.
28
-
29
- full_name = f"{npi_data.get('first_name', '')} {npi_data.get('last_name', '')}".strip()
30
- if not full_name:
31
- full_name = npi_data.get('organization_name', 'Unknown')
32
-
33
- specialty = npi_data.get('taxonomy_desc', 'Unknown')
34
- address_obj = npi_data.get('practice_address', {})
35
- # Simple string representation of location
36
- location = f"{address_obj.get('city', '')}, {address_obj.get('state', '')}" if isinstance(address_obj, dict) else str(address_obj)
37
-
38
- # 2. Update or Create in DB
39
- provider = self.db.scalar(select(Provider).where(Provider.npi == npi))
40
-
41
- if not provider:
42
- provider = Provider(
43
- npi=npi,
44
- full_name=full_name,
45
- primary_specialty=specialty,
46
- location=location,
47
- is_active=True
48
- )
49
- self.db.add(provider)
50
- else:
51
- provider.full_name = full_name
52
- provider.primary_specialty = specialty
53
- provider.location = location
54
- # Don't overwrite dept if set locally? assuming upstream doesn't have dept.
55
-
56
- self.db.commit()
57
- self.db.refresh(provider)
58
-
59
- return ProviderRead.model_validate(provider)
60
-
61
- def add_or_update_credential(
62
- self, provider_id: int, type: str, issuer: str, number: str, expiry_date: str
63
- ):
64
- # Parse expiry_date (assuming YYYY-MM-DD)
65
- try:
66
- exp_date = datetime.strptime(expiry_date, "%Y-%m-%d").date()
67
- except ValueError:
68
- return {"error": "Invalid date format. Use YYYY-MM-DD"}
69
-
70
- # Check provider exists
71
- provider = self.db.get(Provider, provider_id)
72
- if not provider:
73
- return {"error": f"Provider with ID {provider_id} not found"}
74
-
75
- # Find existing credential
76
- stmt = select(Credential).where(
77
- and_(
78
- Credential.provider_id == provider_id,
79
- Credential.type == type,
80
- Credential.issuer == issuer,
81
- Credential.number == number
82
- )
83
- )
84
- credential = self.db.scalar(stmt)
85
-
86
- if credential:
87
- credential.expiry_date = exp_date
88
- # Heuristic status update
89
- credential.status = "active" if exp_date > date.today() else "expired"
90
- credential.updated_at = datetime.now()
91
- else:
92
- credential = Credential(
93
- provider_id=provider_id,
94
- type=type,
95
- issuer=issuer,
96
- number=number,
97
- expiry_date=exp_date,
98
- status="active" if exp_date > date.today() else "expired",
99
- created_at=datetime.now()
100
- )
101
- self.db.add(credential)
102
-
103
- self.db.commit()
104
- self.db.refresh(credential)
105
-
106
- return CredentialRead.model_validate(credential)
107
-
108
- def list_expiring_credentials(
109
- self, window_days: int, dept: str | None = None, location: str | None = None
110
- ):
111
- today = date.today()
112
- cutoff = today + timedelta(days=window_days)
113
-
114
- # Build Query
115
- stmt = select(Credential, Provider).join(Provider).where(
116
- and_(
117
- Credential.status == "active",
118
- Credential.expiry_date <= cutoff,
119
- Credential.expiry_date >= today # Only future or today (past would be expired)
120
- )
121
- )
122
-
123
- if dept:
124
- stmt = stmt.where(Provider.dept == dept)
125
- if location:
126
- stmt = stmt.where(Provider.location.ilike(f"%{location}%"))
127
-
128
- results = self.db.execute(stmt).all()
129
-
130
- output = []
131
- for cred, prov in results:
132
- if not cred.expiry_date:
133
- continue
134
-
135
- days_to_expiry = (cred.expiry_date - today).days
136
-
137
- # Risk Score Heuristic
138
- # 3 for <30 days, 2 for 30–60, 1 for 60–90
139
- if days_to_expiry < 30:
140
- risk_score = 3
141
- elif days_to_expiry < 60:
142
- risk_score = 2
143
- else:
144
- risk_score = 1
145
-
146
- output.append(ExpiringCredentialItem(
147
- provider=ProviderRead.model_validate(prov),
148
- credential=CredentialRead.model_validate(cred),
149
- days_to_expiry=days_to_expiry,
150
- risk_score=risk_score
151
- ))
152
-
153
- return output
154
-
155
- def get_provider_snapshot(
156
- self, provider_id: int | None = None, npi: str | None = None
157
- ):
158
- if not provider_id and not npi:
159
- return {"error": "Must provide either provider_id or npi"}
160
-
161
- stmt = select(Provider)
162
- if provider_id:
163
- stmt = stmt.where(Provider.id == provider_id)
164
- else:
165
- stmt = stmt.where(Provider.npi == npi)
166
-
167
- provider = self.db.scalar(stmt)
168
- if not provider:
169
- return {"error": "Provider not found"}
170
-
171
- # Get Credentials
172
- creds = self.db.scalars(
173
- select(Credential).where(Credential.provider_id == provider.id)
174
- ).all()
175
-
176
- # Get Alerts (Optional)
177
- alerts = self.db.scalars(
178
- select(Alert).where(Alert.provider_id == provider.id)
179
- ).all()
180
-
181
- return ProviderSnapshot(
182
- provider=ProviderRead.model_validate(provider),
183
- credentials=[CredentialRead.model_validate(c) for c in creds],
184
- alerts=[AlertRead.model_validate(a) for a in alerts]
185
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/cred_db_mcp/models.py DELETED
@@ -1,62 +0,0 @@
1
- from datetime import datetime, date
2
- from typing import Optional, List, Any
3
- from sqlalchemy import (
4
- String, Integer, Boolean, DateTime, Date, ForeignKey, JSON, func
5
- )
6
- from sqlalchemy.orm import Mapped, mapped_column, relationship
7
-
8
- from .db import Base
9
-
10
- class Provider(Base):
11
- __tablename__ = "providers"
12
-
13
- id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
14
- npi: Mapped[Optional[str]] = mapped_column(String, index=True, nullable=True)
15
- full_name: Mapped[str] = mapped_column(String)
16
- dept: Mapped[Optional[str]] = mapped_column(String, nullable=True)
17
- location: Mapped[Optional[str]] = mapped_column(String, nullable=True)
18
- primary_specialty: Mapped[Optional[str]] = mapped_column(String, nullable=True)
19
- is_active: Mapped[bool] = mapped_column(Boolean, default=True)
20
-
21
- created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
22
- updated_at: Mapped[datetime] = mapped_column(
23
- DateTime, server_default=func.now(), onupdate=func.now()
24
- )
25
-
26
- credentials: Mapped[List["Credential"]] = relationship(
27
- "Credential", back_populates="provider", cascade="all, delete-orphan"
28
- )
29
-
30
- class Credential(Base):
31
- __tablename__ = "credentials"
32
-
33
- id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
34
- provider_id: Mapped[int] = mapped_column(
35
- Integer, ForeignKey("providers.id"), nullable=False
36
- )
37
- type: Mapped[str] = mapped_column(String) # e.g. "state_license", "board_cert"
38
- issuer: Mapped[str] = mapped_column(String)
39
- number: Mapped[str] = mapped_column(String)
40
- status: Mapped[str] = mapped_column(String) # "active", "expired", etc.
41
-
42
- issue_date: Mapped[Optional[date]] = mapped_column(Date, nullable=True)
43
- expiry_date: Mapped[Optional[date]] = mapped_column(Date, nullable=True)
44
- last_verified_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
45
-
46
- metadata_json: Mapped[Optional[Any]] = mapped_column(JSON, nullable=True)
47
-
48
- created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
49
- updated_at: Mapped[datetime] = mapped_column(
50
- DateTime, server_default=func.now(), onupdate=func.now()
51
- )
52
-
53
- provider: Mapped["Provider"] = relationship("Provider", back_populates="credentials")
54
-
55
- class Alert(Base):
56
- __tablename__ = "alerts"
57
-
58
- id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
59
- provider_id: Mapped[int] = mapped_column(Integer, ForeignKey("providers.id"))
60
- message: Mapped[str] = mapped_column(String)
61
- created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
62
- # Minimal schema for alerts as requested
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/cred_db_mcp/npi_client.py DELETED
@@ -1,40 +0,0 @@
1
- import os
2
- import httpx
3
- from typing import Optional, Dict, Any
4
-
5
- class NPIClient:
6
- def __init__(self, base_url: Optional[str] = None):
7
- self.base_url = base_url or os.getenv("NPI_MCP_URL", "http://localhost:8001")
8
- # Ensure no trailing slash
9
- self.base_url = self.base_url.rstrip("/")
10
-
11
- async def get_provider_by_npi(self, npi: str) -> Optional[Dict[str, Any]]:
12
- """
13
- Calls the npi_mcp server to get provider details.
14
- Assumes npi_mcp exposes a tool 'get_provider_by_npi' via HTTP.
15
-
16
- The NPI MCP is expected to have a similar structure: POST /mcp/tools/get_provider_by_npi
17
- """
18
- url = f"{self.base_url}/mcp/tools/get_provider_by_npi"
19
-
20
- # We assume the NPI MCP accepts arguments in the body
21
- payload = {"npi": npi}
22
-
23
- try:
24
- async with httpx.AsyncClient() as client:
25
- response = await client.post(url, json=payload, timeout=10.0)
26
-
27
- if response.status_code == 200:
28
- data = response.json()
29
- # If the tool wraps return value, e.g. {"result": ...} adjust here.
30
- # For now assuming direct return or generic tool response.
31
- return data
32
- elif response.status_code == 404:
33
- return None
34
- else:
35
- # Log error or raise
36
- print(f"Error calling NPI MCP: {response.status_code} {response.text}")
37
- return None
38
- except httpx.RequestError as e:
39
- print(f"Request error calling NPI MCP: {e}")
40
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/cred_db_mcp/schemas.py DELETED
@@ -1,79 +0,0 @@
1
- from pydantic import BaseModel, ConfigDict
2
- from typing import Optional, List, Any
3
- from datetime import date, datetime
4
-
5
- # --- Base Models ---
6
-
7
- class ProviderBase(BaseModel):
8
- npi: Optional[str] = None
9
- full_name: str
10
- dept: Optional[str] = None
11
- location: Optional[str] = None
12
- primary_specialty: Optional[str] = None
13
- is_active: bool = True
14
-
15
- class ProviderRead(ProviderBase):
16
- id: int
17
- created_at: datetime
18
- updated_at: datetime
19
- model_config = ConfigDict(from_attributes=True)
20
-
21
- class CredentialBase(BaseModel):
22
- type: str
23
- issuer: str
24
- number: str
25
- status: str
26
- issue_date: Optional[date] = None
27
- expiry_date: Optional[date] = None
28
- last_verified_at: Optional[datetime] = None
29
- metadata_json: Optional[Any] = None
30
-
31
- class CredentialRead(CredentialBase):
32
- id: int
33
- provider_id: int
34
- created_at: datetime
35
- updated_at: datetime
36
- model_config = ConfigDict(from_attributes=True)
37
-
38
- class AlertRead(BaseModel):
39
- id: int
40
- provider_id: int
41
- message: str
42
- created_at: datetime
43
- model_config = ConfigDict(from_attributes=True)
44
-
45
- # --- Tool IO Models ---
46
-
47
- class SyncProviderInput(BaseModel):
48
- npi: str
49
-
50
- class AddCredentialInput(BaseModel):
51
- provider_id: int
52
- type: str
53
- issuer: str
54
- number: str
55
- expiry_date: str # Expecting YYYY-MM-DD string as per prompt, or we can parse it
56
-
57
- class ListExpiringInput(BaseModel):
58
- window_days: int
59
- dept: Optional[str] = None
60
- location: Optional[str] = None
61
-
62
- class ExpiringCredentialItem(BaseModel):
63
- provider: ProviderRead
64
- credential: CredentialRead
65
- days_to_expiry: int
66
- risk_score: int
67
-
68
- class GetSnapshotInput(BaseModel):
69
- provider_id: Optional[int] = None
70
- npi: Optional[str] = None
71
-
72
- class ProviderSnapshot(BaseModel):
73
- provider: ProviderRead
74
- credentials: List[CredentialRead]
75
- alerts: Optional[List[AlertRead]] = []
76
-
77
- class ToolError(BaseModel):
78
- error: str
79
- details: Optional[str] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/cred_db_mcp_server/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ The `cred_db_mcp_server` package exposes an MCP server wrapper around the `CRED_API` service.
3
+ It uses Gradio to provide the MCP endpoints.
4
+ """
src/cred_db_mcp_server/config.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+
4
+ load_dotenv()
5
+
6
+ CRED_API_BASE_URL = os.getenv("CRED_API_BASE_URL", "http://localhost:8000")
src/cred_db_mcp_server/main.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from .tools import (
3
+ sync_provider_from_npi,
4
+ add_or_update_credential,
5
+ list_expiring_credentials,
6
+ get_provider_snapshot
7
+ )
8
+
9
+ def main():
10
+ """
11
+ Main entry point for the Gradio MCP server.
12
+ """
13
+ # Create the Gradio interface
14
+ # We use a TabbedInterface to organize the tools if accessed via UI,
15
+ # but primarily they are exposed as MCP tools.
16
+
17
+ # Tool 1: sync_provider_from_npi
18
+ iface_sync = gr.Interface(
19
+ fn=sync_provider_from_npi,
20
+ inputs=[gr.Textbox(label="NPI")],
21
+ outputs=gr.JSON(label="Provider Data"),
22
+ description="Syncs a provider's data from the NPI registry.",
23
+ allow_flagging="never"
24
+ )
25
+
26
+ # Tool 2: add_or_update_credential
27
+ iface_add_cred = gr.Interface(
28
+ fn=add_or_update_credential,
29
+ inputs=[
30
+ gr.Number(label="Provider ID", precision=0),
31
+ gr.Textbox(label="Type"),
32
+ gr.Textbox(label="Issuer"),
33
+ gr.Textbox(label="Number"),
34
+ gr.Textbox(label="Expiry Date (YYYY-MM-DD)"),
35
+ ],
36
+ outputs=gr.JSON(label="Credential Data"),
37
+ description="Adds or updates a credential for a provider.",
38
+ allow_flagging="never"
39
+ )
40
+
41
+ # Tool 3: list_expiring_credentials
42
+ iface_expiring = gr.Interface(
43
+ fn=list_expiring_credentials,
44
+ inputs=[
45
+ gr.Number(label="Window Days", precision=0),
46
+ gr.Textbox(label="Department", value=None),
47
+ gr.Textbox(label="Location", value=None),
48
+ ],
49
+ outputs=gr.JSON(label="Expiring Credentials"),
50
+ description="Lists credentials expiring within a certain number of days.",
51
+ allow_flagging="never"
52
+ )
53
+
54
+ # Tool 4: get_provider_snapshot
55
+ iface_snapshot = gr.Interface(
56
+ fn=get_provider_snapshot,
57
+ inputs=[
58
+ gr.Number(label="Provider ID", precision=0, value=None),
59
+ gr.Textbox(label="NPI", value=None),
60
+ ],
61
+ outputs=gr.JSON(label="Provider Snapshot"),
62
+ description="Gets a snapshot of a provider's data including credentials and alerts.",
63
+ allow_flagging="never"
64
+ )
65
+
66
+ demo = gr.TabbedInterface(
67
+ [iface_sync, iface_add_cred, iface_expiring, iface_snapshot],
68
+ ["Sync Provider", "Add/Update Credential", "List Expiring", "Provider Snapshot"]
69
+ )
70
+
71
+ # Launch with mcp_server=True to enable MCP endpoints
72
+ demo.launch(mcp_server=True)
73
+
74
+ if __name__ == "__main__":
75
+ main()
src/cred_db_mcp_server/schemas.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import Optional, List
3
+
4
+ # Provider Schema
5
+ class Provider(BaseModel):
6
+ id: int
7
+ npi: str
8
+ name: str
9
+ # Add other fields as necessary based on CRED_API response
10
+
11
+ class Credential(BaseModel):
12
+ id: Optional[int] = None
13
+ provider_id: int
14
+ type: str
15
+ issuer: str
16
+ number: str
17
+ expiry_date: str
18
+ # Add other fields as necessary
19
+
20
+ class ExpiringCredential(BaseModel):
21
+ provider: Provider
22
+ credential: Credential
23
+ days_to_expiry: int
24
+ risk_score: float
25
+
26
+ class ProviderSnapshot(BaseModel):
27
+ provider: Provider
28
+ credentials: List[Credential]
29
+ alerts: Optional[List[dict]] = None
src/cred_db_mcp_server/tools.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module implements the tools exposed by the MCP server.
3
+ Each function corresponds to an MCP tool.
4
+ """
5
+ import httpx
6
+ from typing import Optional, List, Dict, Any
7
+ from .config import CRED_API_BASE_URL
8
+ # from .schemas import Provider, Credential, ExpiringCredential, ProviderSnapshot
9
+
10
+ # We can use simple types in signature for Gradio MCP compatibility
11
+ # but use schemas for internal validation if needed.
12
+ # Since the prompt asked for mapped errors, we'll wrap calls.
13
+
14
+ def sync_provider_from_npi(npi: str) -> Dict[str, Any]:
15
+ """
16
+ Syncs a provider's data from the NPI registry.
17
+
18
+ Args:
19
+ npi (str): The NPI number of the provider.
20
+
21
+ Returns:
22
+ dict: The provider object returned by the API.
23
+ """
24
+ url = f"{CRED_API_BASE_URL}/providers/sync_from_npi"
25
+ try:
26
+ with httpx.Client() as client:
27
+ response = client.post(url, json={"npi": npi})
28
+ response.raise_for_status()
29
+ return response.json()
30
+ except httpx.HTTPStatusError as e:
31
+ # Map HTTP errors to MCP/User friendly errors
32
+ if e.response.status_code == 404:
33
+ return {"error": f"Provider with NPI {npi} not found or sync failed."}
34
+ return {"error": f"API Error: {e.response.text}"}
35
+ except Exception as e:
36
+ return {"error": f"Connection Error: {str(e)}"}
37
+
38
+ def add_or_update_credential(
39
+ provider_id: int,
40
+ type: str,
41
+ issuer: str,
42
+ number: str,
43
+ expiry_date: str
44
+ ) -> Dict[str, Any]:
45
+ """
46
+ Adds or updates a credential for a provider.
47
+
48
+ Args:
49
+ provider_id (int): The internal ID of the provider.
50
+ type (str): The type of credential (e.g., 'Medical License').
51
+ issuer (str): The issuing body (e.g., 'State Board').
52
+ number (str): The credential number.
53
+ expiry_date (str): The expiry date in YYYY-MM-DD format.
54
+
55
+ Returns:
56
+ dict: The created or updated credential object.
57
+ """
58
+ url = f"{CRED_API_BASE_URL}/credentials/add_or_update"
59
+ payload = {
60
+ "provider_id": provider_id,
61
+ "type": type,
62
+ "issuer": issuer,
63
+ "number": number,
64
+ "expiry_date": expiry_date
65
+ }
66
+ try:
67
+ with httpx.Client() as client:
68
+ response = client.post(url, json=payload)
69
+ response.raise_for_status()
70
+ return response.json()
71
+ except httpx.HTTPStatusError as e:
72
+ return {"error": f"API Error: {e.response.text}"}
73
+ except Exception as e:
74
+ return {"error": f"Connection Error: {str(e)}"}
75
+
76
+ def list_expiring_credentials(
77
+ window_days: int,
78
+ dept: Optional[str] = None,
79
+ location: Optional[str] = None
80
+ ) -> List[Dict[str, Any]]:
81
+ """
82
+ Lists credentials expiring within a certain number of days.
83
+
84
+ Args:
85
+ window_days (int): The number of days to check for expiry.
86
+ dept (str, optional): Filter by department.
87
+ location (str, optional): Filter by location.
88
+
89
+ Returns:
90
+ list: A list of objects containing provider, credential, days_to_expiry, and risk_score.
91
+ """
92
+ url = f"{CRED_API_BASE_URL}/credentials/expiring"
93
+ payload = {
94
+ "window_days": window_days,
95
+ "dept": dept,
96
+ "location": location
97
+ }
98
+ # Remove None values to avoid sending them if API doesn't expect them or treat them as valid
99
+ payload = {k: v for k, v in payload.items() if v is not None}
100
+
101
+ try:
102
+ with httpx.Client() as client:
103
+ response = client.post(url, json=payload)
104
+ response.raise_for_status()
105
+ return response.json()
106
+ except httpx.HTTPStatusError as e:
107
+ # In case of error, return a list with an error dict or raise
108
+ # For MCP, returning structured error info is usually better than crashing
109
+ # But for list return type, we might need to handle differently.
110
+ # Here we assume the tool call handles exceptions or checks for "error" key in result if it was a dict.
111
+ # Since return type is List, we can't easily return a dict error.
112
+ # We'll return an empty list and print error or raise.
113
+ # Let's raise ValueError which Gradio might catch and show.
114
+ raise ValueError(f"API Error: {e.response.text}")
115
+ except Exception as e:
116
+ raise ConnectionError(f"Connection Error: {str(e)}")
117
+
118
+ def get_provider_snapshot(
119
+ provider_id: Optional[int] = None,
120
+ npi: Optional[str] = None
121
+ ) -> Dict[str, Any]:
122
+ """
123
+ Gets a snapshot of a provider's data including credentials and alerts.
124
+
125
+ Args:
126
+ provider_id (int, optional): The provider's internal ID.
127
+ npi (str, optional): The provider's NPI.
128
+
129
+ Returns:
130
+ dict: Object containing provider details, credentials, and alerts.
131
+ """
132
+ if provider_id is None and npi is None:
133
+ return {"error": "Must provide either provider_id or npi."}
134
+
135
+ url = f"{CRED_API_BASE_URL}/providers/snapshot"
136
+ payload = {}
137
+ if provider_id is not None:
138
+ payload["provider_id"] = provider_id
139
+ if npi is not None:
140
+ payload["npi"] = npi
141
+
142
+ try:
143
+ with httpx.Client() as client:
144
+ response = client.post(url, json=payload)
145
+ response.raise_for_status()
146
+ return response.json()
147
+ except httpx.HTTPStatusError as e:
148
+ if e.response.status_code == 404:
149
+ return {"error": "Provider not found."}
150
+ return {"error": f"API Error: {e.response.text}"}
151
+ except Exception as e:
152
+ return {"error": f"Connection Error: {str(e)}"}
tests/test_cred_db_mcp_server.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import pytest
3
+ from unittest.mock import patch, MagicMock
4
+ from cred_db_mcp_server.tools import (
5
+ sync_provider_from_npi,
6
+ add_or_update_credential,
7
+ list_expiring_credentials,
8
+ get_provider_snapshot
9
+ )
10
+ import httpx
11
+
12
+ # Mock response data
13
+ MOCK_PROVIDER = {"id": 1, "npi": "1234567890", "name": "Dr. Test"}
14
+ MOCK_CREDENTIAL = {
15
+ "id": 10,
16
+ "provider_id": 1,
17
+ "type": "Medical License",
18
+ "issuer": "State Board",
19
+ "number": "MD12345",
20
+ "expiry_date": "2025-01-01"
21
+ }
22
+ MOCK_EXPIRING_LIST = [
23
+ {
24
+ "provider": MOCK_PROVIDER,
25
+ "credential": MOCK_CREDENTIAL,
26
+ "days_to_expiry": 30,
27
+ "risk_score": 0.5
28
+ }
29
+ ]
30
+ MOCK_SNAPSHOT = {
31
+ "provider": MOCK_PROVIDER,
32
+ "credentials": [MOCK_CREDENTIAL],
33
+ "alerts": []
34
+ }
35
+
36
+ @patch("cred_db_mcp_server.tools.httpx.Client")
37
+ def test_sync_provider_from_npi(mock_client_cls):
38
+ mock_client = MagicMock()
39
+ mock_client_cls.return_value.__enter__.return_value = mock_client
40
+ mock_client.post.return_value.status_code = 200
41
+ mock_client.post.return_value.json.return_value = MOCK_PROVIDER
42
+
43
+ result = sync_provider_from_npi("1234567890")
44
+ assert result == MOCK_PROVIDER
45
+ mock_client.post.assert_called_once()
46
+ assert "sync_from_npi" in mock_client.post.call_args[0][0]
47
+
48
+ @patch("cred_db_mcp_server.tools.httpx.Client")
49
+ def test_add_or_update_credential(mock_client_cls):
50
+ mock_client = MagicMock()
51
+ mock_client_cls.return_value.__enter__.return_value = mock_client
52
+ mock_client.post.return_value.status_code = 200
53
+ mock_client.post.return_value.json.return_value = MOCK_CREDENTIAL
54
+
55
+ result = add_or_update_credential(1, "Medical License", "State Board", "MD12345", "2025-01-01")
56
+ assert result == MOCK_CREDENTIAL
57
+ mock_client.post.assert_called_once()
58
+ assert "add_or_update" in mock_client.post.call_args[0][0]
59
+
60
+ @patch("cred_db_mcp_server.tools.httpx.Client")
61
+ def test_list_expiring_credentials(mock_client_cls):
62
+ mock_client = MagicMock()
63
+ mock_client_cls.return_value.__enter__.return_value = mock_client
64
+ mock_client.post.return_value.status_code = 200
65
+ mock_client.post.return_value.json.return_value = MOCK_EXPIRING_LIST
66
+
67
+ result = list_expiring_credentials(30)
68
+ assert result == MOCK_EXPIRING_LIST
69
+ mock_client.post.assert_called_once()
70
+ assert "expiring" in mock_client.post.call_args[0][0]
71
+
72
+ @patch("cred_db_mcp_server.tools.httpx.Client")
73
+ def test_get_provider_snapshot(mock_client_cls):
74
+ mock_client = MagicMock()
75
+ mock_client_cls.return_value.__enter__.return_value = mock_client
76
+ mock_client.post.return_value.status_code = 200
77
+ mock_client.post.return_value.json.return_value = MOCK_SNAPSHOT
78
+
79
+ result = get_provider_snapshot(provider_id=1)
80
+ assert result == MOCK_SNAPSHOT
81
+ mock_client.post.assert_called_once()
82
+ assert "snapshot" in mock_client.post.call_args[0][0]
83
+
84
+ def test_get_provider_snapshot_no_args():
85
+ result = get_provider_snapshot()
86
+ assert "error" in result
uv.lock ADDED
The diff for this file is too large to render. See raw diff