dylanglenister commited on
Commit
a00215a
·
1 Parent(s): 1ec4bcf

Merging duplicate functions in account

Browse files
src/api/routes/chat.py CHANGED
@@ -36,11 +36,7 @@ async def chat_endpoint(
36
  # Get or create user profile (doctor as current user profile)
37
  user_profile = state.memory_system.get_user(request.user_id)
38
  if not user_profile:
39
- state.memory_system.create_user(
40
- role=request.user_role or "Anonymous",
41
- user_id=request.user_id,
42
- speciality=request.user_specialty or None
43
- )
44
  except Exception as e:
45
  logger().error(f"Error retrieving or creating user profile: {e}")
46
  logger().error(f"Request data: {request.model_dump()}")
 
36
  # Get or create user profile (doctor as current user profile)
37
  user_profile = state.memory_system.get_user(request.user_id)
38
  if not user_profile:
39
+ state.memory_system.create_user()
 
 
 
 
40
  except Exception as e:
41
  logger().error(f"Error retrieving or creating user profile: {e}")
42
  logger().error(f"Request data: {request.model_dump()}")
src/api/routes/doctors.py CHANGED
@@ -2,8 +2,8 @@
2
 
3
  from fastapi import APIRouter, HTTPException
4
 
5
- from src.data.repositories.account import (create_doctor, get_all_doctors,
6
- get_doctor_by_name, search_doctors)
7
  from src.models.user import DoctorCreateRequest
8
  from src.utils.logger import logger
9
 
@@ -13,7 +13,7 @@ router = APIRouter(prefix="/doctors", tags=["Doctors"])
13
  async def get_all_doctors_route(limit: int = 50):
14
  try:
15
  logger().info(f"GET /doctors limit={limit}")
16
- results = get_all_doctors(limit=limit)
17
  logger().info(f"Retrieved {len(results)} doctors")
18
  return {"results": results}
19
  except Exception as e:
@@ -24,11 +24,10 @@ async def get_all_doctors_route(limit: int = 50):
24
  async def create_doctor_profile(req: DoctorCreateRequest):
25
  try:
26
  logger().info(f"POST /doctors name={req.name}")
27
- doctor_id = create_doctor(
28
  name=req.name,
29
  role=req.role,
30
- specialty=req.specialty,
31
- medical_roles=req.medical_roles
32
  )
33
  logger().info(f"Created doctor {req.name} id={doctor_id}")
34
  return {"doctor_id": doctor_id, "name": req.name}
@@ -40,7 +39,7 @@ async def create_doctor_profile(req: DoctorCreateRequest):
40
  async def get_doctor(doctor_name: str):
41
  try:
42
  logger().info(f"GET /doctors/{doctor_name}")
43
- doctor = get_doctor_by_name(doctor_name)
44
  if not doctor:
45
  raise HTTPException(status_code=404, detail="Doctor not found")
46
  return doctor
@@ -54,7 +53,7 @@ async def get_doctor(doctor_name: str):
54
  async def search_doctors_route(q: str, limit: int = 10):
55
  try:
56
  logger().info(f"GET /doctors/search q='{q}' limit={limit}")
57
- results = search_doctors(q, limit=limit)
58
  logger().info(f"Doctor search returned {len(results)} results")
59
  return {"results": results}
60
  except Exception as e:
 
2
 
3
  from fastapi import APIRouter, HTTPException
4
 
5
+ from src.data.repositories.account import (create_account, get_all_accounts,
6
+ get_account_by_name, search_accounts)
7
  from src.models.user import DoctorCreateRequest
8
  from src.utils.logger import logger
9
 
 
13
  async def get_all_doctors_route(limit: int = 50):
14
  try:
15
  logger().info(f"GET /doctors limit={limit}")
16
+ results = get_all_accounts(limit=limit)
17
  logger().info(f"Retrieved {len(results)} doctors")
18
  return {"results": results}
19
  except Exception as e:
 
24
  async def create_doctor_profile(req: DoctorCreateRequest):
25
  try:
26
  logger().info(f"POST /doctors name={req.name}")
27
+ doctor_id = create_account(
28
  name=req.name,
29
  role=req.role,
30
+ specialty=req.specialty
 
31
  )
32
  logger().info(f"Created doctor {req.name} id={doctor_id}")
33
  return {"doctor_id": doctor_id, "name": req.name}
 
39
  async def get_doctor(doctor_name: str):
40
  try:
41
  logger().info(f"GET /doctors/{doctor_name}")
42
+ doctor = get_account_by_name(doctor_name)
43
  if not doctor:
44
  raise HTTPException(status_code=404, detail="Doctor not found")
45
  return doctor
 
53
  async def search_doctors_route(q: str, limit: int = 10):
54
  try:
55
  logger().info(f"GET /doctors/search q='{q}' limit={limit}")
56
+ results = search_accounts(q, limit=limit)
57
  logger().info(f"Doctor search returned {len(results)} results")
58
  return {"results": results}
59
  except Exception as e:
src/api/routes/user.py CHANGED
@@ -17,23 +17,16 @@ async def create_user_profile(
17
  """Create or update user profile"""
18
  try:
19
  # Persist to in-memory profile (existing behavior)
20
- user = state.memory_system.create_user(user_id=request.user_id, name=request.name)
21
- user.set_preference("role", request.role)
22
- if request.specialty:
23
- user.set_preference("specialty", request.specialty)
24
- if request.medical_roles:
25
- user.set_preference("medical_roles", request.medical_roles)
26
 
27
  # Persist to MongoDB accounts collection
28
  account_id = create_account(
29
  request.name,
30
  request.role,
31
- request.specialty or None,
32
- request.medical_roles or [request.role] if request.role else [],
33
- user_id=request.user_id
34
  )
35
 
36
- return {"message": "User profile created successfully", "user_id": request.user_id, "account_id": account_id}
37
  except Exception as e:
38
  logger().error(f"Error creating user profile: {e}")
39
  raise HTTPException(status_code=500, detail=str(e))
 
17
  """Create or update user profile"""
18
  try:
19
  # Persist to in-memory profile (existing behavior)
20
+ #state.memory_system.create_user(name=request.name, role=request.role, speciality=request.specialty)
 
 
 
 
 
21
 
22
  # Persist to MongoDB accounts collection
23
  account_id = create_account(
24
  request.name,
25
  request.role,
26
+ request.specialty
 
 
27
  )
28
 
29
+ return {"message": "User profile created successfully", "account_id": account_id}
30
  except Exception as e:
31
  logger().error(f"Error creating user profile: {e}")
32
  raise HTTPException(status_code=500, detail=str(e))
src/core/memory.py CHANGED
@@ -2,7 +2,6 @@
2
 
3
  import uuid
4
  from datetime import datetime, timezone
5
- from typing import Any
6
 
7
  from src.core.profile import UserProfile
8
  from src.core.session import ChatSession
@@ -21,27 +20,20 @@ class MemoryLRU:
21
 
22
  def create_user(self,
23
  name: str = "Anonymous",
24
- role: str | None = None,
25
- speciality: str | None = None,
26
- roles: list[str] = [],
27
- preferences: dict[str, Any] = {},
28
- *,
29
- user_id: str,
30
  ) -> UserProfile:
31
  """Creates a new user profile."""
32
- account.create_account(
33
  name=name,
34
  role=role,
35
- speciality=speciality,
36
- roles=roles,
37
- preferences=preferences,
38
- user_id=user_id
39
  )
40
  return UserProfile(user_id, name)
41
 
42
  def get_user(self, user_id: str) -> UserProfile | None:
43
  """Retrieves a user profile by its ID."""
44
- data = account.get_user_profile(user_id)
45
  return UserProfile.from_dict(data) if data else None
46
 
47
  def create_session(self, user_id: str, title: str = "New Chat") -> str:
@@ -91,14 +83,6 @@ class MemoryLRU:
91
  """Deletes a chat session."""
92
  session.delete_session(session_id)
93
 
94
- def set_user_preferences(
95
- self,
96
- user_id: str,
97
- update_data: dict[str, Any]
98
- ):
99
- """Sets a preference for a user."""
100
- account.set_user_preferences(user_id, update_data)
101
-
102
  def add(self, user_id: str, summary: str):
103
  """Adds a medical context summary for a user."""
104
  medical.add_medical_context(user_id, summary)
 
2
 
3
  import uuid
4
  from datetime import datetime, timezone
 
5
 
6
  from src.core.profile import UserProfile
7
  from src.core.session import ChatSession
 
20
 
21
  def create_user(self,
22
  name: str = "Anonymous",
23
+ role: str = "Other",
24
+ speciality: str | None = None
 
 
 
 
25
  ) -> UserProfile:
26
  """Creates a new user profile."""
27
+ user_id = account.create_account(
28
  name=name,
29
  role=role,
30
+ specialty=speciality
 
 
 
31
  )
32
  return UserProfile(user_id, name)
33
 
34
  def get_user(self, user_id: str) -> UserProfile | None:
35
  """Retrieves a user profile by its ID."""
36
+ data = account.get_account(user_id)
37
  return UserProfile.from_dict(data) if data else None
38
 
39
  def create_session(self, user_id: str, title: str = "New Chat") -> str:
 
83
  """Deletes a chat session."""
84
  session.delete_session(session_id)
85
 
 
 
 
 
 
 
 
 
86
  def add(self, user_id: str, summary: str):
87
  """Adds a medical context summary for a user."""
88
  medical.add_medical_context(user_id, summary)
src/data/repositories/account.py CHANGED
@@ -28,8 +28,21 @@ from src.utils.logger import logger
28
 
29
  ACCOUNTS_COLLECTION = "accounts"
30
 
 
 
 
 
 
 
 
 
 
 
31
  def create():
32
- create_collection(ACCOUNTS_COLLECTION, "schemas/account_validator.json")
 
 
 
33
 
34
  def get_account_frame(
35
  *,
@@ -40,12 +53,9 @@ def get_account_frame(
40
 
41
  def create_account(
42
  name: str,
43
- role: str | None = None,
44
- speciality: str | None = None,
45
- roles: list[str] = [],
46
- preferences: dict[str, Any] = {},
47
  *,
48
- user_id: str,
49
  collection_name: str = ACCOUNTS_COLLECTION
50
  ) -> str:
51
  """
@@ -56,18 +66,15 @@ def create_account(
56
  collection = get_collection(collection_name)
57
  now = datetime.now(timezone.utc)
58
  user_data: dict[str, Any] = {
59
- "_id": user_id,
60
  "name": name,
61
  "role" : role,
62
- "speciality": speciality,
63
- "medical_roles": roles,
64
  "created_at": now,
65
  "updated_at": now
66
  }
67
 
68
  try:
69
  result = collection.insert_one(user_data)
70
- set_user_preferences(user_id, preferences)
71
  logger().info(f"Created new account: {result.inserted_id}")
72
  return str(result.inserted_id)
73
  except DuplicateKeyError as e:
@@ -93,12 +100,12 @@ def update_account(
93
  )
94
  return result.modified_count > 0
95
 
96
- def get_user_profile(
97
  user_id: str,
98
  /, *,
99
  collection_name: str = ACCOUNTS_COLLECTION
100
  ) -> dict[str, Any] | None:
101
- """Retrieves a user profile by ID and updates their last_seen timestamp."""
102
  collection = get_collection(collection_name)
103
  now = datetime.now(timezone.utc)
104
  return collection.find_one_and_update(
@@ -111,93 +118,21 @@ def get_user_profile(
111
  return_document=True
112
  )
113
 
114
- def set_user_preferences(
115
- user_id: str,
116
- /,
117
- preferences: dict[str, Any],
118
- *,
119
- collection_name: str = ACCOUNTS_COLLECTION
120
- ) -> bool:
121
- """Sets a preference for a user."""
122
- try:
123
- collection = get_collection(collection_name)
124
- preferences = {f"preferences.{key}": value for key, value in preferences}
125
- preferences["updated_at"] = datetime.now(timezone.utc)
126
- result = collection.update_one(
127
- {"_id": user_id},
128
- {
129
- "$set": preferences
130
- }
131
- )
132
- if result.matched_count == 0:
133
- raise EntryNotFound(f"User with ID '{user_id}' not found.")
134
-
135
- return result.modified_count > 0
136
- except PyMongoError as e:
137
- logger().error(f"An error occurred with the database operation: {e}")
138
- return False
139
- except EntryNotFound as e:
140
- logger().error(e)
141
- return False
142
-
143
- # TODO Below methods are unverified
144
-
145
- def create_doctor(
146
- *,
147
- name: str,
148
- role: str | None = None,
149
- specialty: str | None = None,
150
- medical_roles: list[str] | None = None
151
- ) -> str:
152
- """Create a new doctor profile"""
153
- collection = get_collection(ACCOUNTS_COLLECTION)
154
- now = datetime.now(timezone.utc)
155
- doctor_doc = {
156
- "name": name,
157
- "role": role,
158
- "specialty": specialty,
159
- "medical_roles": medical_roles or [],
160
- "created_at": now,
161
- "updated_at": now
162
- }
163
- try:
164
- result = collection.insert_one(doctor_doc)
165
- logger().info(f"Created new doctor: {name} with id {result.inserted_id}")
166
- return str(result.inserted_id)
167
- except Exception as e:
168
- logger().error(f"Error creating doctor: {e}")
169
- raise e
170
-
171
-
172
- def get_doctor_by_name(name: str) -> dict[str, Any] | None:
173
- """Get doctor by name from accounts collection"""
174
  collection = get_collection(ACCOUNTS_COLLECTION)
175
- doctor = collection.find_one({
176
- "name": name,
177
- "role": {
178
- "$in": [
179
- "Doctor",
180
- "Healthcare Prof",
181
- "General Practitioner",
182
- "Cardiologist",
183
- "Pediatrician",
184
- "Neurologist",
185
- "Dermatologist"
186
- ]
187
- }
188
- })
189
- if doctor:
190
- doctor["_id"] = str(doctor.get("_id")) if doctor.get("_id") else None
191
- return doctor
192
 
193
-
194
- def search_doctors(query: str, limit: int = 10) -> list[dict[str, Any]]:
195
- """Search doctors by name (case-insensitive contains) from accounts collection"""
196
  collection = get_collection(ACCOUNTS_COLLECTION)
197
  if not query:
198
  return []
199
 
200
- logger().info(f"Searching doctors with query: '{query}', limit: {limit}")
201
 
202
  # Build a regex for name search
203
  pattern = re.compile(re.escape(query), re.IGNORECASE)
@@ -221,30 +156,18 @@ def search_doctors(query: str, limit: int = 10) -> list[dict[str, Any]]:
221
  for d in cursor:
222
  d["_id"] = str(d.get("_id")) if d.get("_id") else None
223
  results.append(d)
224
- logger().info(f"Found {len(results)} doctors matching query")
225
  return results
226
  except Exception as e:
227
- logger().error(f"Error in search_doctors: {e}")
228
  return []
229
 
230
 
231
- def get_all_doctors(limit: int = 50) -> list[dict[str, Any]]:
232
  """Get all doctors with optional limit from accounts collection"""
233
  collection = get_collection(ACCOUNTS_COLLECTION)
234
  try:
235
- cursor = collection.find({
236
- "role": {
237
- "$in": [
238
- "Doctor",
239
- "Healthcare Prof",
240
- "General Practitioner",
241
- "Cardiologist",
242
- "Pediatrician",
243
- "Neurologist",
244
- "Dermatologist"
245
- ]
246
- }
247
- }).sort("name", ASCENDING).limit(limit)
248
  results = []
249
  for d in cursor:
250
  d["_id"] = str(d.get("_id")) if d.get("_id") else None
 
28
 
29
  ACCOUNTS_COLLECTION = "accounts"
30
 
31
+ VALID_ROLES = [
32
+ "Doctor",
33
+ "Healthcare Prof",
34
+ "Nurse",
35
+ "Caregiver",
36
+ "Physicion",
37
+ "Medical Student",
38
+ "Other"
39
+ ]
40
+
41
  def create():
42
+ create_collection(
43
+ ACCOUNTS_COLLECTION,
44
+ "schemas/account_validator.json"
45
+ )
46
 
47
  def get_account_frame(
48
  *,
 
53
 
54
  def create_account(
55
  name: str,
56
+ role: str,
57
+ specialty: str | None = None,
 
 
58
  *,
 
59
  collection_name: str = ACCOUNTS_COLLECTION
60
  ) -> str:
61
  """
 
66
  collection = get_collection(collection_name)
67
  now = datetime.now(timezone.utc)
68
  user_data: dict[str, Any] = {
 
69
  "name": name,
70
  "role" : role,
71
+ "specialty": specialty,
 
72
  "created_at": now,
73
  "updated_at": now
74
  }
75
 
76
  try:
77
  result = collection.insert_one(user_data)
 
78
  logger().info(f"Created new account: {result.inserted_id}")
79
  return str(result.inserted_id)
80
  except DuplicateKeyError as e:
 
100
  )
101
  return result.modified_count > 0
102
 
103
+ def get_account(
104
  user_id: str,
105
  /, *,
106
  collection_name: str = ACCOUNTS_COLLECTION
107
  ) -> dict[str, Any] | None:
108
+ """Retrieves an account by ID and updates their last_seen timestamp."""
109
  collection = get_collection(collection_name)
110
  now = datetime.now(timezone.utc)
111
  return collection.find_one_and_update(
 
118
  return_document=True
119
  )
120
 
121
+ def get_account_by_name(name: str) -> dict[str, Any] | None:
122
+ """Get account by name from accounts collection"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  collection = get_collection(ACCOUNTS_COLLECTION)
124
+ account = collection.find_one({"name": name})
125
+ #if account:
126
+ # account["_id"] = str(account.get("_id")) if account.get("_id") else None
127
+ return account
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
+ def search_accounts(query: str, limit: int = 10) -> list[dict[str, Any]]:
130
+ """Search accounts by name (case-insensitive contains) from accounts collection"""
 
131
  collection = get_collection(ACCOUNTS_COLLECTION)
132
  if not query:
133
  return []
134
 
135
+ logger().info(f"Searching accounts with query: '{query}', limit: {limit}")
136
 
137
  # Build a regex for name search
138
  pattern = re.compile(re.escape(query), re.IGNORECASE)
 
156
  for d in cursor:
157
  d["_id"] = str(d.get("_id")) if d.get("_id") else None
158
  results.append(d)
159
+ logger().info(f"Found {len(results)} accounts matching query")
160
  return results
161
  except Exception as e:
162
+ logger().error(f"Error in search_account: {e}")
163
  return []
164
 
165
 
166
+ def get_all_accounts(limit: int = 50) -> list[dict[str, Any]]:
167
  """Get all doctors with optional limit from accounts collection"""
168
  collection = get_collection(ACCOUNTS_COLLECTION)
169
  try:
170
+ cursor = collection.find().sort("name", ASCENDING).limit(limit)
 
 
 
 
 
 
 
 
 
 
 
 
171
  results = []
172
  for d in cursor:
173
  d["_id"] = str(d.get("_id")) if d.get("_id") else None
src/models/user.py CHANGED
@@ -4,11 +4,14 @@ from pydantic import BaseModel
4
 
5
 
6
  class UserProfileRequest(BaseModel):
7
- user_id: str
8
  name: str
9
  role: str
10
  specialty: str | None = None
11
- medical_roles: list[str] | None = None
 
 
 
 
12
 
13
  class PatientCreateRequest(BaseModel):
14
  name: str
@@ -31,9 +34,3 @@ class PatientUpdateRequest(BaseModel):
31
  medications: list[str] | None = None
32
  past_assessment_summary: str | None = None
33
  assigned_doctor_id: str | None = None
34
-
35
- class DoctorCreateRequest(BaseModel):
36
- name: str
37
- role: str | None = None
38
- specialty: str | None = None
39
- medical_roles: list[str] | None = None
 
4
 
5
 
6
  class UserProfileRequest(BaseModel):
 
7
  name: str
8
  role: str
9
  specialty: str | None = None
10
+
11
+ class DoctorCreateRequest(BaseModel):
12
+ name: str
13
+ role: str
14
+ specialty: str | None = None
15
 
16
  class PatientCreateRequest(BaseModel):
17
  name: str
 
34
  medications: list[str] | None = None
35
  past_assessment_summary: str | None = None
36
  assigned_doctor_id: str | None = None
 
 
 
 
 
 
tests/mongo_test.py CHANGED
@@ -51,7 +51,7 @@ class TestMongoDBRepositories(unittest.TestCase):
51
  success = account_repo.update_account(user_id, {"name": "Updated Name"}, collection_name=test_coll)
52
  self.assertTrue(success)
53
 
54
- profile = account_repo.get_user_profile(user_id, collection_name=test_coll)
55
  self.assertIsNotNone(profile)
56
  self.assertEqual(profile["name"], "Updated Name") # type: ignore
57
 
 
51
  success = account_repo.update_account(user_id, {"name": "Updated Name"}, collection_name=test_coll)
52
  self.assertTrue(success)
53
 
54
+ profile = account_repo.get_account(user_id, collection_name=test_coll)
55
  self.assertIsNotNone(profile)
56
  self.assertEqual(profile["name"], "Updated Name") # type: ignore
57