YousefMohtady1 commited on
Commit
ecba0c9
·
1 Parent(s): c60f264

Introduce `RagPipeline` for conversational RAG, including vector store management and query processing.

Browse files
Files changed (1) hide show
  1. core/rag_pipeline.py +33 -9
core/rag_pipeline.py CHANGED
@@ -1,11 +1,13 @@
1
  import logging
2
  import time
 
3
  from langchain_core.messages import HumanMessage, AIMessage
4
  from langchain.chains import create_history_aware_retriever, create_retrieval_chain
5
  from langchain.chains.combine_documents import create_stuff_documents_chain
6
  from config.settings import Settings
7
  from services.llm_client import LLMClient
8
  from services.vector_store import VectorStore
 
9
  from core.prompts import get_chat_prompt, get_contextualize_prompt
10
 
11
  logger = logging.getLogger(__name__)
@@ -15,6 +17,28 @@ class RagPipeline:
15
  try:
16
  self.llm = LLMClient().get_llm()
17
  self.vector_store = VectorStore()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  self.retriever = self.vector_store.get_retriever(k=5)
19
  self.prompt = get_chat_prompt()
20
  self.history_aware_retriever = create_history_aware_retriever(
@@ -30,7 +54,6 @@ class RagPipeline:
30
  self.history_aware_retriever,
31
  self.question_answer_chain
32
  )
33
- self.chat_history = []
34
  logger.info("RAG pipeline initialized successfully")
35
 
36
  except Exception as e:
@@ -38,22 +61,23 @@ class RagPipeline:
38
  raise e
39
 
40
  def clear_history(self):
41
- self.chat_history = []
42
- logger.info("Chat history cleared")
43
 
44
  def process_query(self, question:str, chat_history: list = []):
45
  start_time = time.time()
46
  try:
47
  logger.info(f"Processing query: {question}")
 
 
 
 
 
 
 
48
  response = self.rag_chain.invoke({
49
  "input": question,
50
- "chat_history": self.chat_history
51
  })
52
-
53
- self.chat_history.extend([
54
- HumanMessage(content=question),
55
- AIMessage(content=response["answer"])
56
- ])
57
 
58
  latency = time.time() - start_time
59
 
 
1
  import logging
2
  import time
3
+ import os
4
  from langchain_core.messages import HumanMessage, AIMessage
5
  from langchain.chains import create_history_aware_retriever, create_retrieval_chain
6
  from langchain.chains.combine_documents import create_stuff_documents_chain
7
  from config.settings import Settings
8
  from services.llm_client import LLMClient
9
  from services.vector_store import VectorStore
10
+ from services.document_processor import DocumentProcessor
11
  from core.prompts import get_chat_prompt, get_contextualize_prompt
12
 
13
  logger = logging.getLogger(__name__)
 
17
  try:
18
  self.llm = LLMClient().get_llm()
19
  self.vector_store = VectorStore()
20
+
21
+ try:
22
+ logger.info("Checking Vector Database integrity...")
23
+ test_retriever = self.vector_store.get_retriever(k=1)
24
+ test_retriever.invoke("test")
25
+ logger.info("Vector Database is healthy and ready.")
26
+
27
+ except Exception as e:
28
+ logger.warning(f"Database seems empty or corrupt ({str(e)}). Rebuilding from PDF...")
29
+
30
+ try:
31
+ processor = DocumentProcessor()
32
+ docs = processor.process_documents("./data")
33
+
34
+ if not docs:
35
+ logger.error("No documents found in ./data folder to ingest!")
36
+ else:
37
+ self.vector_store.add_documents(docs)
38
+ logger.info(f"Successfully rebuilt database with {len(docs)} chunks.")
39
+ except Exception as build_error:
40
+ logger.error(f"Failed to rebuild database: {str(build_error)}")
41
+
42
  self.retriever = self.vector_store.get_retriever(k=5)
43
  self.prompt = get_chat_prompt()
44
  self.history_aware_retriever = create_history_aware_retriever(
 
54
  self.history_aware_retriever,
55
  self.question_answer_chain
56
  )
 
57
  logger.info("RAG pipeline initialized successfully")
58
 
59
  except Exception as e:
 
61
  raise e
62
 
63
  def clear_history(self):
64
+ pass
 
65
 
66
  def process_query(self, question:str, chat_history: list = []):
67
  start_time = time.time()
68
  try:
69
  logger.info(f"Processing query: {question}")
70
+ langchain_history = []
71
+ for msg in chat_history:
72
+ if msg[0] == "human":
73
+ langchain_history.append(HumanMessage(content=msg[1]))
74
+ elif msg[0] == "ai":
75
+ langchain_history.append(AIMessage(content=msg[1]))
76
+
77
  response = self.rag_chain.invoke({
78
  "input": question,
79
+ "chat_history": langchain_history
80
  })
 
 
 
 
 
81
 
82
  latency = time.time() - start_time
83