|
|
import asyncio |
|
|
import os |
|
|
import json |
|
|
import logging |
|
|
import numpy as np |
|
|
from typing import Dict, List, Optional, Any |
|
|
from pathlib import Path |
|
|
from datetime import datetime |
|
|
|
|
|
from lightrag import LightRAG, QueryParam |
|
|
from lightrag.utils import EmbeddingFunc |
|
|
from lightrag.kg.shared_storage import initialize_pipeline_status |
|
|
|
|
|
class CloudflareWorker: |
|
|
def __init__( |
|
|
self, |
|
|
cloudflare_api_key: str, |
|
|
api_base_url: str, |
|
|
llm_model_name: str, |
|
|
embedding_model_name: str, |
|
|
max_tokens: int = 4080, |
|
|
max_response_tokens: int = 4080, |
|
|
): |
|
|
self.cloudflare_api_key = cloudflare_api_key |
|
|
self.api_base_url = api_base_url |
|
|
self.llm_model_name = llm_model_name |
|
|
self.embedding_model_name = embedding_model_name |
|
|
self.max_tokens = max_tokens |
|
|
self.max_response_tokens = max_response_tokens |
|
|
|
|
|
async def _send_request(self, model_name: str, input_: dict): |
|
|
import requests |
|
|
|
|
|
headers = {"Authorization": f"Bearer {self.cloudflare_api_key}"} |
|
|
|
|
|
try: |
|
|
response_raw = requests.post( |
|
|
f"{self.api_base_url}{model_name}", headers=headers, json=input_ |
|
|
).json() |
|
|
|
|
|
result = response_raw.get("result", {}) |
|
|
|
|
|
if "data" in result: |
|
|
return np.array(result["data"]) |
|
|
|
|
|
if "response" in result: |
|
|
return result["response"] |
|
|
|
|
|
raise ValueError("Unexpected Cloudflare response format") |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Cloudflare API error: {e}") |
|
|
return None |
|
|
|
|
|
async def query(self, prompt: str, system_prompt: str = "", **kwargs) -> str: |
|
|
kwargs.pop("hashing_kv", None) |
|
|
|
|
|
message = [ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": prompt}, |
|
|
] |
|
|
|
|
|
input_ = { |
|
|
"messages": message, |
|
|
"max_tokens": self.max_tokens, |
|
|
"response_token_limit": self.max_response_tokens, |
|
|
} |
|
|
|
|
|
return await self._send_request(self.llm_model_name, input_) |
|
|
|
|
|
async def embedding_chunk(self, texts: List[str]) -> np.ndarray: |
|
|
input_ = { |
|
|
"text": texts, |
|
|
"max_tokens": self.max_tokens, |
|
|
"response_token_limit": self.max_response_tokens, |
|
|
} |
|
|
|
|
|
return await self._send_request(self.embedding_model_name, input_) |
|
|
|
|
|
class LightRAGManager: |
|
|
def __init__(self, cloudflare_worker: CloudflareWorker, base_working_dir: str = "/app/lightrag_storage"): |
|
|
self.cloudflare_worker = cloudflare_worker |
|
|
self.base_working_dir = Path(base_working_dir) |
|
|
|
|
|
|
|
|
if not self.base_working_dir.exists(): |
|
|
logging.error(f"Base directory {self.base_working_dir} does not exist!") |
|
|
|
|
|
self.rag_instances: Dict[str, LightRAG] = {} |
|
|
self.conversation_memory: Dict[str, List[Dict]] = {} |
|
|
self.fire_safety_rag = None |
|
|
self.logger = logging.getLogger(__name__) |
|
|
|
|
|
async def initialize_fire_safety_rag(self) -> LightRAG: |
|
|
if self.fire_safety_rag is not None: |
|
|
return self.fire_safety_rag |
|
|
|
|
|
working_dir = self.base_working_dir / "fire_safety" |
|
|
|
|
|
|
|
|
if not working_dir.exists(): |
|
|
self.logger.error(f"Fire safety directory {working_dir} does not exist!") |
|
|
raise RuntimeError(f"Directory {working_dir} not found") |
|
|
|
|
|
self.fire_safety_rag = await self._create_rag_instance(str(working_dir)) |
|
|
|
|
|
|
|
|
await self._auto_migrate_fire_safety() |
|
|
|
|
|
return self.fire_safety_rag |
|
|
|
|
|
async def _auto_migrate_fire_safety(self): |
|
|
"""Auto-migrate fire safety knowledge from multiple sources on startup""" |
|
|
try: |
|
|
|
|
|
working_dir = self.base_working_dir / "fire_safety" |
|
|
existing_files = list(working_dir.glob("*.json")) |
|
|
|
|
|
self.logger.info(f"Checking for existing files in {working_dir}") |
|
|
self.logger.info(f"Found {len(existing_files)} existing JSON files: {[f.name for f in existing_files]}") |
|
|
|
|
|
if existing_files: |
|
|
self.logger.info("Fire safety knowledge already exists, skipping migration") |
|
|
return |
|
|
|
|
|
|
|
|
knowledge_sources = [ |
|
|
"/app/book.txt", |
|
|
"/app/book.pdf", |
|
|
"/app/fire_safety.txt", |
|
|
"/app/fire_safety.pdf", |
|
|
"/app/regulations.txt", |
|
|
"/app/regulations.pdf", |
|
|
"./book.txt", |
|
|
"./book.pdf" |
|
|
] |
|
|
|
|
|
self.logger.info("🔍 Scanning for knowledge files...") |
|
|
|
|
|
|
|
|
app_dir = Path("/app") |
|
|
if app_dir.exists(): |
|
|
all_files = list(app_dir.glob("*")) |
|
|
self.logger.info(f"Files in /app/: {[f.name for f in all_files if f.is_file()]}") |
|
|
|
|
|
combined_content = "" |
|
|
processed_files = [] |
|
|
|
|
|
for source_file in knowledge_sources: |
|
|
self.logger.info(f"Checking for file: {source_file}") |
|
|
|
|
|
if os.path.exists(source_file): |
|
|
self.logger.info(f"✅ Found knowledge source: {source_file}") |
|
|
|
|
|
try: |
|
|
if source_file.endswith('.txt'): |
|
|
|
|
|
with open(source_file, 'r', encoding='utf-8') as f: |
|
|
content = f.read() |
|
|
if content.strip(): |
|
|
combined_content += f"\n\n=== Content from {source_file} ===\n\n{content}\n\n" |
|
|
processed_files.append(source_file) |
|
|
self.logger.info(f"📄 Processed {source_file}: {len(content)} characters") |
|
|
else: |
|
|
self.logger.warning(f"⚠️ File {source_file} is empty") |
|
|
|
|
|
elif source_file.endswith('.pdf'): |
|
|
|
|
|
self.logger.info(f"📄 Extracting PDF content from {source_file}") |
|
|
content = await self._extract_pdf_content(source_file) |
|
|
if content and content.strip(): |
|
|
combined_content += f"\n\n=== Content from {source_file} ===\n\n{content}\n\n" |
|
|
processed_files.append(source_file) |
|
|
self.logger.info(f"📄 Processed {source_file}: {len(content)} characters") |
|
|
else: |
|
|
self.logger.warning(f"⚠️ Could not extract content from {source_file}") |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"❌ Error processing {source_file}: {e}") |
|
|
continue |
|
|
else: |
|
|
self.logger.debug(f"❌ File not found: {source_file}") |
|
|
|
|
|
self.logger.info(f"📊 Total processed files: {len(processed_files)}") |
|
|
self.logger.info(f"📊 Combined content length: {len(combined_content)} characters") |
|
|
|
|
|
if combined_content.strip(): |
|
|
|
|
|
self.logger.info("🚀 Inserting combined content into LightRAG...") |
|
|
try: |
|
|
await self.fire_safety_rag.ainsert(combined_content) |
|
|
self.logger.info(f"✅ Successfully migrated fire safety knowledge from {len(processed_files)} files: {processed_files}") |
|
|
|
|
|
|
|
|
created_files = list(working_dir.glob("*.json")) |
|
|
self.logger.info(f"📁 Created LightRAG files: {[f.name for f in created_files]}") |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"❌ Failed to insert content into LightRAG: {e}") |
|
|
raise |
|
|
else: |
|
|
self.logger.warning("⚠️ No fire safety knowledge files found or readable") |
|
|
self.logger.info("💡 Expected files: book.txt, book.pdf in /app/ directory") |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"❌ Auto-migration failed: {e}") |
|
|
import traceback |
|
|
self.logger.error(f"Full traceback: {traceback.format_exc()}") |
|
|
|
|
|
async def _extract_pdf_content(self, pdf_path: str) -> str: |
|
|
"""Extract text content from PDF file""" |
|
|
try: |
|
|
self.logger.info(f"📄 Attempting to extract PDF content from: {pdf_path}") |
|
|
|
|
|
|
|
|
try: |
|
|
import PyPDF2 |
|
|
content = "" |
|
|
self.logger.info("🔧 Using PyPDF2 for extraction") |
|
|
|
|
|
with open(pdf_path, 'rb') as file: |
|
|
pdf_reader = PyPDF2.PdfReader(file) |
|
|
self.logger.info(f"📊 PDF has {len(pdf_reader.pages)} pages") |
|
|
|
|
|
for page_num, page in enumerate(pdf_reader.pages): |
|
|
text = page.extract_text() |
|
|
if text and text.strip(): |
|
|
content += f"\n--- Page {page_num + 1} ---\n{text}\n" |
|
|
|
|
|
if content.strip(): |
|
|
self.logger.info(f"✅ PyPDF2: Extracted {len(content)} characters from {pdf_path}") |
|
|
return content |
|
|
else: |
|
|
self.logger.warning("⚠️ PyPDF2: No text content extracted") |
|
|
|
|
|
except ImportError: |
|
|
self.logger.warning("❌ PyPDF2 not available, trying alternative methods") |
|
|
except Exception as e: |
|
|
self.logger.warning(f"❌ PyPDF2 extraction failed: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
import pdfplumber |
|
|
content = "" |
|
|
self.logger.info("🔧 Using pdfplumber for extraction") |
|
|
|
|
|
with pdfplumber.open(pdf_path) as pdf: |
|
|
self.logger.info(f"📊 PDF has {len(pdf.pages)} pages") |
|
|
|
|
|
for page_num, page in enumerate(pdf.pages): |
|
|
text = page.extract_text() |
|
|
if text and text.strip(): |
|
|
content += f"\n--- Page {page_num + 1} ---\n{text}\n" |
|
|
|
|
|
if content.strip(): |
|
|
self.logger.info(f"✅ pdfplumber: Extracted {len(content)} characters from {pdf_path}") |
|
|
return content |
|
|
else: |
|
|
self.logger.warning("⚠️ pdfplumber: No text content extracted") |
|
|
|
|
|
except ImportError: |
|
|
self.logger.warning("❌ pdfplumber not available") |
|
|
except Exception as e: |
|
|
self.logger.warning(f"❌ pdfplumber extraction failed: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
import fitz |
|
|
content = "" |
|
|
self.logger.info("🔧 Using PyMuPDF for extraction") |
|
|
|
|
|
pdf_document = fitz.open(pdf_path) |
|
|
self.logger.info(f"📊 PDF has {pdf_document.page_count} pages") |
|
|
|
|
|
for page_num in range(pdf_document.page_count): |
|
|
page = pdf_document[page_num] |
|
|
text = page.get_text() |
|
|
if text and text.strip(): |
|
|
content += f"\n--- Page {page_num + 1} ---\n{text}\n" |
|
|
pdf_document.close() |
|
|
|
|
|
if content.strip(): |
|
|
self.logger.info(f"✅ PyMuPDF: Extracted {len(content)} characters from {pdf_path}") |
|
|
return content |
|
|
else: |
|
|
self.logger.warning("⚠️ PyMuPDF: No text content extracted") |
|
|
|
|
|
except ImportError: |
|
|
self.logger.warning("❌ PyMuPDF not available") |
|
|
except Exception as e: |
|
|
self.logger.warning(f"❌ PyMuPDF extraction failed: {e}") |
|
|
|
|
|
|
|
|
self.logger.error(f"❌ Could not extract text from PDF: {pdf_path}") |
|
|
self.logger.info("💡 Please convert PDF to text format or check if PDF contains text (not just images)") |
|
|
return "" |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"❌ PDF extraction error for {pdf_path}: {e}") |
|
|
return "" |
|
|
|
|
|
async def create_custom_rag(self, user_id: str, ai_id: str, knowledge_texts: List[str]) -> LightRAG: |
|
|
instance_key = f"{user_id}_{ai_id}" |
|
|
|
|
|
if instance_key in self.rag_instances: |
|
|
return self.rag_instances[instance_key] |
|
|
|
|
|
working_dir = self.base_working_dir / "custom" / user_id / ai_id |
|
|
|
|
|
|
|
|
try: |
|
|
working_dir.mkdir(parents=True, exist_ok=True) |
|
|
except Exception as e: |
|
|
self.logger.error(f"Failed to create custom AI directory: {e}") |
|
|
raise |
|
|
|
|
|
rag = await self._create_rag_instance(str(working_dir)) |
|
|
|
|
|
if knowledge_texts: |
|
|
await self._insert_knowledge_batch(rag, knowledge_texts) |
|
|
|
|
|
self.rag_instances[instance_key] = rag |
|
|
return rag |
|
|
|
|
|
async def _create_rag_instance(self, working_dir: str) -> LightRAG: |
|
|
try: |
|
|
rag = LightRAG( |
|
|
working_dir=working_dir, |
|
|
max_parallel_insert=2, |
|
|
llm_model_func=self.cloudflare_worker.query, |
|
|
llm_model_name=self.cloudflare_worker.llm_model_name, |
|
|
llm_model_max_token_size=4080, |
|
|
embedding_func=EmbeddingFunc( |
|
|
embedding_dim=1024, |
|
|
max_token_size=2048, |
|
|
func=self.cloudflare_worker.embedding_chunk, |
|
|
), |
|
|
graph_storage="NetworkXStorage", |
|
|
vector_storage="NanoVectorDBStorage", |
|
|
) |
|
|
|
|
|
await rag.initialize_storages() |
|
|
await initialize_pipeline_status() |
|
|
|
|
|
return rag |
|
|
except Exception as e: |
|
|
self.logger.error(f"Failed to create RAG instance: {e}") |
|
|
raise |
|
|
|
|
|
async def _insert_knowledge_batch(self, rag: LightRAG, texts: List[str]): |
|
|
for text in texts: |
|
|
if text.strip(): |
|
|
try: |
|
|
await rag.ainsert(text) |
|
|
except Exception as e: |
|
|
self.logger.error(f"Failed to insert knowledge: {e}") |
|
|
continue |
|
|
|
|
|
async def query_with_memory( |
|
|
self, |
|
|
rag: LightRAG, |
|
|
question: str, |
|
|
conversation_id: str, |
|
|
mode: str = "hybrid", |
|
|
max_memory_turns: int = 10 |
|
|
) -> str: |
|
|
try: |
|
|
memory = self.conversation_memory.get(conversation_id, []) |
|
|
context_prompt = self._build_context_prompt(question, memory, max_memory_turns) |
|
|
|
|
|
response = await rag.aquery( |
|
|
context_prompt, |
|
|
param=QueryParam(mode=mode) |
|
|
) |
|
|
|
|
|
self._update_conversation_memory(conversation_id, question, response) |
|
|
return response |
|
|
except Exception as e: |
|
|
self.logger.error(f"Query with memory failed: {e}") |
|
|
|
|
|
try: |
|
|
response = await rag.aquery(question, param=QueryParam(mode=mode)) |
|
|
return response |
|
|
except Exception as e2: |
|
|
self.logger.error(f"Direct query also failed: {e2}") |
|
|
return f"I apologize, but I'm experiencing technical difficulties. Please try again later." |
|
|
|
|
|
def _build_context_prompt(self, question: str, memory: List[Dict], max_turns: int) -> str: |
|
|
if not memory: |
|
|
return question |
|
|
|
|
|
recent_memory = memory[-max_turns*2:] if len(memory) > max_turns*2 else memory |
|
|
|
|
|
context = "Previous conversation:\n" |
|
|
for msg in recent_memory: |
|
|
role = msg['role'] |
|
|
content = msg['content'][:200] + "..." if len(msg['content']) > 200 else msg['content'] |
|
|
context += f"{role.title()}: {content}\n" |
|
|
|
|
|
context += f"\nCurrent question: {question}" |
|
|
return context |
|
|
|
|
|
def _update_conversation_memory(self, conversation_id: str, question: str, response: str): |
|
|
if conversation_id not in self.conversation_memory: |
|
|
self.conversation_memory[conversation_id] = [] |
|
|
|
|
|
memory = self.conversation_memory[conversation_id] |
|
|
|
|
|
memory.append({ |
|
|
'role': 'user', |
|
|
'content': question, |
|
|
'timestamp': datetime.now().isoformat() |
|
|
}) |
|
|
|
|
|
memory.append({ |
|
|
'role': 'assistant', |
|
|
'content': response, |
|
|
'timestamp': datetime.now().isoformat() |
|
|
}) |
|
|
|
|
|
if len(memory) > 50: |
|
|
self.conversation_memory[conversation_id] = memory[-50:] |
|
|
|
|
|
async def get_rag_instance(self, ai_type: str, user_id: str = None, ai_id: str = None) -> LightRAG: |
|
|
if ai_type == "fire-safety": |
|
|
return await self.initialize_fire_safety_rag() |
|
|
elif ai_type == "custom" and user_id and ai_id: |
|
|
instance_key = f"{user_id}_{ai_id}" |
|
|
if instance_key not in self.rag_instances: |
|
|
working_dir = self.base_working_dir / "custom" / user_id / ai_id |
|
|
if working_dir.exists(): |
|
|
rag = await self._create_rag_instance(str(working_dir)) |
|
|
self.rag_instances[instance_key] = rag |
|
|
else: |
|
|
raise ValueError(f"Custom AI {ai_id} knowledge base not found") |
|
|
return self.rag_instances[instance_key] |
|
|
else: |
|
|
raise ValueError(f"Unknown AI type: {ai_type}") |
|
|
|
|
|
def clear_conversation_memory(self, conversation_id: str): |
|
|
if conversation_id in self.conversation_memory: |
|
|
del self.conversation_memory[conversation_id] |
|
|
|
|
|
async def cleanup(self): |
|
|
for rag in self.rag_instances.values(): |
|
|
if hasattr(rag, 'finalize_storages'): |
|
|
await rag.finalize_storages() |
|
|
|
|
|
if self.fire_safety_rag and hasattr(self.fire_safety_rag, 'finalize_storages'): |
|
|
await self.fire_safety_rag.finalize_storages() |
|
|
|
|
|
|
|
|
lightrag_manager: Optional[LightRAGManager] = None |
|
|
|
|
|
async def initialize_lightrag_manager(cloudflare_worker: CloudflareWorker) -> LightRAGManager: |
|
|
global lightrag_manager |
|
|
if lightrag_manager is None: |
|
|
lightrag_manager = LightRAGManager(cloudflare_worker) |
|
|
return lightrag_manager |
|
|
|
|
|
def get_lightrag_manager() -> LightRAGManager: |
|
|
if lightrag_manager is None: |
|
|
raise RuntimeError("LightRAG manager not initialized") |
|
|
return lightrag_manager |