Spaces:
Runtime error
Runtime error
| # Define the script's usage example | |
| USAGE_EXAMPLE = """ | |
| Example usage: | |
| To process input *.txt files at input_path and save the vector db output at output_db: | |
| python create_vector_db.py input_path output_db --chunk_size 100 --chunk_overlap 10 | |
| Required arguments: | |
| - input_path: Path to the input dir containing the .txt files | |
| - output_path: Path to the output vector db. | |
| Optional arguments: | |
| - --chunk_size: Size of the chunks (default: None). | |
| - --chunk_overlap: Overlap between chunks (default: None). | |
| """ | |
| import os | |
| import sys | |
| import argparse | |
| import logging | |
| from langchain_community.document_loaders import DirectoryLoader, UnstructuredURLLoader | |
| from langchain_community.embeddings import HuggingFaceInstructEmbeddings | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter | |
| from langchain_community.vectorstores import FAISS, Chroma, Qdrant | |
| vectordb_dir = os.path.dirname(os.path.abspath(__file__)) | |
| utils_dir = os.path.abspath(os.path.join(vectordb_dir, "..")) | |
| repo_dir = os.path.abspath(os.path.join(utils_dir, "..")) | |
| sys.path.append(repo_dir) | |
| sys.path.append(utils_dir) | |
| from utils.model_wrappers.api_gateway import APIGateway | |
| import uuid | |
| import streamlit as st | |
| EMBEDDING_MODEL = "intfloat/e5-large-v2" | |
| NORMALIZE_EMBEDDINGS = True | |
| VECTORDB_LOG_FILE_NAME = "vector_db.log" | |
| # Configure the logger | |
| logging.basicConfig( | |
| level=logging.INFO, # Set the logging level (e.g., INFO, DEBUG) | |
| format="%(asctime)s [%(levelname)s] - %(message)s", # Define the log message format | |
| handlers=[ | |
| logging.StreamHandler(), # Output logs to the console | |
| logging.FileHandler(VECTORDB_LOG_FILE_NAME), | |
| ], | |
| ) | |
| # Create a logger object | |
| logger = logging.getLogger(__name__) | |
| class VectorDb(): | |
| """ | |
| A class for creating, updating and loading FAISS or Chroma vector databases, | |
| to use them with retrieval augmented generation tasks with langchain | |
| Args: | |
| None | |
| Attributes: | |
| None | |
| Methods: | |
| load_files: Load files from an input directory as langchain documents | |
| get_text_chunks: Get text chunks from a list of documents | |
| get_token_chunks: Get token chunks from a list of documents | |
| create_vector_store: Create a vector store from chunks and an embedding model | |
| load_vdb: load a previous stored vector database | |
| update_vdb: Update an existing vector store with new chunks | |
| create_vdb: Create a vector database from the raw files in a specific input directory | |
| """ | |
| def __init__(self) -> None: | |
| self.collection_id = str(uuid.uuid4()) | |
| self.vector_collections = set() | |
| def load_files(self, input_path, recursive=False, load_txt=True, load_pdf=False, urls = None) -> list: | |
| """Load files from input location | |
| Args: | |
| input_path : input location of files | |
| recursive (bool, optional): flag to load files recursively. Defaults to False. | |
| load_txt (bool, optional): flag to load txt files. Defaults to True. | |
| load_pdf (bool, optional): flag to load pdf files. Defaults to False. | |
| urls (list, optional): list of urls to load. Defaults to None. | |
| Returns: | |
| list: list of documents | |
| """ | |
| docs=[] | |
| text_loader_kwargs={'autodetect_encoding': True} | |
| if input_path is not None: | |
| if load_txt: | |
| loader = DirectoryLoader(input_path, glob="*.txt", recursive=recursive, show_progress=True, loader_kwargs=text_loader_kwargs) | |
| docs.extend(loader.load()) | |
| if load_pdf: | |
| loader = DirectoryLoader(input_path, glob="*.pdf", recursive=recursive, show_progress=True, loader_kwargs=text_loader_kwargs) | |
| docs.extend(loader.load()) | |
| if urls: | |
| loader = UnstructuredURLLoader(urls=urls) | |
| docs.extend(loader.load()) | |
| logger.info(f"Total {len(docs)} files loaded") | |
| return docs | |
| def get_text_chunks(self, docs: list, chunk_size: int, chunk_overlap: int, meta_data: list = None) -> list: | |
| """Gets text chunks. If metadata is not None, it will create chunks with metadata elements. | |
| Args: | |
| docs (list): list of documents or texts. If no metadata is passed, this parameter is a list of documents. | |
| If metadata is passed, this parameter is a list of texts. | |
| chunk_size (int): chunk size in number of characters | |
| chunk_overlap (int): chunk overlap in number of characters | |
| metadata (list, optional): list of metadata in dictionary format. Defaults to None. | |
| Returns: | |
| list: list of documents | |
| """ | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len | |
| ) | |
| if meta_data is None: | |
| logger.info(f"Splitter: splitting documents") | |
| chunks = text_splitter.split_documents(docs) | |
| else: | |
| logger.info(f"Splitter: creating documents with metadata") | |
| chunks = text_splitter.create_documents(docs, meta_data) | |
| logger.info(f"Total {len(chunks)} chunks created") | |
| return chunks | |
| def get_token_chunks(self, docs: list, chunk_size: int, chunk_overlap: int, tokenizer) -> list: | |
| """Gets token chunks. If metadata is not None, it will create chunks with metadata elements. | |
| Args: | |
| docs (list): list of documents or texts. If no metadata is passed, this parameter is a list of documents. | |
| If metadata is passed, this parameter is a list of texts. | |
| chunk_size (int): chunk size in number of tokens | |
| chunk_overlap (int): chunk overlap in number of tokens | |
| Returns: | |
| list: list of documents | |
| """ | |
| text_splitter = CharacterTextSplitter.from_huggingface_tokenizer( | |
| tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap | |
| ) | |
| logger.info(f"Splitter: splitting documents") | |
| chunks = text_splitter.split_documents(docs) | |
| logger.info(f"Total {len(chunks)} chunks created") | |
| return chunks | |
| def create_vector_store(self, chunks: list, embeddings: HuggingFaceInstructEmbeddings, db_type: str, | |
| output_db: str = None, collection_name: str = None): | |
| """Creates a vector store | |
| Args: | |
| chunks (list): list of chunks | |
| embeddings (HuggingFaceInstructEmbeddings): embedding model | |
| db_type (str): vector db type | |
| output_db (str, optional): output path to save the vector db. Defaults to None. | |
| """ | |
| if collection_name is None: | |
| collection_name = f"collection_{self.collection_id}" | |
| logger.info(f'This is the collection name: {collection_name}') | |
| if db_type == "faiss": | |
| vector_store = FAISS.from_documents( | |
| documents=chunks, | |
| embedding=embeddings | |
| ) | |
| if output_db: | |
| vector_store.save_local(output_db) | |
| elif db_type == "chroma": | |
| if output_db: | |
| vector_store = Chroma() | |
| vector_store.delete_collection() | |
| vector_store = Chroma.from_documents( | |
| documents=chunks, | |
| embedding=embeddings, | |
| persist_directory=output_db, | |
| collection_name=collection_name | |
| ) | |
| else: | |
| vector_store = Chroma() | |
| vector_store.delete_collection() | |
| vector_store = Chroma.from_documents( | |
| documents=chunks, | |
| embedding=embeddings, | |
| collection_name=collection_name | |
| ) | |
| self.vector_collections.add(collection_name) | |
| elif db_type == "qdrant": | |
| if output_db: | |
| vector_store = Qdrant.from_documents( | |
| documents=chunks, | |
| embedding=embeddings, | |
| path=output_db, | |
| collection_name="test_collection", | |
| ) | |
| else: | |
| vector_store = Qdrant.from_documents( | |
| documents=chunks, | |
| embedding=embeddings, | |
| collection_name="test_collection", | |
| ) | |
| logger.info(f"Vector store saved to {output_db}") | |
| return vector_store | |
| def load_vdb(self, persist_directory, embedding_model, db_type="chroma", collection_name=None): | |
| if db_type == "faiss": | |
| vector_store = FAISS.load_local(persist_directory, embedding_model, allow_dangerous_deserialization=True) | |
| elif db_type == "chroma": | |
| if collection_name: | |
| vector_store = Chroma( | |
| persist_directory=persist_directory, | |
| embedding_function=embedding_model, | |
| collection_name=collection_name | |
| ) | |
| else: | |
| vector_store = Chroma( | |
| persist_directory=persist_directory, | |
| embedding_function=embedding_model | |
| ) | |
| elif db_type == "qdrant": | |
| # TODO: Implement Qdrant loading | |
| pass | |
| else: | |
| raise ValueError(f"Unsupported database type: {db_type}") | |
| return vector_store | |
| def update_vdb(self, chunks: list, embeddings, db_type: str, input_db: str = None, | |
| output_db: str = None): | |
| if db_type == "faiss": | |
| vector_store = FAISS.load_local(input_db, embeddings, allow_dangerous_deserialization=True) | |
| new_vector_store = self.create_vector_store(chunks, embeddings, db_type, None) | |
| vector_store.merge_from(new_vector_store) | |
| if output_db: | |
| vector_store.save_local(output_db) | |
| elif db_type == "chroma": | |
| # TODO implement update method for chroma | |
| pass | |
| elif db_type == "qdrant": | |
| # TODO implement update method for qdrant | |
| pass | |
| return vector_store | |
| def create_vdb( | |
| self, | |
| input_path, | |
| chunk_size, | |
| chunk_overlap, | |
| db_type, | |
| output_db=None, | |
| recursive=False, | |
| tokenizer=None, | |
| load_txt=True, | |
| load_pdf=False, | |
| urls=None, | |
| embedding_type="cpu", | |
| batch_size= None, | |
| coe = None, | |
| select_expert = None | |
| ): | |
| docs = self.load_files(input_path, recursive=recursive, load_txt=load_txt, load_pdf=load_pdf, urls=urls) | |
| if tokenizer is None: | |
| chunks = self.get_text_chunks(docs, chunk_size, chunk_overlap) | |
| else: | |
| chunks = self.get_token_chunks(docs, chunk_size, chunk_overlap, tokenizer) | |
| embeddings = APIGateway.load_embedding_model( | |
| type=embedding_type, | |
| batch_size=batch_size, | |
| coe=coe, | |
| select_expert=select_expert | |
| ) | |
| vector_store = self.create_vector_store(chunks, embeddings, db_type, output_db) | |
| return vector_store | |
| def dir_path(path): | |
| if os.path.isdir(path): | |
| return path | |
| else: | |
| raise argparse.ArgumentTypeError(f"readable_dir:{path} is not a valid path") | |
| # Parse the arguments | |
| def parse_arguments(): | |
| parser = argparse.ArgumentParser(description="Process command line arguments.") | |
| parser.add_argument("-input_path", type=dir_path, help="path to input directory") | |
| parser.add_argument("--chunk_size", type=int, help="chunk size for splitting") | |
| parser.add_argument("--chunk_overlap", type=int, help="chunk overlap for splitting") | |
| parser.add_argument("-output_path", type=dir_path, help="path to input directory") | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Process data with optional chunking") | |
| # Required arguments | |
| parser.add_argument("--input_path", type=str, help="Path to the input directory") | |
| parser.add_argument("--output_db", type=str, help="Path to the output vectordb") | |
| # Optional arguments | |
| parser.add_argument( | |
| "--chunk_size", type=int, default=1000, help="Chunk size (default: 1000)" | |
| ) | |
| parser.add_argument( | |
| "--chunk_overlap", type=int, default=200, help="Chunk overlap (default: 200)" | |
| ) | |
| parser.add_argument( | |
| "--db_type", | |
| type=str, | |
| default="faiss", | |
| help="Type of vector store (default: faiss)", | |
| ) | |
| args = parser.parse_args() | |
| vectordb = VectorDb() | |
| vectordb.create_vdb( | |
| args.input_path, | |
| args.output_db, | |
| args.chunk_size, | |
| args.chunk_overlap, | |
| args.db_type, | |
| ) | |