From 7e58a1a223a14d83efb3b4d62bb58f084e0432ff Mon Sep 17 00:00:00 2001 From: wanyaokun <12345678> Date: Tue, 13 Aug 2024 13:10:52 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E5=A4=9A=E5=B7=A5=E7=A8=8B?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=AD=98=E5=82=A8=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/api/services/file.py | 4 ++- backend/app/engine/__init__.py | 4 ++- backend/app/engine/generate.py | 36 ++++++++++----------- backend/app/engine/index.py | 27 ++++++++-------- backend/app/engine/loaders/__init__.py | 44 ++++++++++++++++++++++++-- backend/app/engine/loaders/file.py | 7 ++-- backend/app/engine/vectordb.py | 20 +++++++----- backend/run-data.bat | 2 +- backend/tests/query.py | 4 ++- 9 files changed, 97 insertions(+), 51 deletions(-) diff --git a/backend/app/api/services/file.py b/backend/app/api/services/file.py index a478570..e8eb54c 100644 --- a/backend/app/api/services/file.py +++ b/backend/app/api/services/file.py @@ -87,7 +87,9 @@ class PrivateFileService: nodes = pipeline.run(documents=documents) # Add the nodes to the index and persist it - current_index = get_index() + indexs = get_index() + if len(indexs) > 0: + current_index = list(indexs.values())[0] # Insert the documents into the index if isinstance(current_index, LlamaCloudIndex): diff --git a/backend/app/engine/__init__.py b/backend/app/engine/__init__.py index def5e51..97d92c3 100644 --- a/backend/app/engine/__init__.py +++ b/backend/app/engine/__init__.py @@ -43,7 +43,9 @@ def get_chat_engine(filters=None, params=None): description="来源于一个由博微公司电力造价软件编制的造价工程文件。该文件以多张表格的形式存储存储了整个工程的全部数据内容。适用于以详细的自然语言查询表格数据方式查询造价工程各项具体属性、费用的数值。请先使用“zj_query_tool”无法解决才使用本工具") # Add query tool if index exists - index = get_index() + indexs = get_index() + if len(indexs) > 0: + index = list(indexs.values())[0] if index is not None: summary_index = SummaryIndex(index.vector_store.get_nodes(node_ids=None)) summary_query_engine = summary_index.as_query_engine() diff --git a/backend/app/engine/generate.py b/backend/app/engine/generate.py index 115c175..1e05731 100644 --- a/backend/app/engine/generate.py +++ b/backend/app/engine/generate.py @@ -5,7 +5,7 @@ load_dotenv() import logging import os -from app.engine.loaders import get_documents +from app.engine.loaders import get_document_Types, get_documents from app.engine.vectordb import get_vector_store from app.settings import init_settings from llama_index.core.ingestion import IngestionPipeline @@ -19,17 +19,16 @@ logger = logging.getLogger() STORAGE_DIR = os.getenv("STORAGE_DIR", "storage") - -def get_doc_store(): +def get_doc_store(docType:str): # If the storage directory is there, load the document store from it. # If not, set up an in-memory document store since we can't load from a directory that doesn't exist. - if os.path.exists(STORAGE_DIR): - return SimpleDocumentStore.from_persist_dir(STORAGE_DIR) + storeDir = os.path.join(STORAGE_DIR,docType) + if os.path.exists(storeDir): + return SimpleDocumentStore.from_persist_dir(storeDir) else: return SimpleDocumentStore() - def run_pipeline(docstore, vector_store, documents): pipeline = IngestionPipeline( transformations=[ @@ -49,7 +48,6 @@ def run_pipeline(docstore, vector_store, documents): return nodes - def persist_storage(docstore, vector_store): storage_context = StorageContext.from_defaults( docstore=docstore, @@ -57,28 +55,28 @@ def persist_storage(docstore, vector_store): ) storage_context.persist(STORAGE_DIR) - def generate_datasource(): init_settings() logger.info("Generate index for the provided data") # Get the stores and documents or create new ones - documents = get_documents() - # Set private=false to mark the document as public (required for filtering) - for doc in documents: - doc.metadata["private"] = "false" - docstore = get_doc_store() - vector_store = get_vector_store() + docTypes = get_document_Types() + for docType in docTypes: + documents = get_documents(docType) + # Set private=false to mark the document as public (required for filtering) + for doc in documents: + doc.metadata["private"] = "false" + docstore = get_doc_store(docType) + vector_store = get_vector_store(docType) - # Run the ingestion pipeline - _ = run_pipeline(docstore, vector_store, documents) + # Run the ingestion pipeline + _ = run_pipeline(docstore, vector_store, documents) - # Build the index and persist storage - persist_storage(docstore, vector_store) + # Build the index and persist storage + persist_storage(docstore, vector_store) logger.info("Finished generating the index") - if __name__ == "__main__": from phoenix.trace import using_project with using_project(os.getenv("PHOENIX_PROJECT_NAME") + "_generate") as obj: diff --git a/backend/app/engine/index.py b/backend/app/engine/index.py index b21e695..6f6164d 100644 --- a/backend/app/engine/index.py +++ b/backend/app/engine/index.py @@ -1,22 +1,23 @@ import logging from llama_index.core.indices import VectorStoreIndex from app.engine.vectordb import get_vector_store - +from app.engine.generate import get_document_Types logger = logging.getLogger("uvicorn") -index = None +indexs = {} def get_index(params=None): - global index - if index is None: + global indexs + if len(index) <= 0: logger.info("Connecting vector store...") - - store = get_vector_store() - # Load the index from the vector store - # If you are using a vector store that doesn't store text, - # you must load the index from both the vector store and the document store - index = VectorStoreIndex.from_vector_store(store) - logger.info("Finished load index from vector store.") - - return index + docTypes = get_document_Types() + for docType in docTypes: + store = get_vector_store(docType) + # Load the index from the vector store + # If you are using a vector store that doesn't store text, + # you must load the index from both the vector store and the document store + index = VectorStoreIndex.from_vector_store(store) + logger.info("Finished load index from vector store.") + indexs[docType] = index + return indexs diff --git a/backend/app/engine/loaders/__init__.py b/backend/app/engine/loaders/__init__.py index a220170..e36c167 100644 --- a/backend/app/engine/loaders/__init__.py +++ b/backend/app/engine/loaders/__init__.py @@ -13,8 +13,48 @@ def load_configs(): configs = yaml.safe_load(f) return configs +def path_difference(path1:str, path2:str): + import os + path1 = os.path.abspath(path1) + path2 = os.path.abspath(path2) -def get_documents(): + path1_parts = path1.split(os.path.sep) + path2_parts = path2.split(os.path.sep) + + for i, part in enumerate(path1_parts): + if part != path2_parts[i]: + break + else: + i += 1 + + pathKey = '' + for j in range(i,len(path2_parts)): + pathKey+=path2_parts[j] + '_' + return pathKey[0:-1] + +def get_document_Types(): + import os + rootPath = 'data' + configs = load_configs() + if configs is not None and len(configs.items()) > 0: + for loader_type, loader_config in configs.items(): + if loader_type == "file": + rootPath = FileLoaderConfig(**loader_config).data_dir + break + + types = [] + dirStack = [rootPath] + while len(dirStack) > 0: + curDir = dirStack.pop() + dirs = [os.path.join(curDir, d) for d in os.listdir(curDir) if os.path.isdir(os.path.join(curDir, d))] + if len(dirs) > 0: + for dir in dirs: + dirStack.append(dir) + else: + types.append(path_difference(rootPath,curDir)) + return types + +def get_documents(docType:str): documents = [] config = load_configs() if config is None or len(config.items()) == 0: @@ -28,7 +68,7 @@ def get_documents(): loader_config = loader_config or [] match loader_type: case "file": - document = get_file_documents(FileLoaderConfig(**loader_config)) + document = get_file_documents(FileLoaderConfig(**loader_config),docType) case "web": document = get_web_documents(WebLoaderConfig(**loader_config)) case "db": diff --git a/backend/app/engine/loaders/file.py b/backend/app/engine/loaders/file.py index 1db99ce..b390c99 100644 --- a/backend/app/engine/loaders/file.py +++ b/backend/app/engine/loaders/file.py @@ -20,7 +20,6 @@ class FileLoaderConfig(BaseModel): raise ValueError(f"Directory '{v}' does not exist") return v - def llama_parse_parser(): if os.getenv("LLAMA_CLOUD_API_KEY") is None: raise ValueError( @@ -35,7 +34,6 @@ def llama_parse_parser(): ) return parser - def llama_parse_extractor() -> Dict[str, LlamaParse]: from llama_parse.utils import SUPPORTED_FILE_TYPES @@ -45,8 +43,7 @@ def llama_parse_extractor() -> Dict[str, LlamaParse]: def llama_local_extractor() -> Dict[str, BaseReader]: return {"json" : JSONReader} - -def get_file_documents(config: FileLoaderConfig): +def get_file_documents(config: FileLoaderConfig, childPath: str): from llama_index.core.readers import SimpleDirectoryReader try: @@ -63,7 +60,7 @@ def get_file_documents(config: FileLoaderConfig): file_extractor = llama_local_extractor() reader = SimpleDirectoryReader( - config.data_dir, + os.path.join(config.data_dir,childPath.replace('_','\\')), recursive=True, filename_as_id=True, raise_on_error=True, diff --git a/backend/app/engine/vectordb.py b/backend/app/engine/vectordb.py index f3f2a7d..5a72ca4 100644 --- a/backend/app/engine/vectordb.py +++ b/backend/app/engine/vectordb.py @@ -5,12 +5,14 @@ from qdrant_client import qdrant_client qclient = None -def get_qdrant_vector_store(): - collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default") +def get_qdrant_vector_store(docType:str): + collection_name = docType + #collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default") vector_store_path = os.getenv("VECTOR_STORE_PATH") host=os.getenv("VECTOR_STORE_HOST", "127.0.0.1"), port=int(os.getenv("VECTOR_STORE_PORT", "6333")), - + + vector_store_path =os.path.join(vector_store_path,docType) if not vector_store_path or not host: raise ValueError( "Please provide either VECTOR_STORE_PATH or VECTOR_STORE_HOST and VECTOR_STORE_PORT" @@ -32,9 +34,11 @@ def get_qdrant_vector_store(): vector_store = QdrantVectorStore(client=qclient, collection_name=collection_name) return vector_store -def get_chroma_vector_store(): - collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default") +def get_chroma_vector_store(docType:str): + #collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default") + collection_name = docType vector_store_path = os.getenv("VECTOR_STORE_PATH") + vector_store_path =os.path.join(vector_store_path,docType) # if VECTOR_STORE_PATH is set, use a local ChromaVectorStore from the path # otherwise, use a remote ChromaVectorStore (ChromaDB Cloud is not supported yet) if vector_store_path: @@ -55,16 +59,16 @@ def get_chroma_vector_store(): ) return store -def get_vector_store(): +def get_vector_store(docType:str): store_type=os.getenv("VECTOR_STORE_TYPE") store = None match store_type: case "chroma": - store = get_chroma_vector_store() + store = get_chroma_vector_store(docType) case "qdrant": - store = get_qdrant_vector_store() + store = get_qdrant_vector_store(docType) case _: raise ValueError(f"Invalid vector store type: {store_type}") diff --git a/backend/run-data.bat b/backend/run-data.bat index 8314019..b25bba4 100644 --- a/backend/run-data.bat +++ b/backend/run-data.bat @@ -1,4 +1,4 @@ rmdir /S /Q storage_vector rmdir /S /Q storage -C:\Users\liuyue\AppData\Local\pypoetry\Cache\virtualenvs\app-laEO4lY0-py3.11\Scripts\python app/engine/generate.py \ No newline at end of file +python app/engine/generate.py \ No newline at end of file diff --git a/backend/tests/query.py b/backend/tests/query.py index 48ca304..7ab9517 100644 --- a/backend/tests/query.py +++ b/backend/tests/query.py @@ -19,7 +19,9 @@ def main(): init_settings() init_observability() - index = get_index() + indexs = get_index() + if len(indexs) > 0: + index = list(indexs.values())[0] top_k = 5 filters = generate_filters([])