Spaces:
Running
Running
Commit
·
ecba0c9
1
Parent(s):
c60f264
Introduce `RagPipeline` for conversational RAG, including vector store management and query processing.
Browse files- core/rag_pipeline.py +33 -9
core/rag_pipeline.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
| 1 |
import logging
|
| 2 |
import time
|
|
|
|
| 3 |
from langchain_core.messages import HumanMessage, AIMessage
|
| 4 |
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
| 5 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 6 |
from config.settings import Settings
|
| 7 |
from services.llm_client import LLMClient
|
| 8 |
from services.vector_store import VectorStore
|
|
|
|
| 9 |
from core.prompts import get_chat_prompt, get_contextualize_prompt
|
| 10 |
|
| 11 |
logger = logging.getLogger(__name__)
|
|
@@ -15,6 +17,28 @@ class RagPipeline:
|
|
| 15 |
try:
|
| 16 |
self.llm = LLMClient().get_llm()
|
| 17 |
self.vector_store = VectorStore()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
self.retriever = self.vector_store.get_retriever(k=5)
|
| 19 |
self.prompt = get_chat_prompt()
|
| 20 |
self.history_aware_retriever = create_history_aware_retriever(
|
|
@@ -30,7 +54,6 @@ class RagPipeline:
|
|
| 30 |
self.history_aware_retriever,
|
| 31 |
self.question_answer_chain
|
| 32 |
)
|
| 33 |
-
self.chat_history = []
|
| 34 |
logger.info("RAG pipeline initialized successfully")
|
| 35 |
|
| 36 |
except Exception as e:
|
|
@@ -38,22 +61,23 @@ class RagPipeline:
|
|
| 38 |
raise e
|
| 39 |
|
| 40 |
def clear_history(self):
|
| 41 |
-
|
| 42 |
-
logger.info("Chat history cleared")
|
| 43 |
|
| 44 |
def process_query(self, question:str, chat_history: list = []):
|
| 45 |
start_time = time.time()
|
| 46 |
try:
|
| 47 |
logger.info(f"Processing query: {question}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
response = self.rag_chain.invoke({
|
| 49 |
"input": question,
|
| 50 |
-
"chat_history":
|
| 51 |
})
|
| 52 |
-
|
| 53 |
-
self.chat_history.extend([
|
| 54 |
-
HumanMessage(content=question),
|
| 55 |
-
AIMessage(content=response["answer"])
|
| 56 |
-
])
|
| 57 |
|
| 58 |
latency = time.time() - start_time
|
| 59 |
|
|
|
|
| 1 |
import logging
|
| 2 |
import time
|
| 3 |
+
import os
|
| 4 |
from langchain_core.messages import HumanMessage, AIMessage
|
| 5 |
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
| 6 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 7 |
from config.settings import Settings
|
| 8 |
from services.llm_client import LLMClient
|
| 9 |
from services.vector_store import VectorStore
|
| 10 |
+
from services.document_processor import DocumentProcessor
|
| 11 |
from core.prompts import get_chat_prompt, get_contextualize_prompt
|
| 12 |
|
| 13 |
logger = logging.getLogger(__name__)
|
|
|
|
| 17 |
try:
|
| 18 |
self.llm = LLMClient().get_llm()
|
| 19 |
self.vector_store = VectorStore()
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
logger.info("Checking Vector Database integrity...")
|
| 23 |
+
test_retriever = self.vector_store.get_retriever(k=1)
|
| 24 |
+
test_retriever.invoke("test")
|
| 25 |
+
logger.info("Vector Database is healthy and ready.")
|
| 26 |
+
|
| 27 |
+
except Exception as e:
|
| 28 |
+
logger.warning(f"Database seems empty or corrupt ({str(e)}). Rebuilding from PDF...")
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
processor = DocumentProcessor()
|
| 32 |
+
docs = processor.process_documents("./data")
|
| 33 |
+
|
| 34 |
+
if not docs:
|
| 35 |
+
logger.error("No documents found in ./data folder to ingest!")
|
| 36 |
+
else:
|
| 37 |
+
self.vector_store.add_documents(docs)
|
| 38 |
+
logger.info(f"Successfully rebuilt database with {len(docs)} chunks.")
|
| 39 |
+
except Exception as build_error:
|
| 40 |
+
logger.error(f"Failed to rebuild database: {str(build_error)}")
|
| 41 |
+
|
| 42 |
self.retriever = self.vector_store.get_retriever(k=5)
|
| 43 |
self.prompt = get_chat_prompt()
|
| 44 |
self.history_aware_retriever = create_history_aware_retriever(
|
|
|
|
| 54 |
self.history_aware_retriever,
|
| 55 |
self.question_answer_chain
|
| 56 |
)
|
|
|
|
| 57 |
logger.info("RAG pipeline initialized successfully")
|
| 58 |
|
| 59 |
except Exception as e:
|
|
|
|
| 61 |
raise e
|
| 62 |
|
| 63 |
def clear_history(self):
|
| 64 |
+
pass
|
|
|
|
| 65 |
|
| 66 |
def process_query(self, question:str, chat_history: list = []):
|
| 67 |
start_time = time.time()
|
| 68 |
try:
|
| 69 |
logger.info(f"Processing query: {question}")
|
| 70 |
+
langchain_history = []
|
| 71 |
+
for msg in chat_history:
|
| 72 |
+
if msg[0] == "human":
|
| 73 |
+
langchain_history.append(HumanMessage(content=msg[1]))
|
| 74 |
+
elif msg[0] == "ai":
|
| 75 |
+
langchain_history.append(AIMessage(content=msg[1]))
|
| 76 |
+
|
| 77 |
response = self.rag_chain.invoke({
|
| 78 |
"input": question,
|
| 79 |
+
"chat_history": langchain_history
|
| 80 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
latency = time.time() - start_time
|
| 83 |
|