Spaces:
Runtime error
Runtime error
dylanglenister
commited on
Commit
·
368f485
1
Parent(s):
c9573f1
Upgraded and reworked medical repo file.
Browse filesFollows the same style, structure, and logic as the other data access layer files.
Also includes tests.
- schemas/medical_memory_validator.json +41 -0
- schemas/medical_record_validator.json +40 -0
- src/data/repositories/medical.py +0 -183
- src/data/repositories/medical_memory.py +148 -0
- src/data/repositories/medical_record.py +93 -0
- src/models/medical.py +38 -0
- tests/test_medical_memory.py +122 -0
- tests/test_medical_record.py +104 -0
schemas/medical_memory_validator.json
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"$jsonSchema": {
|
| 3 |
+
"bsonType": "object",
|
| 4 |
+
"title": "Medical Memory validator",
|
| 5 |
+
"required": [
|
| 6 |
+
"patient_id",
|
| 7 |
+
"doctor_id",
|
| 8 |
+
"summary",
|
| 9 |
+
"created_at"
|
| 10 |
+
],
|
| 11 |
+
"properties": {
|
| 12 |
+
"patient_id": {
|
| 13 |
+
"bsonType": "objectId",
|
| 14 |
+
"description": "'patient_id' must be an objectId and is required."
|
| 15 |
+
},
|
| 16 |
+
"doctor_id": {
|
| 17 |
+
"bsonType": "objectId",
|
| 18 |
+
"description": "'doctor_id' must be an objectId and is required."
|
| 19 |
+
},
|
| 20 |
+
"session_id": {
|
| 21 |
+
"bsonType": "objectId",
|
| 22 |
+
"description": "'session_id' must be an objectId and is optional."
|
| 23 |
+
},
|
| 24 |
+
"summary": {
|
| 25 |
+
"bsonType": "string",
|
| 26 |
+
"description": "'summary' must be a string and is required."
|
| 27 |
+
},
|
| 28 |
+
"embedding": {
|
| 29 |
+
"bsonType": "array",
|
| 30 |
+
"description": "'embedding' must be an array of floats and is optional.",
|
| 31 |
+
"items": {
|
| 32 |
+
"bsonType": "double"
|
| 33 |
+
}
|
| 34 |
+
},
|
| 35 |
+
"created_at": {
|
| 36 |
+
"bsonType": "date",
|
| 37 |
+
"description": "'created_at' must be a date and is required."
|
| 38 |
+
}
|
| 39 |
+
}
|
| 40 |
+
}
|
| 41 |
+
}
|
schemas/medical_record_validator.json
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"$jsonSchema": {
|
| 3 |
+
"bsonType": "object",
|
| 4 |
+
"title": "Medical Record validator",
|
| 5 |
+
"required": [
|
| 6 |
+
"patient_id",
|
| 7 |
+
"doctor_id",
|
| 8 |
+
"record_type",
|
| 9 |
+
"details",
|
| 10 |
+
"created_at",
|
| 11 |
+
"updated_at"
|
| 12 |
+
],
|
| 13 |
+
"properties": {
|
| 14 |
+
"patient_id": {
|
| 15 |
+
"bsonType": "objectId",
|
| 16 |
+
"description": "'patient_id' must be an objectId and is required."
|
| 17 |
+
},
|
| 18 |
+
"doctor_id": {
|
| 19 |
+
"bsonType": "objectId",
|
| 20 |
+
"description": "'doctor_id' must be an objectId and is required."
|
| 21 |
+
},
|
| 22 |
+
"record_type": {
|
| 23 |
+
"bsonType": "string",
|
| 24 |
+
"description": "'record_type' must be a string and is required."
|
| 25 |
+
},
|
| 26 |
+
"details": {
|
| 27 |
+
"bsonType": "object",
|
| 28 |
+
"description": "'details' must be an object and is required."
|
| 29 |
+
},
|
| 30 |
+
"created_at": {
|
| 31 |
+
"bsonType": "date",
|
| 32 |
+
"description": "'created_at' must be a date and is required."
|
| 33 |
+
},
|
| 34 |
+
"updated_at": {
|
| 35 |
+
"bsonType": "date",
|
| 36 |
+
"description": "'updated_at' must be a date and is required."
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
}
|
| 40 |
+
}
|
src/data/repositories/medical.py
DELETED
|
@@ -1,183 +0,0 @@
|
|
| 1 |
-
# data/repositories/medical.py
|
| 2 |
-
"""
|
| 3 |
-
Medical records and memory management operations for MongoDB.
|
| 4 |
-
|
| 5 |
-
@Note Could this be split into two? One for records and one for memory.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
from datetime import datetime, timezone
|
| 9 |
-
from typing import Any
|
| 10 |
-
|
| 11 |
-
from bson import ObjectId
|
| 12 |
-
from pymongo import ASCENDING, DESCENDING
|
| 13 |
-
from pymongo.errors import (ConnectionFailure, DuplicateKeyError,
|
| 14 |
-
OperationFailure, PyMongoError)
|
| 15 |
-
|
| 16 |
-
from src.data.connection import (ActionFailed, Collections, get_collection,
|
| 17 |
-
setup_collection)
|
| 18 |
-
from src.utils.logger import logger
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
#def init(
|
| 22 |
-
# *,
|
| 23 |
-
# collection_name: str = Collections.MEDICAL_MEMORY,
|
| 24 |
-
# validator_path: str = "schemas/account_validator.json",
|
| 25 |
-
# drop: bool = False
|
| 26 |
-
#):
|
| 27 |
-
# if drop:
|
| 28 |
-
# get_collection(collection_name).drop()
|
| 29 |
-
# setup_collection(collection_name, validator_path)
|
| 30 |
-
|
| 31 |
-
def create_medical_record(
|
| 32 |
-
record_data: dict[str, Any],
|
| 33 |
-
/, *,
|
| 34 |
-
collection_name: str = Collections.MEDICAL_RECORDS
|
| 35 |
-
) -> str:
|
| 36 |
-
"""Creates a new medical record."""
|
| 37 |
-
collection = get_collection(collection_name)
|
| 38 |
-
now = datetime.now(timezone.utc)
|
| 39 |
-
record_data.update({"created_at": now, "updated_at": now})
|
| 40 |
-
result = collection.insert_one(record_data)
|
| 41 |
-
return str(result.inserted_id)
|
| 42 |
-
|
| 43 |
-
def get_user_medical_records(
|
| 44 |
-
user_id: str,
|
| 45 |
-
/, *,
|
| 46 |
-
collection_name: str = Collections.MEDICAL_RECORDS
|
| 47 |
-
) -> list[dict[str, Any]]:
|
| 48 |
-
"""Retrieves all medical records for a specific user."""
|
| 49 |
-
collection = get_collection(collection_name)
|
| 50 |
-
cursor = collection.find(
|
| 51 |
-
{"user_id": user_id}
|
| 52 |
-
).sort(
|
| 53 |
-
"created_at", ASCENDING
|
| 54 |
-
)
|
| 55 |
-
return list(cursor)
|
| 56 |
-
|
| 57 |
-
def add_medical_context(
|
| 58 |
-
user_id: str,
|
| 59 |
-
/,
|
| 60 |
-
summary: str,
|
| 61 |
-
*,
|
| 62 |
-
collection_name: str = Collections.MEDICAL_MEMORY
|
| 63 |
-
) -> str:
|
| 64 |
-
"""Adds a medical context summary for a user."""
|
| 65 |
-
collection = get_collection(collection_name)
|
| 66 |
-
doc = {
|
| 67 |
-
"_id": str(ObjectId()),
|
| 68 |
-
"user_id": user_id,
|
| 69 |
-
"summary": summary,
|
| 70 |
-
"timestamp": datetime.now(timezone.utc)
|
| 71 |
-
}
|
| 72 |
-
result = collection.insert_one(doc)
|
| 73 |
-
return str(result.inserted_id)
|
| 74 |
-
|
| 75 |
-
def get_medical_context(
|
| 76 |
-
user_id: str,
|
| 77 |
-
/,
|
| 78 |
-
limit: int | None = None,
|
| 79 |
-
*,
|
| 80 |
-
collection_name: str = Collections.MEDICAL_MEMORY
|
| 81 |
-
) -> list[dict[str, Any]]:
|
| 82 |
-
"""Retrieves medical context summaries for a user."""
|
| 83 |
-
collection = get_collection(collection_name)
|
| 84 |
-
cursor = collection.find(
|
| 85 |
-
{"user_id": user_id}
|
| 86 |
-
).sort(
|
| 87 |
-
"timestamp", DESCENDING
|
| 88 |
-
)
|
| 89 |
-
if limit:
|
| 90 |
-
cursor = cursor.limit(limit)
|
| 91 |
-
return list(cursor)
|
| 92 |
-
|
| 93 |
-
# TODO Delete context
|
| 94 |
-
|
| 95 |
-
def save_memory_summary(
|
| 96 |
-
*,
|
| 97 |
-
patient_id: str,
|
| 98 |
-
doctor_id: str,
|
| 99 |
-
summary: str,
|
| 100 |
-
embedding: list[float] | None = None,
|
| 101 |
-
created_at: datetime | None = None,
|
| 102 |
-
collection_name: str = Collections.MEDICAL_MEMORY
|
| 103 |
-
) -> str:
|
| 104 |
-
collection = get_collection(collection_name)
|
| 105 |
-
ts = created_at or datetime.now(timezone.utc)
|
| 106 |
-
doc = {
|
| 107 |
-
"patient_id": patient_id,
|
| 108 |
-
"doctor_id": doctor_id,
|
| 109 |
-
"summary": summary,
|
| 110 |
-
"created_at": ts
|
| 111 |
-
}
|
| 112 |
-
if embedding is not None:
|
| 113 |
-
doc["embedding"] = embedding
|
| 114 |
-
result = collection.insert_one(doc)
|
| 115 |
-
return str(result.inserted_id)
|
| 116 |
-
|
| 117 |
-
def get_recent_memory_summaries(
|
| 118 |
-
patient_id: str,
|
| 119 |
-
/,
|
| 120 |
-
*,
|
| 121 |
-
limit: int = 20,
|
| 122 |
-
collection_name: str = Collections.MEDICAL_MEMORY
|
| 123 |
-
) -> list[str]:
|
| 124 |
-
collection = get_collection(collection_name)
|
| 125 |
-
docs = list(collection.find({"patient_id": patient_id}).sort("created_at", DESCENDING).limit(limit))
|
| 126 |
-
return [d.get("summary", "") for d in docs]
|
| 127 |
-
|
| 128 |
-
def search_memory_summaries_semantic(
|
| 129 |
-
patient_id: str,
|
| 130 |
-
query_embedding: list[float],
|
| 131 |
-
/,
|
| 132 |
-
*,
|
| 133 |
-
limit: int = 5,
|
| 134 |
-
similarity_threshold: float = 0.5, # >= 50% semantic similarity
|
| 135 |
-
collection_name: str = Collections.MEDICAL_MEMORY
|
| 136 |
-
) -> list[dict[str, Any]]:
|
| 137 |
-
"""
|
| 138 |
-
Search memory summaries using semantic similarity with embeddings.
|
| 139 |
-
Returns list of {summary, similarity_score, created_at} sorted by similarity.
|
| 140 |
-
"""
|
| 141 |
-
collection = get_collection(collection_name)
|
| 142 |
-
|
| 143 |
-
# Get all summaries with embeddings for this patient
|
| 144 |
-
docs = list(collection.find({
|
| 145 |
-
"patient_id": patient_id,
|
| 146 |
-
"embedding": {"$exists": True}
|
| 147 |
-
}))
|
| 148 |
-
|
| 149 |
-
if not docs:
|
| 150 |
-
return []
|
| 151 |
-
|
| 152 |
-
# Calculate similarities
|
| 153 |
-
import numpy as np
|
| 154 |
-
query_vec = np.array(query_embedding, dtype="float32")
|
| 155 |
-
results = []
|
| 156 |
-
|
| 157 |
-
for doc in docs:
|
| 158 |
-
embedding = doc.get("embedding")
|
| 159 |
-
if not embedding:
|
| 160 |
-
continue
|
| 161 |
-
|
| 162 |
-
# Calculate cosine similarity
|
| 163 |
-
doc_vec = np.array(embedding, dtype="float32")
|
| 164 |
-
dot_product = np.dot(query_vec, doc_vec)
|
| 165 |
-
norm_query = np.linalg.norm(query_vec)
|
| 166 |
-
norm_doc = np.linalg.norm(doc_vec)
|
| 167 |
-
|
| 168 |
-
if norm_query == 0 or norm_doc == 0:
|
| 169 |
-
similarity = 0.0
|
| 170 |
-
else:
|
| 171 |
-
similarity = float(dot_product / (norm_query * norm_doc))
|
| 172 |
-
|
| 173 |
-
if similarity >= similarity_threshold:
|
| 174 |
-
results.append({
|
| 175 |
-
"summary": doc.get("summary", ""),
|
| 176 |
-
"similarity_score": similarity,
|
| 177 |
-
"created_at": doc.get("created_at"),
|
| 178 |
-
"session_id": doc.get("session_id") # if we add this field later
|
| 179 |
-
})
|
| 180 |
-
|
| 181 |
-
# Sort by similarity (highest first) and return top results
|
| 182 |
-
results.sort(key=lambda x: x["similarity_score"], reverse=True)
|
| 183 |
-
return results[:limit]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/data/repositories/medical_memory.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/data/repositories/medical_memory.py
|
| 2 |
+
"""
|
| 3 |
+
Medical memory management operations for MongoDB.
|
| 4 |
+
Medical memories are unstructured summaries, often with vector embeddings for semantic search.
|
| 5 |
+
|
| 6 |
+
## Fields
|
| 7 |
+
_id: index
|
| 8 |
+
patient_id: The patient this memory relates to
|
| 9 |
+
doctor_id: The doctor involved in the context of this memory
|
| 10 |
+
session_id: The chat session this memory was derived from (optional)
|
| 11 |
+
summary: The unstructured text summary of the medical context
|
| 12 |
+
embedding: The vector embedding of the summary for semantic search (optional)
|
| 13 |
+
created_at: The timestamp when the memory was created
|
| 14 |
+
"""
|
| 15 |
+
from datetime import datetime, timezone
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
from bson import ObjectId
|
| 19 |
+
from bson.errors import InvalidId
|
| 20 |
+
from pymongo import DESCENDING
|
| 21 |
+
from pymongo.errors import ConnectionFailure, PyMongoError, WriteError
|
| 22 |
+
|
| 23 |
+
from src.data.connection import (ActionFailed, Collections, get_collection,
|
| 24 |
+
setup_collection)
|
| 25 |
+
from src.models.medical import MedicalMemory, SemanticSearchResult
|
| 26 |
+
from src.utils.logger import logger
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def init(
|
| 30 |
+
*,
|
| 31 |
+
collection_name: str = Collections.MEDICAL_MEMORY,
|
| 32 |
+
validator_path: str = "schemas/medical_memory_validator.json",
|
| 33 |
+
drop: bool = False
|
| 34 |
+
):
|
| 35 |
+
"""Initializes the medical_memory collection, applying schema validation."""
|
| 36 |
+
try:
|
| 37 |
+
if drop:
|
| 38 |
+
get_collection(collection_name).drop()
|
| 39 |
+
setup_collection(collection_name, validator_path)
|
| 40 |
+
except (ConnectionFailure, PyMongoError) as e:
|
| 41 |
+
logger().error(f"Failed to initialize collection '{collection_name}': {e}")
|
| 42 |
+
raise ActionFailed(f"Database operation failed during initialization: {e}") from e
|
| 43 |
+
|
| 44 |
+
def create_memory(
|
| 45 |
+
patient_id: str,
|
| 46 |
+
doctor_id: str,
|
| 47 |
+
summary: str,
|
| 48 |
+
session_id: str | None = None,
|
| 49 |
+
embedding: list[float] | None = None,
|
| 50 |
+
*,
|
| 51 |
+
collection_name: str = Collections.MEDICAL_MEMORY
|
| 52 |
+
) -> str:
|
| 53 |
+
"""Saves a new medical memory summary, raising ActionFailed on error."""
|
| 54 |
+
try:
|
| 55 |
+
collection = get_collection(collection_name)
|
| 56 |
+
doc = {
|
| 57 |
+
"patient_id": ObjectId(patient_id),
|
| 58 |
+
"doctor_id": ObjectId(doctor_id),
|
| 59 |
+
"summary": summary,
|
| 60 |
+
"created_at": datetime.now(timezone.utc)
|
| 61 |
+
}
|
| 62 |
+
if session_id:
|
| 63 |
+
doc["session_id"] = ObjectId(session_id)
|
| 64 |
+
if embedding:
|
| 65 |
+
doc["embedding"] = embedding
|
| 66 |
+
|
| 67 |
+
result = collection.insert_one(doc)
|
| 68 |
+
return str(result.inserted_id)
|
| 69 |
+
except InvalidId as e:
|
| 70 |
+
logger().error(f"Invalid ObjectId format provided for medical memory: {e}")
|
| 71 |
+
raise ActionFailed("Patient, Doctor, or Session ID is not a valid format.") from e
|
| 72 |
+
except (WriteError, ConnectionFailure, PyMongoError) as e:
|
| 73 |
+
logger().error(f"Failed to create medical memory: {e}")
|
| 74 |
+
raise ActionFailed("A database error occurred while creating the medical memory.") from e
|
| 75 |
+
|
| 76 |
+
def get_recent_memories(
|
| 77 |
+
patient_id: str,
|
| 78 |
+
limit: int = 20,
|
| 79 |
+
*,
|
| 80 |
+
collection_name: str = Collections.MEDICAL_MEMORY
|
| 81 |
+
) -> list[MedicalMemory]:
|
| 82 |
+
"""Retrieves the most recent memory summaries for a patient."""
|
| 83 |
+
try:
|
| 84 |
+
obj_patient_id = ObjectId(patient_id)
|
| 85 |
+
collection = get_collection(collection_name)
|
| 86 |
+
cursor = collection.find(
|
| 87 |
+
{"patient_id": obj_patient_id}
|
| 88 |
+
).sort("created_at", DESCENDING).limit(limit)
|
| 89 |
+
|
| 90 |
+
return [MedicalMemory.model_validate(doc) for doc in cursor]
|
| 91 |
+
except InvalidId as e:
|
| 92 |
+
logger().error(f"Invalid patient_id format for get_recent_memories: '{patient_id}'")
|
| 93 |
+
raise ActionFailed("The provided patient ID is not a valid format.") from e
|
| 94 |
+
except (ConnectionFailure, PyMongoError) as e:
|
| 95 |
+
logger().error(f"Database error retrieving recent memories for patient '{patient_id}': {e}")
|
| 96 |
+
raise ActionFailed("A database error occurred while retrieving recent memories.") from e
|
| 97 |
+
|
| 98 |
+
def search_memories_semantic(
|
| 99 |
+
patient_id: str,
|
| 100 |
+
query_embedding: list[float],
|
| 101 |
+
limit: int = 5,
|
| 102 |
+
*,
|
| 103 |
+
collection_name: str = Collections.MEDICAL_MEMORY
|
| 104 |
+
) -> list[SemanticSearchResult]:
|
| 105 |
+
"""Searches memory summaries using semantic similarity with embeddings."""
|
| 106 |
+
try:
|
| 107 |
+
obj_patient_id = ObjectId(patient_id)
|
| 108 |
+
collection = get_collection(collection_name)
|
| 109 |
+
|
| 110 |
+
# In a real-world scenario, this would be an Atlas Vector Search query.
|
| 111 |
+
# This implementation fetches all docs and calculates similarity in the client.
|
| 112 |
+
docs = list(collection.find({
|
| 113 |
+
"patient_id": obj_patient_id,
|
| 114 |
+
"embedding": {"$exists": True}
|
| 115 |
+
}))
|
| 116 |
+
|
| 117 |
+
if not docs:
|
| 118 |
+
return []
|
| 119 |
+
|
| 120 |
+
query_vec = np.array(query_embedding, dtype="float32")
|
| 121 |
+
results = []
|
| 122 |
+
for doc in docs:
|
| 123 |
+
doc_vec = np.array(doc["embedding"], dtype="float32")
|
| 124 |
+
|
| 125 |
+
# Calculate cosine similarity
|
| 126 |
+
dot_product = np.dot(query_vec, doc_vec)
|
| 127 |
+
norm_query = np.linalg.norm(query_vec)
|
| 128 |
+
norm_doc = np.linalg.norm(doc_vec)
|
| 129 |
+
|
| 130 |
+
if norm_query > 0 and norm_doc > 0:
|
| 131 |
+
similarity = float(dot_product / (norm_query * norm_doc))
|
| 132 |
+
result_data = {
|
| 133 |
+
"summary": doc["summary"],
|
| 134 |
+
"similarity_score": similarity,
|
| 135 |
+
"created_at": doc["created_at"],
|
| 136 |
+
"session_id": doc.get("session_id")
|
| 137 |
+
}
|
| 138 |
+
results.append(SemanticSearchResult.model_validate(result_data))
|
| 139 |
+
|
| 140 |
+
# Sort by similarity (highest first) and return top results
|
| 141 |
+
results.sort(key=lambda x: x.similarity_score, reverse=True)
|
| 142 |
+
return results[:limit]
|
| 143 |
+
except InvalidId as e:
|
| 144 |
+
logger().error(f"Invalid patient_id format for semantic search: '{patient_id}'")
|
| 145 |
+
raise ActionFailed("The provided patient ID is not a valid format.") from e
|
| 146 |
+
except (ConnectionFailure, PyMongoError) as e:
|
| 147 |
+
logger().error(f"Database error during semantic search for patient '{patient_id}': {e}")
|
| 148 |
+
raise ActionFailed("A database error occurred during the semantic search.") from e
|
src/data/repositories/medical_record.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/data/repositories/medical_record.py
|
| 2 |
+
"""
|
| 3 |
+
Medical record management operations for MongoDB.
|
| 4 |
+
Medical records are structured, factual pieces of information about a patient.
|
| 5 |
+
|
| 6 |
+
## Fields
|
| 7 |
+
_id: index
|
| 8 |
+
patient_id: The patient this record belongs to
|
| 9 |
+
doctor_id: The doctor who created or is associated with this record
|
| 10 |
+
record_type: The category of the record (e.g., 'Consultation', 'LabResult')
|
| 11 |
+
details: An object containing the specific, structured data for the record
|
| 12 |
+
created_at: The timestamp when the record was created
|
| 13 |
+
updated_at: The timestamp when the record was last modified
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from datetime import datetime, timezone
|
| 17 |
+
from typing import Any
|
| 18 |
+
|
| 19 |
+
from bson import ObjectId
|
| 20 |
+
from bson.errors import InvalidId
|
| 21 |
+
from pymongo import ASCENDING
|
| 22 |
+
from pymongo.errors import ConnectionFailure, PyMongoError, WriteError
|
| 23 |
+
|
| 24 |
+
from src.data.connection import (ActionFailed, Collections, get_collection,
|
| 25 |
+
setup_collection)
|
| 26 |
+
from src.models.medical import MedicalRecord
|
| 27 |
+
from src.utils.logger import logger
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def init(
|
| 31 |
+
*,
|
| 32 |
+
collection_name: str = Collections.MEDICAL_RECORDS,
|
| 33 |
+
validator_path: str = "schemas/medical_record_validator.json",
|
| 34 |
+
drop: bool = False
|
| 35 |
+
):
|
| 36 |
+
"""Initializes the medical_records collection, applying schema validation."""
|
| 37 |
+
try:
|
| 38 |
+
if drop:
|
| 39 |
+
get_collection(collection_name).drop()
|
| 40 |
+
setup_collection(collection_name, validator_path)
|
| 41 |
+
except (ConnectionFailure, PyMongoError) as e:
|
| 42 |
+
logger().error(f"Failed to initialize collection '{collection_name}': {e}")
|
| 43 |
+
raise ActionFailed(f"Database operation failed during initialization: {e}") from e
|
| 44 |
+
|
| 45 |
+
def create_medical_record(
|
| 46 |
+
patient_id: str,
|
| 47 |
+
doctor_id: str,
|
| 48 |
+
record_type: str,
|
| 49 |
+
details: dict[str, Any],
|
| 50 |
+
*,
|
| 51 |
+
collection_name: str = Collections.MEDICAL_RECORDS
|
| 52 |
+
) -> str:
|
| 53 |
+
"""Creates a new medical record, raising ActionFailed on error."""
|
| 54 |
+
now = datetime.now(timezone.utc)
|
| 55 |
+
try:
|
| 56 |
+
collection = get_collection(collection_name)
|
| 57 |
+
record_data = {
|
| 58 |
+
"patient_id": ObjectId(patient_id),
|
| 59 |
+
"doctor_id": ObjectId(doctor_id),
|
| 60 |
+
"record_type": record_type,
|
| 61 |
+
"details": details,
|
| 62 |
+
"created_at": now,
|
| 63 |
+
"updated_at": now
|
| 64 |
+
}
|
| 65 |
+
result = collection.insert_one(record_data)
|
| 66 |
+
return str(result.inserted_id)
|
| 67 |
+
except InvalidId as e:
|
| 68 |
+
logger().error(f"Invalid ObjectId format provided for medical record: {e}")
|
| 69 |
+
raise ActionFailed("Patient ID or Doctor ID is not a valid format.") from e
|
| 70 |
+
except (WriteError, ConnectionFailure, PyMongoError) as e:
|
| 71 |
+
logger().error(f"Failed to create medical record: {e}")
|
| 72 |
+
raise ActionFailed("A database error occurred while creating the medical record.") from e
|
| 73 |
+
|
| 74 |
+
def get_records_for_patient(
|
| 75 |
+
patient_id: str,
|
| 76 |
+
*,
|
| 77 |
+
collection_name: str = Collections.MEDICAL_RECORDS
|
| 78 |
+
) -> list[MedicalRecord]:
|
| 79 |
+
"""Retrieves all medical records for a patient, sorted by creation date."""
|
| 80 |
+
try:
|
| 81 |
+
obj_patient_id = ObjectId(patient_id)
|
| 82 |
+
collection = get_collection(collection_name)
|
| 83 |
+
cursor = collection.find(
|
| 84 |
+
{"patient_id": obj_patient_id}
|
| 85 |
+
).sort("created_at", ASCENDING)
|
| 86 |
+
|
| 87 |
+
return [MedicalRecord.model_validate(doc) for doc in cursor]
|
| 88 |
+
except InvalidId as e:
|
| 89 |
+
logger().error(f"Invalid patient_id format for get_records_for_patient: '{patient_id}'")
|
| 90 |
+
raise ActionFailed("The provided patient ID is not a valid format.") from e
|
| 91 |
+
except (ConnectionFailure, PyMongoError) as e:
|
| 92 |
+
logger().error(f"Database error retrieving records for patient '{patient_id}': {e}")
|
| 93 |
+
raise ActionFailed("A database error occurred while retrieving medical records.") from e
|
src/models/medical.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/models/medical.py
|
| 2 |
+
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
from pydantic import BaseModel, ConfigDict
|
| 7 |
+
|
| 8 |
+
from models.repositories import BaseMongoModel, PyObjectId
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class MedicalRecord(BaseMongoModel):
|
| 12 |
+
"""A Pydantic model for a structured medical record."""
|
| 13 |
+
patient_id: PyObjectId
|
| 14 |
+
doctor_id: PyObjectId
|
| 15 |
+
record_type: str
|
| 16 |
+
details: dict[str, Any]
|
| 17 |
+
created_at: datetime
|
| 18 |
+
updated_at: datetime
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class MedicalMemory(BaseMongoModel):
|
| 22 |
+
"""A Pydantic model for a medical memory summary, used for semantic search."""
|
| 23 |
+
patient_id: PyObjectId
|
| 24 |
+
doctor_id: PyObjectId
|
| 25 |
+
session_id: PyObjectId | None = None
|
| 26 |
+
summary: str
|
| 27 |
+
embedding: list[float] | None = None
|
| 28 |
+
created_at: datetime
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class SemanticSearchResult(BaseModel):
|
| 32 |
+
"""A Pydantic model for the result of a semantic search."""
|
| 33 |
+
summary: str
|
| 34 |
+
similarity_score: float
|
| 35 |
+
created_at: datetime
|
| 36 |
+
session_id: PyObjectId | None = None
|
| 37 |
+
|
| 38 |
+
model_config = ConfigDict(frozen=True, from_attributes=True)
|
tests/test_medical_memory.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
from unittest.mock import patch
|
| 3 |
+
|
| 4 |
+
from bson import ObjectId
|
| 5 |
+
from pymongo.errors import ConnectionFailure
|
| 6 |
+
|
| 7 |
+
from src.data.connection import ActionFailed, Collections
|
| 8 |
+
from src.data.repositories import medical_memory as medical_memory_repo
|
| 9 |
+
from src.models.medical import MedicalMemory, SemanticSearchResult
|
| 10 |
+
from src.utils.logger import logger
|
| 11 |
+
from tests.base_test import BaseMongoTest
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TestMedicalMemoryRepository(BaseMongoTest):
|
| 15 |
+
"""Test class for the 'happy path' of medical memory repository functions."""
|
| 16 |
+
|
| 17 |
+
def setUp(self):
|
| 18 |
+
"""Set up a clean test environment before each test."""
|
| 19 |
+
super().setUp()
|
| 20 |
+
self.test_collection = self._collections[Collections.MEDICAL_MEMORY]
|
| 21 |
+
medical_memory_repo.init(collection_name=self.test_collection, drop=True)
|
| 22 |
+
self.patient_id = str(ObjectId())
|
| 23 |
+
self.doctor_id = str(ObjectId())
|
| 24 |
+
self.session_id = str(ObjectId())
|
| 25 |
+
self.embedding = [0.1, 0.2, 0.3]
|
| 26 |
+
|
| 27 |
+
def test_init_functionality(self):
|
| 28 |
+
"""Test that the init function correctly sets up the collection."""
|
| 29 |
+
self.assertIn(self.test_collection, self.db.list_collection_names())
|
| 30 |
+
|
| 31 |
+
def test_create_memory(self):
|
| 32 |
+
"""Test successful creation of a medical memory with and without optional fields."""
|
| 33 |
+
# Test full creation
|
| 34 |
+
memory_id = medical_memory_repo.create_memory(
|
| 35 |
+
self.patient_id, self.doctor_id, "Full summary", self.session_id, self.embedding,
|
| 36 |
+
collection_name=self.test_collection
|
| 37 |
+
)
|
| 38 |
+
self.assertIsInstance(memory_id, str)
|
| 39 |
+
doc = self.get_doc_by_id(Collections.MEDICAL_MEMORY, memory_id)
|
| 40 |
+
self.assertIsNotNone(doc)
|
| 41 |
+
self.assertEqual(doc["summary"], "Full summary") # type: ignore
|
| 42 |
+
self.assertEqual(len(doc["embedding"]), 3) # type: ignore
|
| 43 |
+
|
| 44 |
+
# Test minimal creation
|
| 45 |
+
min_id = medical_memory_repo.create_memory(
|
| 46 |
+
self.patient_id, self.doctor_id, "Minimal summary", collection_name=self.test_collection
|
| 47 |
+
)
|
| 48 |
+
self.assertIsInstance(min_id, str)
|
| 49 |
+
|
| 50 |
+
def test_get_recent_memories(self):
|
| 51 |
+
"""Test retrieving recent memories, verifying sorting, filtering, and limit."""
|
| 52 |
+
medical_memory_repo.create_memory(self.patient_id, self.doctor_id, "Oldest", collection_name=self.test_collection)
|
| 53 |
+
medical_memory_repo.create_memory(str(ObjectId()), self.doctor_id, "Other Patient", collection_name=self.test_collection)
|
| 54 |
+
medical_memory_repo.create_memory(self.patient_id, self.doctor_id, "Newest", collection_name=self.test_collection)
|
| 55 |
+
|
| 56 |
+
memories = medical_memory_repo.get_recent_memories(self.patient_id, collection_name=self.test_collection)
|
| 57 |
+
self.assertEqual(len(memories), 2)
|
| 58 |
+
self.assertIsInstance(memories[0], MedicalMemory)
|
| 59 |
+
self.assertEqual(memories[0].summary, "Newest") # Descending sort order
|
| 60 |
+
|
| 61 |
+
# Test limit
|
| 62 |
+
limited = medical_memory_repo.get_recent_memories(self.patient_id, limit=1, collection_name=self.test_collection)
|
| 63 |
+
self.assertEqual(len(limited), 1)
|
| 64 |
+
|
| 65 |
+
def test_search_memories_semantic(self):
|
| 66 |
+
"""Test semantic search functionality, verifying similarity logic and sorting."""
|
| 67 |
+
# Create memories with known embeddings
|
| 68 |
+
vec_a = [1.0, 0.0, 0.0] # Most similar
|
| 69 |
+
vec_b = [0.7, 0.7, 0.0] # Less similar
|
| 70 |
+
vec_c = [0.0, 0.0, 1.0] # Not similar
|
| 71 |
+
medical_memory_repo.create_memory(self.patient_id, self.doctor_id, "Vec A", embedding=vec_a, collection_name=self.test_collection)
|
| 72 |
+
medical_memory_repo.create_memory(self.patient_id, self.doctor_id, "Vec B", embedding=vec_b, collection_name=self.test_collection)
|
| 73 |
+
medical_memory_repo.create_memory(self.patient_id, self.doctor_id, "Vec C", embedding=vec_c, collection_name=self.test_collection)
|
| 74 |
+
medical_memory_repo.create_memory(self.patient_id, self.doctor_id, "No Embedding", collection_name=self.test_collection)
|
| 75 |
+
|
| 76 |
+
query_embedding = [0.9, 0.1, 0.0]
|
| 77 |
+
results = medical_memory_repo.search_memories_semantic(self.patient_id, query_embedding, collection_name=self.test_collection)
|
| 78 |
+
|
| 79 |
+
self.assertEqual(len(results), 3) # Vec C should be filtered by default numpy math
|
| 80 |
+
self.assertIsInstance(results[0], SemanticSearchResult)
|
| 81 |
+
self.assertEqual(results[0].summary, "Vec A") # Most similar should be first
|
| 82 |
+
self.assertGreater(results[0].similarity_score, results[1].similarity_score)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class TestMedicalMemoryRepositoryExceptions(BaseMongoTest):
|
| 86 |
+
"""Test class for the exception handling of medical memory repository functions."""
|
| 87 |
+
|
| 88 |
+
def setUp(self):
|
| 89 |
+
"""Set up the test environment before each test."""
|
| 90 |
+
super().setUp()
|
| 91 |
+
self.test_collection = self._collections[Collections.MEDICAL_MEMORY]
|
| 92 |
+
medical_memory_repo.init(collection_name=self.test_collection, drop=True)
|
| 93 |
+
self.patient_id = str(ObjectId())
|
| 94 |
+
self.doctor_id = str(ObjectId())
|
| 95 |
+
|
| 96 |
+
def test_invalid_id_raises_action_failed(self):
|
| 97 |
+
"""Test that functions raise ActionFailed when given a malformed ObjectId string."""
|
| 98 |
+
with self.assertRaises(ActionFailed):
|
| 99 |
+
medical_memory_repo.create_memory("bad-id", self.doctor_id, "t", collection_name=self.test_collection)
|
| 100 |
+
with self.assertRaises(ActionFailed):
|
| 101 |
+
medical_memory_repo.get_recent_memories("bad-id", collection_name=self.test_collection)
|
| 102 |
+
with self.assertRaises(ActionFailed):
|
| 103 |
+
medical_memory_repo.search_memories_semantic("bad-id", [], collection_name=self.test_collection)
|
| 104 |
+
|
| 105 |
+
@patch('src.data.repositories.medical_memory.get_collection')
|
| 106 |
+
def test_all_functions_raise_on_connection_error(self, mock_get_collection):
|
| 107 |
+
"""Test that all repo functions catch generic PyMongoErrors and raise ActionFailed."""
|
| 108 |
+
mock_get_collection.side_effect = ConnectionFailure("Simulated connection error")
|
| 109 |
+
|
| 110 |
+
with self.assertRaises(ActionFailed):
|
| 111 |
+
medical_memory_repo.init(collection_name=self.test_collection, drop=True)
|
| 112 |
+
with self.assertRaises(ActionFailed):
|
| 113 |
+
medical_memory_repo.create_memory(self.patient_id, self.doctor_id, "t", collection_name=self.test_collection)
|
| 114 |
+
with self.assertRaises(ActionFailed):
|
| 115 |
+
medical_memory_repo.get_recent_memories(self.patient_id, collection_name=self.test_collection)
|
| 116 |
+
with self.assertRaises(ActionFailed):
|
| 117 |
+
medical_memory_repo.search_memories_semantic(self.patient_id, [], collection_name=self.test_collection)
|
| 118 |
+
|
| 119 |
+
if __name__ == "__main__":
|
| 120 |
+
logger().info("Starting MongoDB repository integration tests...")
|
| 121 |
+
unittest.main(verbosity=2)
|
| 122 |
+
logger().info("Tests completed and database connection closed.")
|
tests/test_medical_record.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import unittest
|
| 3 |
+
from unittest.mock import patch
|
| 4 |
+
|
| 5 |
+
from bson import ObjectId
|
| 6 |
+
from pymongo.errors import ConnectionFailure
|
| 7 |
+
|
| 8 |
+
from src.data.connection import ActionFailed, Collections
|
| 9 |
+
from src.data.repositories import medical_record as medical_record_repo
|
| 10 |
+
from src.models.medical import MedicalRecord
|
| 11 |
+
from src.utils.logger import logger
|
| 12 |
+
from tests.base_test import BaseMongoTest
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TestMedicalRecordRepository(BaseMongoTest):
|
| 16 |
+
"""Test class for the 'happy path' of medical record repository functions."""
|
| 17 |
+
|
| 18 |
+
def setUp(self):
|
| 19 |
+
"""Set up a clean test environment before each test."""
|
| 20 |
+
super().setUp()
|
| 21 |
+
self.test_collection = self._collections[Collections.MEDICAL_RECORDS]
|
| 22 |
+
medical_record_repo.init(collection_name=self.test_collection, drop=True)
|
| 23 |
+
self.patient_id = str(ObjectId())
|
| 24 |
+
self.doctor_id = str(ObjectId())
|
| 25 |
+
|
| 26 |
+
def test_init_functionality(self):
|
| 27 |
+
"""Test that the init function correctly sets up the collection."""
|
| 28 |
+
self.assertIn(self.test_collection, self.db.list_collection_names())
|
| 29 |
+
|
| 30 |
+
def test_create_medical_record(self):
|
| 31 |
+
"""Test successful creation of a medical record."""
|
| 32 |
+
record_id = medical_record_repo.create_medical_record(
|
| 33 |
+
patient_id=self.patient_id,
|
| 34 |
+
doctor_id=self.doctor_id,
|
| 35 |
+
record_type="Consultation",
|
| 36 |
+
details={"symptoms": "Fever, cough", "diagnosis": "Common cold"},
|
| 37 |
+
collection_name=self.test_collection
|
| 38 |
+
)
|
| 39 |
+
self.assertIsInstance(record_id, str)
|
| 40 |
+
doc = self.get_doc_by_id(Collections.MEDICAL_RECORDS, record_id)
|
| 41 |
+
self.assertIsNotNone(doc)
|
| 42 |
+
self.assertEqual(doc["record_type"], "Consultation") # type: ignore
|
| 43 |
+
self.assertEqual(str(doc["patient_id"]), self.patient_id) # type: ignore
|
| 44 |
+
|
| 45 |
+
def test_get_records_for_patient(self):
|
| 46 |
+
"""Test retrieving all records for a patient, verifying sorting and filtering."""
|
| 47 |
+
other_patient_id = str(ObjectId())
|
| 48 |
+
# Create records, sleeping to ensure distinct creation timestamps for sorting check
|
| 49 |
+
r1_id = medical_record_repo.create_medical_record(self.patient_id, self.doctor_id, "R1", {}, collection_name=self.test_collection)
|
| 50 |
+
time.sleep(0.01)
|
| 51 |
+
medical_record_repo.create_medical_record(other_patient_id, self.doctor_id, "Other", {}, collection_name=self.test_collection)
|
| 52 |
+
time.sleep(0.01)
|
| 53 |
+
r2_id = medical_record_repo.create_medical_record(self.patient_id, self.doctor_id, "R2", {}, collection_name=self.test_collection)
|
| 54 |
+
|
| 55 |
+
# Retrieve records for the target patient
|
| 56 |
+
records = medical_record_repo.get_records_for_patient(self.patient_id, collection_name=self.test_collection)
|
| 57 |
+
|
| 58 |
+
# Verify correct filtering, count, and type
|
| 59 |
+
self.assertEqual(len(records), 2)
|
| 60 |
+
self.assertIsInstance(records[0], MedicalRecord)
|
| 61 |
+
|
| 62 |
+
# Verify sorting (ascending by creation date)
|
| 63 |
+
self.assertEqual(records[0].id, r1_id)
|
| 64 |
+
self.assertEqual(records[1].id, r2_id)
|
| 65 |
+
|
| 66 |
+
# Test edge case: patient with no records
|
| 67 |
+
no_records = medical_record_repo.get_records_for_patient(str(ObjectId()), collection_name=self.test_collection)
|
| 68 |
+
self.assertEqual(len(no_records), 0)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class TestMedicalRecordRepositoryExceptions(BaseMongoTest):
|
| 72 |
+
"""Test class for the exception handling of medical record repository functions."""
|
| 73 |
+
|
| 74 |
+
def setUp(self):
|
| 75 |
+
"""Set up the test environment before each test."""
|
| 76 |
+
super().setUp()
|
| 77 |
+
self.test_collection = self._collections[Collections.MEDICAL_RECORDS]
|
| 78 |
+
medical_record_repo.init(collection_name=self.test_collection, drop=True)
|
| 79 |
+
self.patient_id = str(ObjectId())
|
| 80 |
+
self.doctor_id = str(ObjectId())
|
| 81 |
+
|
| 82 |
+
def test_invalid_id_raises_action_failed(self):
|
| 83 |
+
"""Test that functions raise ActionFailed when given a malformed ObjectId string."""
|
| 84 |
+
with self.assertRaises(ActionFailed):
|
| 85 |
+
medical_record_repo.create_medical_record("bad-id", self.doctor_id, "t", {}, collection_name=self.test_collection)
|
| 86 |
+
with self.assertRaises(ActionFailed):
|
| 87 |
+
medical_record_repo.get_records_for_patient("bad-id", collection_name=self.test_collection)
|
| 88 |
+
|
| 89 |
+
@patch('src.data.repositories.medical_record.get_collection')
|
| 90 |
+
def test_all_functions_raise_on_connection_error(self, mock_get_collection):
|
| 91 |
+
"""Test that all repo functions catch generic PyMongoErrors and raise ActionFailed."""
|
| 92 |
+
mock_get_collection.side_effect = ConnectionFailure("Simulated connection error")
|
| 93 |
+
|
| 94 |
+
with self.assertRaises(ActionFailed):
|
| 95 |
+
medical_record_repo.init(collection_name=self.test_collection, drop=True)
|
| 96 |
+
with self.assertRaises(ActionFailed):
|
| 97 |
+
medical_record_repo.create_medical_record(self.patient_id, self.doctor_id, "t", {}, collection_name=self.test_collection)
|
| 98 |
+
with self.assertRaises(ActionFailed):
|
| 99 |
+
medical_record_repo.get_records_for_patient(self.patient_id, collection_name=self.test_collection)
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
logger().info("Starting MongoDB repository integration tests...")
|
| 103 |
+
unittest.main(verbosity=2)
|
| 104 |
+
logger().info("Tests completed and database connection closed.")
|