|
|
""" |
|
|
Session Manager Module - Enhanced Version |
|
|
|
|
|
Handles per-user state isolation, persistence, and recovery for multi-user deployments. |
|
|
|
|
|
Features: |
|
|
- Save/load sessions to/from disk |
|
|
- Session restoration on reconnect |
|
|
- Graph state persistence |
|
|
- RLHF data collection with detailed interaction tracking |
|
|
- Session archiving before cleanup |
|
|
- Thread-safe operations |
|
|
""" |
|
|
|
|
|
import json |
|
|
import logging |
|
|
import threading |
|
|
import glob |
|
|
from datetime import datetime, timedelta |
|
|
from pathlib import Path |
|
|
from dataclasses import dataclass, field, asdict |
|
|
from typing import Dict, List, Optional, Any |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ChatMessage: |
|
|
"""A single chat message.""" |
|
|
role: str |
|
|
content: str |
|
|
timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) |
|
|
language: str = "en" |
|
|
|
|
|
def to_dict(self) -> Dict: |
|
|
return asdict(self) |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, data: Dict) -> "ChatMessage": |
|
|
return cls(**data) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class UserInteraction: |
|
|
"""Records a single user interaction for RLHF data collection.""" |
|
|
timestamp: str |
|
|
action_type: str |
|
|
node_id: Optional[str] = None |
|
|
content: Optional[str] = None |
|
|
feedback: Optional[str] = None |
|
|
metadata: Dict[str, Any] = field(default_factory=dict) |
|
|
|
|
|
def to_dict(self) -> Dict: |
|
|
return { |
|
|
'timestamp': self.timestamp, |
|
|
'action_type': self.action_type, |
|
|
'node_id': self.node_id, |
|
|
'content': self.content, |
|
|
'feedback': self.feedback, |
|
|
'metadata': self.metadata |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, data: Dict) -> 'UserInteraction': |
|
|
return cls(**data) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Session: |
|
|
"""User session state with RLHF tracking.""" |
|
|
session_id: str |
|
|
created_at: str = field(default_factory=lambda: datetime.now().isoformat()) |
|
|
updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) |
|
|
language: str = "en" |
|
|
chat_history: List[ChatMessage] = field(default_factory=list) |
|
|
interactions: List[UserInteraction] = field(default_factory=list) |
|
|
graph_state: Optional[Dict] = None |
|
|
metadata: Dict[str, Any] = field(default_factory=dict) |
|
|
feedback_count: Dict[str, int] = field( |
|
|
default_factory=lambda: {'correct': 0, 'incorrect': 0, 'partial': 0} |
|
|
) |
|
|
|
|
|
def add_message(self, role: str, content: str, language: str = "en"): |
|
|
"""Add a message to chat history.""" |
|
|
self.chat_history.append(ChatMessage( |
|
|
role=role, |
|
|
content=content, |
|
|
language=language |
|
|
)) |
|
|
self.updated_at = datetime.now().isoformat() |
|
|
|
|
|
def add_interaction( |
|
|
self, |
|
|
action_type: str, |
|
|
node_id: Optional[str] = None, |
|
|
content: Optional[str] = None, |
|
|
feedback: Optional[str] = None, |
|
|
metadata: Optional[Dict] = None |
|
|
): |
|
|
"""Record an interaction for RLHF.""" |
|
|
interaction = UserInteraction( |
|
|
timestamp=datetime.now().isoformat(), |
|
|
action_type=action_type, |
|
|
node_id=node_id, |
|
|
content=content, |
|
|
feedback=feedback, |
|
|
metadata=metadata or {} |
|
|
) |
|
|
self.interactions.append(interaction) |
|
|
self.updated_at = datetime.now().isoformat() |
|
|
|
|
|
if feedback and feedback in self.feedback_count: |
|
|
self.feedback_count[feedback] += 1 |
|
|
|
|
|
def get_title(self, max_length: int = 35) -> str: |
|
|
"""Get session title from first user message.""" |
|
|
for msg in self.chat_history: |
|
|
if msg.role == "user": |
|
|
title = msg.content[:max_length] |
|
|
return title + "..." if len(msg.content) > max_length else title |
|
|
return "New Chat" |
|
|
|
|
|
def to_dict(self) -> Dict: |
|
|
return { |
|
|
"session_id": self.session_id, |
|
|
"created_at": self.created_at, |
|
|
"updated_at": self.updated_at, |
|
|
"language": self.language, |
|
|
"chat_history": [m.to_dict() for m in self.chat_history], |
|
|
"interactions": [i.to_dict() for i in self.interactions], |
|
|
"graph_state": self.graph_state, |
|
|
"metadata": self.metadata, |
|
|
"feedback_summary": self.feedback_count, |
|
|
} |
|
|
|
|
|
def export_rlhf_data(self) -> Dict: |
|
|
"""Export session data formatted for RLHF training.""" |
|
|
return { |
|
|
"session_id": self.session_id, |
|
|
"created_at": self.created_at, |
|
|
"language": self.language, |
|
|
"interactions": [i.to_dict() for i in self.interactions], |
|
|
"feedback_summary": self.feedback_count, |
|
|
"chat_history": [m.to_dict() for m in self.chat_history], |
|
|
"graph_state": self.graph_state, |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, data: Dict) -> "Session": |
|
|
session = cls( |
|
|
session_id=data["session_id"], |
|
|
created_at=data.get("created_at", datetime.now().isoformat()), |
|
|
updated_at=data.get("updated_at", datetime.now().isoformat()), |
|
|
language=data.get("language", "en"), |
|
|
chat_history=[ChatMessage.from_dict(m) for m in data.get("chat_history", [])], |
|
|
graph_state=data.get("graph_state"), |
|
|
metadata=data.get("metadata", {}), |
|
|
feedback_count=data.get("feedback_summary", {'correct': 0, 'incorrect': 0, 'partial': 0}), |
|
|
) |
|
|
session.interactions = [ |
|
|
UserInteraction.from_dict(i) for i in data.get("interactions", []) |
|
|
] |
|
|
return session |
|
|
|
|
|
|
|
|
class SessionManager: |
|
|
""" |
|
|
Manages user sessions with persistence, RLHF tracking, and cleanup. |
|
|
|
|
|
Thread-safe for multi-user deployments. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
sessions_dir: str = "./data/sessions", |
|
|
max_sessions: int = 1000, |
|
|
session_max_age_hours: int = 24 |
|
|
): |
|
|
self.sessions_dir = Path(sessions_dir) |
|
|
self.sessions_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
self.archive_dir = self.sessions_dir / "archive" |
|
|
self.archive_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
self.max_sessions = max_sessions |
|
|
self.session_max_age = timedelta(hours=session_max_age_hours) |
|
|
|
|
|
|
|
|
self._sessions: Dict[str, Session] = {} |
|
|
self._lock = threading.RLock() |
|
|
|
|
|
def get_session(self, session_id: str) -> Optional[Session]: |
|
|
"""Get session by ID, loading from disk if needed.""" |
|
|
with self._lock: |
|
|
|
|
|
if session_id in self._sessions: |
|
|
session = self._sessions[session_id] |
|
|
|
|
|
|
|
|
try: |
|
|
updated = datetime.fromisoformat(session.updated_at) |
|
|
if datetime.now() - updated > self.session_max_age: |
|
|
self._archive_session(session) |
|
|
del self._sessions[session_id] |
|
|
return None |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
return session |
|
|
|
|
|
|
|
|
session = self._load_from_disk(session_id) |
|
|
if session: |
|
|
self._sessions[session_id] = session |
|
|
return session |
|
|
|
|
|
def create_session(self, session_id: str) -> Session: |
|
|
"""Create a new session.""" |
|
|
with self._lock: |
|
|
session = Session(session_id=session_id) |
|
|
self._sessions[session_id] = session |
|
|
self._save_to_disk(session) |
|
|
|
|
|
|
|
|
if len(self._sessions) > self.max_sessions: |
|
|
self._cleanup_old_sessions() |
|
|
|
|
|
logger.info(f"Created new session: {session_id}") |
|
|
return session |
|
|
|
|
|
def get_or_create(self, session_id: str) -> Session: |
|
|
"""Get existing session or create new one.""" |
|
|
session = self.get_session(session_id) |
|
|
if session is None: |
|
|
session = self.create_session(session_id) |
|
|
return session |
|
|
|
|
|
def save_session(self, session_id: str): |
|
|
"""Save session to disk.""" |
|
|
with self._lock: |
|
|
session = self._sessions.get(session_id) |
|
|
if session: |
|
|
session.updated_at = datetime.now().isoformat() |
|
|
self._save_to_disk(session) |
|
|
|
|
|
def update_graph_state(self, session_id: str, state: Dict): |
|
|
"""Update session's graph state.""" |
|
|
with self._lock: |
|
|
session = self.get_or_create(session_id) |
|
|
session.graph_state = state |
|
|
session.updated_at = datetime.now().isoformat() |
|
|
self._save_to_disk(session) |
|
|
|
|
|
def update_language(self, session_id: str, language: str): |
|
|
"""Update the detected language for a session.""" |
|
|
session = self.get_session(session_id) |
|
|
if session: |
|
|
session.language = language |
|
|
|
|
|
def record_interaction( |
|
|
self, |
|
|
session_id: str, |
|
|
action_type: str, |
|
|
node_id: Optional[str] = None, |
|
|
content: Optional[str] = None, |
|
|
feedback: Optional[str] = None, |
|
|
metadata: Optional[Dict] = None |
|
|
): |
|
|
"""Record a user interaction for RLHF.""" |
|
|
session = self.get_session(session_id) |
|
|
if session: |
|
|
session.add_interaction(action_type, node_id, content, feedback, metadata) |
|
|
|
|
|
if action_type in ['feedback', 'prune', 'inject', 'resurrect']: |
|
|
self.save_session(session_id) |
|
|
|
|
|
def add_feedback( |
|
|
self, |
|
|
session_id: str, |
|
|
node_id: str, |
|
|
feedback_type: str, |
|
|
context: str = "" |
|
|
): |
|
|
"""Record user feedback on reasoning.""" |
|
|
self.record_interaction( |
|
|
session_id=session_id, |
|
|
action_type='feedback', |
|
|
node_id=node_id, |
|
|
feedback=feedback_type, |
|
|
metadata={'context': context} |
|
|
) |
|
|
|
|
|
def delete_session(self, session_id: str, archive: bool = True): |
|
|
"""Delete a session, optionally archiving for RLHF.""" |
|
|
with self._lock: |
|
|
session = self._sessions.pop(session_id, None) |
|
|
|
|
|
if session and archive: |
|
|
self._archive_session(session) |
|
|
|
|
|
session_file = self.sessions_dir / f"{session_id}.json" |
|
|
if session_file.exists(): |
|
|
session_file.unlink() |
|
|
|
|
|
def list_sessions(self, limit: int = 20) -> List[Dict]: |
|
|
"""List recent sessions with titles.""" |
|
|
sessions = [] |
|
|
|
|
|
|
|
|
for session_file in self.sessions_dir.glob("*.json"): |
|
|
try: |
|
|
with open(session_file) as f: |
|
|
data = json.load(f) |
|
|
sessions.append({ |
|
|
"id": data["session_id"], |
|
|
"title": self._get_title(data), |
|
|
"timestamp": data.get("updated_at", ""), |
|
|
"language": data.get("language", "en"), |
|
|
}) |
|
|
except Exception: |
|
|
continue |
|
|
|
|
|
|
|
|
sessions.sort(key=lambda x: x["timestamp"], reverse=True) |
|
|
return sessions[:limit] |
|
|
|
|
|
def cleanup_stale_sessions(self): |
|
|
"""Remove expired sessions, archiving for RLHF.""" |
|
|
cutoff = datetime.now() - self.session_max_age |
|
|
|
|
|
with self._lock: |
|
|
|
|
|
to_remove = [] |
|
|
for sid, session in self._sessions.items(): |
|
|
try: |
|
|
updated = datetime.fromisoformat(session.updated_at) |
|
|
if updated < cutoff: |
|
|
self._archive_session(session) |
|
|
to_remove.append(sid) |
|
|
except Exception: |
|
|
continue |
|
|
|
|
|
for sid in to_remove: |
|
|
del self._sessions[sid] |
|
|
|
|
|
|
|
|
for session_file in self.sessions_dir.glob("*.json"): |
|
|
if session_file.name == "archive": |
|
|
continue |
|
|
try: |
|
|
mtime = datetime.fromtimestamp(session_file.stat().st_mtime) |
|
|
if mtime < cutoff: |
|
|
|
|
|
try: |
|
|
with open(session_file) as f: |
|
|
data = json.load(f) |
|
|
session = Session.from_dict(data) |
|
|
self._archive_session(session) |
|
|
except Exception: |
|
|
pass |
|
|
session_file.unlink() |
|
|
logger.debug(f"Cleaned up stale session: {session_file.stem}") |
|
|
except Exception: |
|
|
continue |
|
|
|
|
|
if to_remove: |
|
|
logger.info(f"Cleaned up {len(to_remove)} stale sessions") |
|
|
|
|
|
def export_all_rlhf_data(self) -> List[Dict]: |
|
|
"""Export all active sessions' RLHF data.""" |
|
|
with self._lock: |
|
|
return [session.export_rlhf_data() for session in self._sessions.values()] |
|
|
|
|
|
def get_stats(self) -> Dict: |
|
|
"""Get session manager statistics.""" |
|
|
with self._lock: |
|
|
total_interactions = sum(len(s.interactions) for s in self._sessions.values()) |
|
|
total_feedback = sum(sum(s.feedback_count.values()) for s in self._sessions.values()) |
|
|
|
|
|
return { |
|
|
'active_sessions': len(self._sessions), |
|
|
'total_interactions': total_interactions, |
|
|
'total_feedback': total_feedback, |
|
|
} |
|
|
|
|
|
def _archive_session(self, session: Session): |
|
|
"""Archive session data for RLHF before deletion.""" |
|
|
try: |
|
|
filename = self.archive_dir / f"session_{session.session_id}_{session.created_at[:10]}.json" |
|
|
with open(filename, 'w', encoding='utf-8') as f: |
|
|
json.dump(session.export_rlhf_data(), f, indent=2, ensure_ascii=False) |
|
|
logger.info(f"Archived session {session.session_id}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to archive session {session.session_id}: {e}") |
|
|
|
|
|
def _load_from_disk(self, session_id: str) -> Optional[Session]: |
|
|
"""Load session from disk.""" |
|
|
session_file = self.sessions_dir / f"{session_id}.json" |
|
|
if not session_file.exists(): |
|
|
return None |
|
|
|
|
|
try: |
|
|
with open(session_file) as f: |
|
|
data = json.load(f) |
|
|
logger.info(f"Loaded session {session_id} from disk") |
|
|
return Session.from_dict(data) |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to load session {session_id}: {e}") |
|
|
return None |
|
|
|
|
|
def _save_to_disk(self, session: Session): |
|
|
"""Save session to disk.""" |
|
|
session_file = self.sessions_dir / f"{session.session_id}.json" |
|
|
try: |
|
|
with open(session_file, 'w', encoding='utf-8') as f: |
|
|
json.dump(session.to_dict(), f, indent=2, ensure_ascii=False) |
|
|
logger.debug(f"Saved session {session.session_id} to disk") |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to save session {session.session_id}: {e}") |
|
|
|
|
|
def _cleanup_old_sessions(self): |
|
|
"""Remove oldest sessions when over limit.""" |
|
|
sessions = list(self._sessions.values()) |
|
|
sessions.sort(key=lambda s: s.updated_at) |
|
|
|
|
|
|
|
|
to_remove = sessions[:len(sessions) // 10] |
|
|
for session in to_remove: |
|
|
self.delete_session(session.session_id, archive=True) |
|
|
|
|
|
def _get_title(self, data: Dict) -> str: |
|
|
"""Extract title from session data.""" |
|
|
for msg in data.get("chat_history", []): |
|
|
if msg.get("role") == "user": |
|
|
content = msg.get("content", "") |
|
|
return content[:35] + "..." if len(content) > 35 else content |
|
|
return "New Chat" |
|
|
|
|
|
|
|
|
|
|
|
_session_manager: Optional[SessionManager] = None |
|
|
|
|
|
|
|
|
def get_session_manager() -> SessionManager: |
|
|
"""Get global session manager instance.""" |
|
|
global _session_manager |
|
|
if _session_manager is None: |
|
|
from .config import get_config |
|
|
config = get_config() |
|
|
_session_manager = SessionManager( |
|
|
sessions_dir=str(config.sessions_dir), |
|
|
max_sessions=config.max_sessions, |
|
|
session_max_age_hours=config.session_max_age_hours |
|
|
) |
|
|
return _session_manager |
|
|
|
|
|
|
|
|
def generate_session_id() -> str: |
|
|
"""Generate a unique session ID for a new user.""" |
|
|
import uuid |
|
|
return str(uuid.uuid4())[:12] |
|
|
|