dylanglenister commited on
Commit
368f485
·
1 Parent(s): c9573f1

Upgraded and reworked medical repo file.

Browse files

Follows the same style, structure, and logic as the other data access layer files.
Also includes tests.

schemas/medical_memory_validator.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "$jsonSchema": {
3
+ "bsonType": "object",
4
+ "title": "Medical Memory validator",
5
+ "required": [
6
+ "patient_id",
7
+ "doctor_id",
8
+ "summary",
9
+ "created_at"
10
+ ],
11
+ "properties": {
12
+ "patient_id": {
13
+ "bsonType": "objectId",
14
+ "description": "'patient_id' must be an objectId and is required."
15
+ },
16
+ "doctor_id": {
17
+ "bsonType": "objectId",
18
+ "description": "'doctor_id' must be an objectId and is required."
19
+ },
20
+ "session_id": {
21
+ "bsonType": "objectId",
22
+ "description": "'session_id' must be an objectId and is optional."
23
+ },
24
+ "summary": {
25
+ "bsonType": "string",
26
+ "description": "'summary' must be a string and is required."
27
+ },
28
+ "embedding": {
29
+ "bsonType": "array",
30
+ "description": "'embedding' must be an array of floats and is optional.",
31
+ "items": {
32
+ "bsonType": "double"
33
+ }
34
+ },
35
+ "created_at": {
36
+ "bsonType": "date",
37
+ "description": "'created_at' must be a date and is required."
38
+ }
39
+ }
40
+ }
41
+ }
schemas/medical_record_validator.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "$jsonSchema": {
3
+ "bsonType": "object",
4
+ "title": "Medical Record validator",
5
+ "required": [
6
+ "patient_id",
7
+ "doctor_id",
8
+ "record_type",
9
+ "details",
10
+ "created_at",
11
+ "updated_at"
12
+ ],
13
+ "properties": {
14
+ "patient_id": {
15
+ "bsonType": "objectId",
16
+ "description": "'patient_id' must be an objectId and is required."
17
+ },
18
+ "doctor_id": {
19
+ "bsonType": "objectId",
20
+ "description": "'doctor_id' must be an objectId and is required."
21
+ },
22
+ "record_type": {
23
+ "bsonType": "string",
24
+ "description": "'record_type' must be a string and is required."
25
+ },
26
+ "details": {
27
+ "bsonType": "object",
28
+ "description": "'details' must be an object and is required."
29
+ },
30
+ "created_at": {
31
+ "bsonType": "date",
32
+ "description": "'created_at' must be a date and is required."
33
+ },
34
+ "updated_at": {
35
+ "bsonType": "date",
36
+ "description": "'updated_at' must be a date and is required."
37
+ }
38
+ }
39
+ }
40
+ }
src/data/repositories/medical.py DELETED
@@ -1,183 +0,0 @@
1
- # data/repositories/medical.py
2
- """
3
- Medical records and memory management operations for MongoDB.
4
-
5
- @Note Could this be split into two? One for records and one for memory.
6
- """
7
-
8
- from datetime import datetime, timezone
9
- from typing import Any
10
-
11
- from bson import ObjectId
12
- from pymongo import ASCENDING, DESCENDING
13
- from pymongo.errors import (ConnectionFailure, DuplicateKeyError,
14
- OperationFailure, PyMongoError)
15
-
16
- from src.data.connection import (ActionFailed, Collections, get_collection,
17
- setup_collection)
18
- from src.utils.logger import logger
19
-
20
-
21
- #def init(
22
- # *,
23
- # collection_name: str = Collections.MEDICAL_MEMORY,
24
- # validator_path: str = "schemas/account_validator.json",
25
- # drop: bool = False
26
- #):
27
- # if drop:
28
- # get_collection(collection_name).drop()
29
- # setup_collection(collection_name, validator_path)
30
-
31
- def create_medical_record(
32
- record_data: dict[str, Any],
33
- /, *,
34
- collection_name: str = Collections.MEDICAL_RECORDS
35
- ) -> str:
36
- """Creates a new medical record."""
37
- collection = get_collection(collection_name)
38
- now = datetime.now(timezone.utc)
39
- record_data.update({"created_at": now, "updated_at": now})
40
- result = collection.insert_one(record_data)
41
- return str(result.inserted_id)
42
-
43
- def get_user_medical_records(
44
- user_id: str,
45
- /, *,
46
- collection_name: str = Collections.MEDICAL_RECORDS
47
- ) -> list[dict[str, Any]]:
48
- """Retrieves all medical records for a specific user."""
49
- collection = get_collection(collection_name)
50
- cursor = collection.find(
51
- {"user_id": user_id}
52
- ).sort(
53
- "created_at", ASCENDING
54
- )
55
- return list(cursor)
56
-
57
- def add_medical_context(
58
- user_id: str,
59
- /,
60
- summary: str,
61
- *,
62
- collection_name: str = Collections.MEDICAL_MEMORY
63
- ) -> str:
64
- """Adds a medical context summary for a user."""
65
- collection = get_collection(collection_name)
66
- doc = {
67
- "_id": str(ObjectId()),
68
- "user_id": user_id,
69
- "summary": summary,
70
- "timestamp": datetime.now(timezone.utc)
71
- }
72
- result = collection.insert_one(doc)
73
- return str(result.inserted_id)
74
-
75
- def get_medical_context(
76
- user_id: str,
77
- /,
78
- limit: int | None = None,
79
- *,
80
- collection_name: str = Collections.MEDICAL_MEMORY
81
- ) -> list[dict[str, Any]]:
82
- """Retrieves medical context summaries for a user."""
83
- collection = get_collection(collection_name)
84
- cursor = collection.find(
85
- {"user_id": user_id}
86
- ).sort(
87
- "timestamp", DESCENDING
88
- )
89
- if limit:
90
- cursor = cursor.limit(limit)
91
- return list(cursor)
92
-
93
- # TODO Delete context
94
-
95
- def save_memory_summary(
96
- *,
97
- patient_id: str,
98
- doctor_id: str,
99
- summary: str,
100
- embedding: list[float] | None = None,
101
- created_at: datetime | None = None,
102
- collection_name: str = Collections.MEDICAL_MEMORY
103
- ) -> str:
104
- collection = get_collection(collection_name)
105
- ts = created_at or datetime.now(timezone.utc)
106
- doc = {
107
- "patient_id": patient_id,
108
- "doctor_id": doctor_id,
109
- "summary": summary,
110
- "created_at": ts
111
- }
112
- if embedding is not None:
113
- doc["embedding"] = embedding
114
- result = collection.insert_one(doc)
115
- return str(result.inserted_id)
116
-
117
- def get_recent_memory_summaries(
118
- patient_id: str,
119
- /,
120
- *,
121
- limit: int = 20,
122
- collection_name: str = Collections.MEDICAL_MEMORY
123
- ) -> list[str]:
124
- collection = get_collection(collection_name)
125
- docs = list(collection.find({"patient_id": patient_id}).sort("created_at", DESCENDING).limit(limit))
126
- return [d.get("summary", "") for d in docs]
127
-
128
- def search_memory_summaries_semantic(
129
- patient_id: str,
130
- query_embedding: list[float],
131
- /,
132
- *,
133
- limit: int = 5,
134
- similarity_threshold: float = 0.5, # >= 50% semantic similarity
135
- collection_name: str = Collections.MEDICAL_MEMORY
136
- ) -> list[dict[str, Any]]:
137
- """
138
- Search memory summaries using semantic similarity with embeddings.
139
- Returns list of {summary, similarity_score, created_at} sorted by similarity.
140
- """
141
- collection = get_collection(collection_name)
142
-
143
- # Get all summaries with embeddings for this patient
144
- docs = list(collection.find({
145
- "patient_id": patient_id,
146
- "embedding": {"$exists": True}
147
- }))
148
-
149
- if not docs:
150
- return []
151
-
152
- # Calculate similarities
153
- import numpy as np
154
- query_vec = np.array(query_embedding, dtype="float32")
155
- results = []
156
-
157
- for doc in docs:
158
- embedding = doc.get("embedding")
159
- if not embedding:
160
- continue
161
-
162
- # Calculate cosine similarity
163
- doc_vec = np.array(embedding, dtype="float32")
164
- dot_product = np.dot(query_vec, doc_vec)
165
- norm_query = np.linalg.norm(query_vec)
166
- norm_doc = np.linalg.norm(doc_vec)
167
-
168
- if norm_query == 0 or norm_doc == 0:
169
- similarity = 0.0
170
- else:
171
- similarity = float(dot_product / (norm_query * norm_doc))
172
-
173
- if similarity >= similarity_threshold:
174
- results.append({
175
- "summary": doc.get("summary", ""),
176
- "similarity_score": similarity,
177
- "created_at": doc.get("created_at"),
178
- "session_id": doc.get("session_id") # if we add this field later
179
- })
180
-
181
- # Sort by similarity (highest first) and return top results
182
- results.sort(key=lambda x: x["similarity_score"], reverse=True)
183
- return results[:limit]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data/repositories/medical_memory.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/data/repositories/medical_memory.py
2
+ """
3
+ Medical memory management operations for MongoDB.
4
+ Medical memories are unstructured summaries, often with vector embeddings for semantic search.
5
+
6
+ ## Fields
7
+ _id: index
8
+ patient_id: The patient this memory relates to
9
+ doctor_id: The doctor involved in the context of this memory
10
+ session_id: The chat session this memory was derived from (optional)
11
+ summary: The unstructured text summary of the medical context
12
+ embedding: The vector embedding of the summary for semantic search (optional)
13
+ created_at: The timestamp when the memory was created
14
+ """
15
+ from datetime import datetime, timezone
16
+
17
+ import numpy as np
18
+ from bson import ObjectId
19
+ from bson.errors import InvalidId
20
+ from pymongo import DESCENDING
21
+ from pymongo.errors import ConnectionFailure, PyMongoError, WriteError
22
+
23
+ from src.data.connection import (ActionFailed, Collections, get_collection,
24
+ setup_collection)
25
+ from src.models.medical import MedicalMemory, SemanticSearchResult
26
+ from src.utils.logger import logger
27
+
28
+
29
+ def init(
30
+ *,
31
+ collection_name: str = Collections.MEDICAL_MEMORY,
32
+ validator_path: str = "schemas/medical_memory_validator.json",
33
+ drop: bool = False
34
+ ):
35
+ """Initializes the medical_memory collection, applying schema validation."""
36
+ try:
37
+ if drop:
38
+ get_collection(collection_name).drop()
39
+ setup_collection(collection_name, validator_path)
40
+ except (ConnectionFailure, PyMongoError) as e:
41
+ logger().error(f"Failed to initialize collection '{collection_name}': {e}")
42
+ raise ActionFailed(f"Database operation failed during initialization: {e}") from e
43
+
44
+ def create_memory(
45
+ patient_id: str,
46
+ doctor_id: str,
47
+ summary: str,
48
+ session_id: str | None = None,
49
+ embedding: list[float] | None = None,
50
+ *,
51
+ collection_name: str = Collections.MEDICAL_MEMORY
52
+ ) -> str:
53
+ """Saves a new medical memory summary, raising ActionFailed on error."""
54
+ try:
55
+ collection = get_collection(collection_name)
56
+ doc = {
57
+ "patient_id": ObjectId(patient_id),
58
+ "doctor_id": ObjectId(doctor_id),
59
+ "summary": summary,
60
+ "created_at": datetime.now(timezone.utc)
61
+ }
62
+ if session_id:
63
+ doc["session_id"] = ObjectId(session_id)
64
+ if embedding:
65
+ doc["embedding"] = embedding
66
+
67
+ result = collection.insert_one(doc)
68
+ return str(result.inserted_id)
69
+ except InvalidId as e:
70
+ logger().error(f"Invalid ObjectId format provided for medical memory: {e}")
71
+ raise ActionFailed("Patient, Doctor, or Session ID is not a valid format.") from e
72
+ except (WriteError, ConnectionFailure, PyMongoError) as e:
73
+ logger().error(f"Failed to create medical memory: {e}")
74
+ raise ActionFailed("A database error occurred while creating the medical memory.") from e
75
+
76
+ def get_recent_memories(
77
+ patient_id: str,
78
+ limit: int = 20,
79
+ *,
80
+ collection_name: str = Collections.MEDICAL_MEMORY
81
+ ) -> list[MedicalMemory]:
82
+ """Retrieves the most recent memory summaries for a patient."""
83
+ try:
84
+ obj_patient_id = ObjectId(patient_id)
85
+ collection = get_collection(collection_name)
86
+ cursor = collection.find(
87
+ {"patient_id": obj_patient_id}
88
+ ).sort("created_at", DESCENDING).limit(limit)
89
+
90
+ return [MedicalMemory.model_validate(doc) for doc in cursor]
91
+ except InvalidId as e:
92
+ logger().error(f"Invalid patient_id format for get_recent_memories: '{patient_id}'")
93
+ raise ActionFailed("The provided patient ID is not a valid format.") from e
94
+ except (ConnectionFailure, PyMongoError) as e:
95
+ logger().error(f"Database error retrieving recent memories for patient '{patient_id}': {e}")
96
+ raise ActionFailed("A database error occurred while retrieving recent memories.") from e
97
+
98
+ def search_memories_semantic(
99
+ patient_id: str,
100
+ query_embedding: list[float],
101
+ limit: int = 5,
102
+ *,
103
+ collection_name: str = Collections.MEDICAL_MEMORY
104
+ ) -> list[SemanticSearchResult]:
105
+ """Searches memory summaries using semantic similarity with embeddings."""
106
+ try:
107
+ obj_patient_id = ObjectId(patient_id)
108
+ collection = get_collection(collection_name)
109
+
110
+ # In a real-world scenario, this would be an Atlas Vector Search query.
111
+ # This implementation fetches all docs and calculates similarity in the client.
112
+ docs = list(collection.find({
113
+ "patient_id": obj_patient_id,
114
+ "embedding": {"$exists": True}
115
+ }))
116
+
117
+ if not docs:
118
+ return []
119
+
120
+ query_vec = np.array(query_embedding, dtype="float32")
121
+ results = []
122
+ for doc in docs:
123
+ doc_vec = np.array(doc["embedding"], dtype="float32")
124
+
125
+ # Calculate cosine similarity
126
+ dot_product = np.dot(query_vec, doc_vec)
127
+ norm_query = np.linalg.norm(query_vec)
128
+ norm_doc = np.linalg.norm(doc_vec)
129
+
130
+ if norm_query > 0 and norm_doc > 0:
131
+ similarity = float(dot_product / (norm_query * norm_doc))
132
+ result_data = {
133
+ "summary": doc["summary"],
134
+ "similarity_score": similarity,
135
+ "created_at": doc["created_at"],
136
+ "session_id": doc.get("session_id")
137
+ }
138
+ results.append(SemanticSearchResult.model_validate(result_data))
139
+
140
+ # Sort by similarity (highest first) and return top results
141
+ results.sort(key=lambda x: x.similarity_score, reverse=True)
142
+ return results[:limit]
143
+ except InvalidId as e:
144
+ logger().error(f"Invalid patient_id format for semantic search: '{patient_id}'")
145
+ raise ActionFailed("The provided patient ID is not a valid format.") from e
146
+ except (ConnectionFailure, PyMongoError) as e:
147
+ logger().error(f"Database error during semantic search for patient '{patient_id}': {e}")
148
+ raise ActionFailed("A database error occurred during the semantic search.") from e
src/data/repositories/medical_record.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/data/repositories/medical_record.py
2
+ """
3
+ Medical record management operations for MongoDB.
4
+ Medical records are structured, factual pieces of information about a patient.
5
+
6
+ ## Fields
7
+ _id: index
8
+ patient_id: The patient this record belongs to
9
+ doctor_id: The doctor who created or is associated with this record
10
+ record_type: The category of the record (e.g., 'Consultation', 'LabResult')
11
+ details: An object containing the specific, structured data for the record
12
+ created_at: The timestamp when the record was created
13
+ updated_at: The timestamp when the record was last modified
14
+ """
15
+
16
+ from datetime import datetime, timezone
17
+ from typing import Any
18
+
19
+ from bson import ObjectId
20
+ from bson.errors import InvalidId
21
+ from pymongo import ASCENDING
22
+ from pymongo.errors import ConnectionFailure, PyMongoError, WriteError
23
+
24
+ from src.data.connection import (ActionFailed, Collections, get_collection,
25
+ setup_collection)
26
+ from src.models.medical import MedicalRecord
27
+ from src.utils.logger import logger
28
+
29
+
30
+ def init(
31
+ *,
32
+ collection_name: str = Collections.MEDICAL_RECORDS,
33
+ validator_path: str = "schemas/medical_record_validator.json",
34
+ drop: bool = False
35
+ ):
36
+ """Initializes the medical_records collection, applying schema validation."""
37
+ try:
38
+ if drop:
39
+ get_collection(collection_name).drop()
40
+ setup_collection(collection_name, validator_path)
41
+ except (ConnectionFailure, PyMongoError) as e:
42
+ logger().error(f"Failed to initialize collection '{collection_name}': {e}")
43
+ raise ActionFailed(f"Database operation failed during initialization: {e}") from e
44
+
45
+ def create_medical_record(
46
+ patient_id: str,
47
+ doctor_id: str,
48
+ record_type: str,
49
+ details: dict[str, Any],
50
+ *,
51
+ collection_name: str = Collections.MEDICAL_RECORDS
52
+ ) -> str:
53
+ """Creates a new medical record, raising ActionFailed on error."""
54
+ now = datetime.now(timezone.utc)
55
+ try:
56
+ collection = get_collection(collection_name)
57
+ record_data = {
58
+ "patient_id": ObjectId(patient_id),
59
+ "doctor_id": ObjectId(doctor_id),
60
+ "record_type": record_type,
61
+ "details": details,
62
+ "created_at": now,
63
+ "updated_at": now
64
+ }
65
+ result = collection.insert_one(record_data)
66
+ return str(result.inserted_id)
67
+ except InvalidId as e:
68
+ logger().error(f"Invalid ObjectId format provided for medical record: {e}")
69
+ raise ActionFailed("Patient ID or Doctor ID is not a valid format.") from e
70
+ except (WriteError, ConnectionFailure, PyMongoError) as e:
71
+ logger().error(f"Failed to create medical record: {e}")
72
+ raise ActionFailed("A database error occurred while creating the medical record.") from e
73
+
74
+ def get_records_for_patient(
75
+ patient_id: str,
76
+ *,
77
+ collection_name: str = Collections.MEDICAL_RECORDS
78
+ ) -> list[MedicalRecord]:
79
+ """Retrieves all medical records for a patient, sorted by creation date."""
80
+ try:
81
+ obj_patient_id = ObjectId(patient_id)
82
+ collection = get_collection(collection_name)
83
+ cursor = collection.find(
84
+ {"patient_id": obj_patient_id}
85
+ ).sort("created_at", ASCENDING)
86
+
87
+ return [MedicalRecord.model_validate(doc) for doc in cursor]
88
+ except InvalidId as e:
89
+ logger().error(f"Invalid patient_id format for get_records_for_patient: '{patient_id}'")
90
+ raise ActionFailed("The provided patient ID is not a valid format.") from e
91
+ except (ConnectionFailure, PyMongoError) as e:
92
+ logger().error(f"Database error retrieving records for patient '{patient_id}': {e}")
93
+ raise ActionFailed("A database error occurred while retrieving medical records.") from e
src/models/medical.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/models/medical.py
2
+
3
+ from datetime import datetime
4
+ from typing import Any
5
+
6
+ from pydantic import BaseModel, ConfigDict
7
+
8
+ from models.repositories import BaseMongoModel, PyObjectId
9
+
10
+
11
+ class MedicalRecord(BaseMongoModel):
12
+ """A Pydantic model for a structured medical record."""
13
+ patient_id: PyObjectId
14
+ doctor_id: PyObjectId
15
+ record_type: str
16
+ details: dict[str, Any]
17
+ created_at: datetime
18
+ updated_at: datetime
19
+
20
+
21
+ class MedicalMemory(BaseMongoModel):
22
+ """A Pydantic model for a medical memory summary, used for semantic search."""
23
+ patient_id: PyObjectId
24
+ doctor_id: PyObjectId
25
+ session_id: PyObjectId | None = None
26
+ summary: str
27
+ embedding: list[float] | None = None
28
+ created_at: datetime
29
+
30
+
31
+ class SemanticSearchResult(BaseModel):
32
+ """A Pydantic model for the result of a semantic search."""
33
+ summary: str
34
+ similarity_score: float
35
+ created_at: datetime
36
+ session_id: PyObjectId | None = None
37
+
38
+ model_config = ConfigDict(frozen=True, from_attributes=True)
tests/test_medical_memory.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from unittest.mock import patch
3
+
4
+ from bson import ObjectId
5
+ from pymongo.errors import ConnectionFailure
6
+
7
+ from src.data.connection import ActionFailed, Collections
8
+ from src.data.repositories import medical_memory as medical_memory_repo
9
+ from src.models.medical import MedicalMemory, SemanticSearchResult
10
+ from src.utils.logger import logger
11
+ from tests.base_test import BaseMongoTest
12
+
13
+
14
+ class TestMedicalMemoryRepository(BaseMongoTest):
15
+ """Test class for the 'happy path' of medical memory repository functions."""
16
+
17
+ def setUp(self):
18
+ """Set up a clean test environment before each test."""
19
+ super().setUp()
20
+ self.test_collection = self._collections[Collections.MEDICAL_MEMORY]
21
+ medical_memory_repo.init(collection_name=self.test_collection, drop=True)
22
+ self.patient_id = str(ObjectId())
23
+ self.doctor_id = str(ObjectId())
24
+ self.session_id = str(ObjectId())
25
+ self.embedding = [0.1, 0.2, 0.3]
26
+
27
+ def test_init_functionality(self):
28
+ """Test that the init function correctly sets up the collection."""
29
+ self.assertIn(self.test_collection, self.db.list_collection_names())
30
+
31
+ def test_create_memory(self):
32
+ """Test successful creation of a medical memory with and without optional fields."""
33
+ # Test full creation
34
+ memory_id = medical_memory_repo.create_memory(
35
+ self.patient_id, self.doctor_id, "Full summary", self.session_id, self.embedding,
36
+ collection_name=self.test_collection
37
+ )
38
+ self.assertIsInstance(memory_id, str)
39
+ doc = self.get_doc_by_id(Collections.MEDICAL_MEMORY, memory_id)
40
+ self.assertIsNotNone(doc)
41
+ self.assertEqual(doc["summary"], "Full summary") # type: ignore
42
+ self.assertEqual(len(doc["embedding"]), 3) # type: ignore
43
+
44
+ # Test minimal creation
45
+ min_id = medical_memory_repo.create_memory(
46
+ self.patient_id, self.doctor_id, "Minimal summary", collection_name=self.test_collection
47
+ )
48
+ self.assertIsInstance(min_id, str)
49
+
50
+ def test_get_recent_memories(self):
51
+ """Test retrieving recent memories, verifying sorting, filtering, and limit."""
52
+ medical_memory_repo.create_memory(self.patient_id, self.doctor_id, "Oldest", collection_name=self.test_collection)
53
+ medical_memory_repo.create_memory(str(ObjectId()), self.doctor_id, "Other Patient", collection_name=self.test_collection)
54
+ medical_memory_repo.create_memory(self.patient_id, self.doctor_id, "Newest", collection_name=self.test_collection)
55
+
56
+ memories = medical_memory_repo.get_recent_memories(self.patient_id, collection_name=self.test_collection)
57
+ self.assertEqual(len(memories), 2)
58
+ self.assertIsInstance(memories[0], MedicalMemory)
59
+ self.assertEqual(memories[0].summary, "Newest") # Descending sort order
60
+
61
+ # Test limit
62
+ limited = medical_memory_repo.get_recent_memories(self.patient_id, limit=1, collection_name=self.test_collection)
63
+ self.assertEqual(len(limited), 1)
64
+
65
+ def test_search_memories_semantic(self):
66
+ """Test semantic search functionality, verifying similarity logic and sorting."""
67
+ # Create memories with known embeddings
68
+ vec_a = [1.0, 0.0, 0.0] # Most similar
69
+ vec_b = [0.7, 0.7, 0.0] # Less similar
70
+ vec_c = [0.0, 0.0, 1.0] # Not similar
71
+ medical_memory_repo.create_memory(self.patient_id, self.doctor_id, "Vec A", embedding=vec_a, collection_name=self.test_collection)
72
+ medical_memory_repo.create_memory(self.patient_id, self.doctor_id, "Vec B", embedding=vec_b, collection_name=self.test_collection)
73
+ medical_memory_repo.create_memory(self.patient_id, self.doctor_id, "Vec C", embedding=vec_c, collection_name=self.test_collection)
74
+ medical_memory_repo.create_memory(self.patient_id, self.doctor_id, "No Embedding", collection_name=self.test_collection)
75
+
76
+ query_embedding = [0.9, 0.1, 0.0]
77
+ results = medical_memory_repo.search_memories_semantic(self.patient_id, query_embedding, collection_name=self.test_collection)
78
+
79
+ self.assertEqual(len(results), 3) # Vec C should be filtered by default numpy math
80
+ self.assertIsInstance(results[0], SemanticSearchResult)
81
+ self.assertEqual(results[0].summary, "Vec A") # Most similar should be first
82
+ self.assertGreater(results[0].similarity_score, results[1].similarity_score)
83
+
84
+
85
+ class TestMedicalMemoryRepositoryExceptions(BaseMongoTest):
86
+ """Test class for the exception handling of medical memory repository functions."""
87
+
88
+ def setUp(self):
89
+ """Set up the test environment before each test."""
90
+ super().setUp()
91
+ self.test_collection = self._collections[Collections.MEDICAL_MEMORY]
92
+ medical_memory_repo.init(collection_name=self.test_collection, drop=True)
93
+ self.patient_id = str(ObjectId())
94
+ self.doctor_id = str(ObjectId())
95
+
96
+ def test_invalid_id_raises_action_failed(self):
97
+ """Test that functions raise ActionFailed when given a malformed ObjectId string."""
98
+ with self.assertRaises(ActionFailed):
99
+ medical_memory_repo.create_memory("bad-id", self.doctor_id, "t", collection_name=self.test_collection)
100
+ with self.assertRaises(ActionFailed):
101
+ medical_memory_repo.get_recent_memories("bad-id", collection_name=self.test_collection)
102
+ with self.assertRaises(ActionFailed):
103
+ medical_memory_repo.search_memories_semantic("bad-id", [], collection_name=self.test_collection)
104
+
105
+ @patch('src.data.repositories.medical_memory.get_collection')
106
+ def test_all_functions_raise_on_connection_error(self, mock_get_collection):
107
+ """Test that all repo functions catch generic PyMongoErrors and raise ActionFailed."""
108
+ mock_get_collection.side_effect = ConnectionFailure("Simulated connection error")
109
+
110
+ with self.assertRaises(ActionFailed):
111
+ medical_memory_repo.init(collection_name=self.test_collection, drop=True)
112
+ with self.assertRaises(ActionFailed):
113
+ medical_memory_repo.create_memory(self.patient_id, self.doctor_id, "t", collection_name=self.test_collection)
114
+ with self.assertRaises(ActionFailed):
115
+ medical_memory_repo.get_recent_memories(self.patient_id, collection_name=self.test_collection)
116
+ with self.assertRaises(ActionFailed):
117
+ medical_memory_repo.search_memories_semantic(self.patient_id, [], collection_name=self.test_collection)
118
+
119
+ if __name__ == "__main__":
120
+ logger().info("Starting MongoDB repository integration tests...")
121
+ unittest.main(verbosity=2)
122
+ logger().info("Tests completed and database connection closed.")
tests/test_medical_record.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import unittest
3
+ from unittest.mock import patch
4
+
5
+ from bson import ObjectId
6
+ from pymongo.errors import ConnectionFailure
7
+
8
+ from src.data.connection import ActionFailed, Collections
9
+ from src.data.repositories import medical_record as medical_record_repo
10
+ from src.models.medical import MedicalRecord
11
+ from src.utils.logger import logger
12
+ from tests.base_test import BaseMongoTest
13
+
14
+
15
+ class TestMedicalRecordRepository(BaseMongoTest):
16
+ """Test class for the 'happy path' of medical record repository functions."""
17
+
18
+ def setUp(self):
19
+ """Set up a clean test environment before each test."""
20
+ super().setUp()
21
+ self.test_collection = self._collections[Collections.MEDICAL_RECORDS]
22
+ medical_record_repo.init(collection_name=self.test_collection, drop=True)
23
+ self.patient_id = str(ObjectId())
24
+ self.doctor_id = str(ObjectId())
25
+
26
+ def test_init_functionality(self):
27
+ """Test that the init function correctly sets up the collection."""
28
+ self.assertIn(self.test_collection, self.db.list_collection_names())
29
+
30
+ def test_create_medical_record(self):
31
+ """Test successful creation of a medical record."""
32
+ record_id = medical_record_repo.create_medical_record(
33
+ patient_id=self.patient_id,
34
+ doctor_id=self.doctor_id,
35
+ record_type="Consultation",
36
+ details={"symptoms": "Fever, cough", "diagnosis": "Common cold"},
37
+ collection_name=self.test_collection
38
+ )
39
+ self.assertIsInstance(record_id, str)
40
+ doc = self.get_doc_by_id(Collections.MEDICAL_RECORDS, record_id)
41
+ self.assertIsNotNone(doc)
42
+ self.assertEqual(doc["record_type"], "Consultation") # type: ignore
43
+ self.assertEqual(str(doc["patient_id"]), self.patient_id) # type: ignore
44
+
45
+ def test_get_records_for_patient(self):
46
+ """Test retrieving all records for a patient, verifying sorting and filtering."""
47
+ other_patient_id = str(ObjectId())
48
+ # Create records, sleeping to ensure distinct creation timestamps for sorting check
49
+ r1_id = medical_record_repo.create_medical_record(self.patient_id, self.doctor_id, "R1", {}, collection_name=self.test_collection)
50
+ time.sleep(0.01)
51
+ medical_record_repo.create_medical_record(other_patient_id, self.doctor_id, "Other", {}, collection_name=self.test_collection)
52
+ time.sleep(0.01)
53
+ r2_id = medical_record_repo.create_medical_record(self.patient_id, self.doctor_id, "R2", {}, collection_name=self.test_collection)
54
+
55
+ # Retrieve records for the target patient
56
+ records = medical_record_repo.get_records_for_patient(self.patient_id, collection_name=self.test_collection)
57
+
58
+ # Verify correct filtering, count, and type
59
+ self.assertEqual(len(records), 2)
60
+ self.assertIsInstance(records[0], MedicalRecord)
61
+
62
+ # Verify sorting (ascending by creation date)
63
+ self.assertEqual(records[0].id, r1_id)
64
+ self.assertEqual(records[1].id, r2_id)
65
+
66
+ # Test edge case: patient with no records
67
+ no_records = medical_record_repo.get_records_for_patient(str(ObjectId()), collection_name=self.test_collection)
68
+ self.assertEqual(len(no_records), 0)
69
+
70
+
71
+ class TestMedicalRecordRepositoryExceptions(BaseMongoTest):
72
+ """Test class for the exception handling of medical record repository functions."""
73
+
74
+ def setUp(self):
75
+ """Set up the test environment before each test."""
76
+ super().setUp()
77
+ self.test_collection = self._collections[Collections.MEDICAL_RECORDS]
78
+ medical_record_repo.init(collection_name=self.test_collection, drop=True)
79
+ self.patient_id = str(ObjectId())
80
+ self.doctor_id = str(ObjectId())
81
+
82
+ def test_invalid_id_raises_action_failed(self):
83
+ """Test that functions raise ActionFailed when given a malformed ObjectId string."""
84
+ with self.assertRaises(ActionFailed):
85
+ medical_record_repo.create_medical_record("bad-id", self.doctor_id, "t", {}, collection_name=self.test_collection)
86
+ with self.assertRaises(ActionFailed):
87
+ medical_record_repo.get_records_for_patient("bad-id", collection_name=self.test_collection)
88
+
89
+ @patch('src.data.repositories.medical_record.get_collection')
90
+ def test_all_functions_raise_on_connection_error(self, mock_get_collection):
91
+ """Test that all repo functions catch generic PyMongoErrors and raise ActionFailed."""
92
+ mock_get_collection.side_effect = ConnectionFailure("Simulated connection error")
93
+
94
+ with self.assertRaises(ActionFailed):
95
+ medical_record_repo.init(collection_name=self.test_collection, drop=True)
96
+ with self.assertRaises(ActionFailed):
97
+ medical_record_repo.create_medical_record(self.patient_id, self.doctor_id, "t", {}, collection_name=self.test_collection)
98
+ with self.assertRaises(ActionFailed):
99
+ medical_record_repo.get_records_for_patient(self.patient_id, collection_name=self.test_collection)
100
+
101
+ if __name__ == "__main__":
102
+ logger().info("Starting MongoDB repository integration tests...")
103
+ unittest.main(verbosity=2)
104
+ logger().info("Tests completed and database connection closed.")