diff --git a/backend/app/api/routers/app.py b/backend/app/api/routers/app.py index c6a60a6..a900501 100644 --- a/backend/app/api/routers/app.py +++ b/backend/app/api/routers/app.py @@ -4,7 +4,7 @@ import logging from typing import Dict, List, Any, Optional, AsyncGenerator from aiostream import stream -from fastapi import APIRouter, Request +from fastapi import APIRouter, Request,HTTPException from fastapi.responses import StreamingResponse from llama_index.core import BaseCallbackHandler from llama_index.core.base.llms.types import ChatMessage @@ -16,6 +16,7 @@ from app.api.routers.request.base import userMng, conversations,message,paramete from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest from app.engine import get_chat_engine import uuid +from app.api.routers.services.fileServices import FileLoadService logger = logging.getLogger("uvicorn") @@ -473,4 +474,9 @@ async def query_parameters(user:str): @r.post("") def upload_file(request: ChatFileUploadRequest) -> List[str]: - pass \ No newline at end of file + try: + logger.info("Processing file") + return FileLoadService.process_file(request.base64) + except Exception as e: + logger.error(f"Error processing file: {e}", exc_info=True) + raise HTTPException(status_code=500, detail="Error processing file") \ No newline at end of file diff --git a/backend/app/api/routers/services/fileServices.py b/backend/app/api/routers/services/fileServices.py new file mode 100644 index 0000000..d63dc47 --- /dev/null +++ b/backend/app/api/routers/services/fileServices.py @@ -0,0 +1,55 @@ +import base64,os +from typing import List +from uuid import uuid4 +import requests +from app.settings import init_settings +from app.engine.loaders import get_document_Types, get_documents,getFileCacahePath +from app.engine.vectordb import get_vector_store +from app.engine.generate import get_doc_store,run_pipeline,persist_storage + + +STORAGE_DIR = os.getenv("STORAGE_DIR", "storage") + +class FileLoadService: + @staticmethod + def store_and_parse_file(file_data): + prjtoJson_url = os.getenv('PRJTOJSON_URL') + convert_url = prjtoJson_url +'/prj_convert_clt2json' + files ={'file':file_data} + response1 = requests.post( + url = convert_url, + files=files + ) + load_url = prjtoJson_url +'/file_download' + response2 = requests.post( + url = load_url, + data=response1.text + ) + + with open('example.zip','wb') as file: + file.write(response2.content) + + prjID = str(uuid4()) + filePath = getFileCacahePath() + f'/Projects/{prjID}' + os.makedirs(filePath) + import zipfile + with zipfile.ZipFile('example.zip','r') as zip_File: + for zip_info in zip_File.infolist(): + zip_info.filename = zip_info.filename.encode('cp437').decode('gbk') + zip_File.extract(zip_info,filePath) + os.remove('example.zip') + return f'Projects_{prjID}' + + @staticmethod + def process_file(base64_content: str) -> List[str]: + docType = FileLoadService.store_and_parse_file(base64_content) + #生成向量并持久化至本地 + init_settings() + documents = get_documents(docType) + for doc in documents: + doc.metadata["private"] = "false" + docstore = get_doc_store(docType) + vector_store = get_vector_store(docType) + _ = run_pipeline(docstore, vector_store, documents) + persist_storage(docstore, vector_store) + diff --git a/backend/app/api/services/file.py b/backend/app/api/services/file.py index a478570..e8eb54c 100644 --- a/backend/app/api/services/file.py +++ b/backend/app/api/services/file.py @@ -87,7 +87,9 @@ class PrivateFileService: nodes = pipeline.run(documents=documents) # Add the nodes to the index and persist it - current_index = get_index() + indexs = get_index() + if len(indexs) > 0: + current_index = list(indexs.values())[0] # Insert the documents into the index if isinstance(current_index, LlamaCloudIndex): diff --git a/backend/app/engine/__init__.py b/backend/app/engine/__init__.py index 4ee1c9c..4ccc26a 100644 --- a/backend/app/engine/__init__.py +++ b/backend/app/engine/__init__.py @@ -25,7 +25,9 @@ def get_chat_engine(filters=None, params=None): #tools.append(sql_query_tool) # Add query tool if index exists - index = get_index() + indexs = get_index() + if len(indexs) > 0: + index = list(indexs.values())[0] if index is not None: summary_query_engine = create_summary_query_engine(index,top_k,use_reranker,filters) summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool", diff --git a/backend/app/engine/index.py b/backend/app/engine/index.py index b21e695..24f4fd1 100644 --- a/backend/app/engine/index.py +++ b/backend/app/engine/index.py @@ -1,22 +1,23 @@ import logging from llama_index.core.indices import VectorStoreIndex from app.engine.vectordb import get_vector_store - +from app.engine.loaders import get_document_Types logger = logging.getLogger("uvicorn") -index = None +indexs = {} def get_index(params=None): - global index - if index is None: + global indexs + if len(index) <= 0: logger.info("Connecting vector store...") - - store = get_vector_store() - # Load the index from the vector store - # If you are using a vector store that doesn't store text, - # you must load the index from both the vector store and the document store - index = VectorStoreIndex.from_vector_store(store) - logger.info("Finished load index from vector store.") - - return index + docTypes = get_document_Types() + for docType in docTypes: + store = get_vector_store(docType) + # Load the index from the vector store + # If you are using a vector store that doesn't store text, + # you must load the index from both the vector store and the document store + index = VectorStoreIndex.from_vector_store(store) + logger.info("Finished load index from vector store.") + indexs[docType] = index + return indexs diff --git a/backend/app/engine/loaders/__init__.py b/backend/app/engine/loaders/__init__.py index 4f585b4..3155028 100644 --- a/backend/app/engine/loaders/__init__.py +++ b/backend/app/engine/loaders/__init__.py @@ -3,39 +3,79 @@ import yaml from app.engine.loaders.db import DBLoaderConfig, get_db_documents from app.engine.loaders.file import FileLoaderConfig, get_file_documents from app.engine.loaders.web import WebLoaderConfig, get_web_documents +import os logger = logging.getLogger(__name__) - def load_configs(): - with open("config/loaders.yaml") as f: + with open("config/loaders.yaml",encoding='utf-8') as f: configs = yaml.safe_load(f) return configs +def path_difference(path1:str, path2:str): + import os + path1 = os.path.abspath(path1) + path2 = os.path.abspath(path2) -def get_documents(): + path1_parts = path1.split(os.path.sep) + path2_parts = path2.split(os.path.sep) + + for i, part in enumerate(path1_parts): + if part != path2_parts[i]: + break + else: + i += 1 + + pathKey = '' + for j in range(i,len(path2_parts)): + pathKey+=path2_parts[j] + '_' + return pathKey[0:-1] + +def getFileCacahePath(): + rootPath = 'data' + configs = load_configs() + if configs is not None and len(configs.items()) > 0: + for loader_type, loader_config in configs.items(): + if loader_type == "file": + rootPath = FileLoaderConfig(**loader_config).data_dir + break + return rootPath + +def get_document_Types(): + rootPath = getFileCacahePath() + types = [] + dirStack = [rootPath] + while len(dirStack) > 0: + curDir = dirStack.pop() + dirs = [os.path.join(curDir, d) for d in os.listdir(curDir) if os.path.isdir(os.path.join(curDir, d))] + if len(dirs) > 0: + for dir in dirs: + dirStack.append(dir) + else: + types.append(path_difference(rootPath,curDir)) + return types + +def get_documents(docType:str): documents = [] config = load_configs() - if config is None or len(config.items()) == 0: - return documents + return documents for loader_type, loader_config in config.items(): - if loader_config.get('enable', True): # 检查 enable 字段 - logger.info( - f"Loading documents from loader: {loader_type}, config: {loader_config}" - ) + logger.info( + f"Loading documents from loader: {loader_type}, config: {loader_config}" + ) - loader_config = loader_config or [] - match loader_type: - case "file": - document = get_file_documents(FileLoaderConfig(**loader_config)) - case "web": - document = get_web_documents(WebLoaderConfig(**loader_config)) - case "db": - document = get_db_documents(configs=[DBLoaderConfig(**cfg) for cfg in loader_config]) - case _: - raise ValueError(f"Invalid loader type: {loader_type}") - documents.extend(document) + loader_config = loader_config or [] + match loader_type: + case "file": + document = get_file_documents(FileLoaderConfig(**loader_config),docType) + case "web": + document = get_web_documents(WebLoaderConfig(**loader_config)) + case "db": + document = get_db_documents(configs=[DBLoaderConfig(**cfg) for cfg in loader_config]) + case _: + raise ValueError(f"Invalid loader type: {loader_type}") + documents.extend(document) return documents \ No newline at end of file diff --git a/backend/app/engine/loaders/file.py b/backend/app/engine/loaders/file.py index dc199db..5a0e648 100644 --- a/backend/app/engine/loaders/file.py +++ b/backend/app/engine/loaders/file.py @@ -46,7 +46,7 @@ def llama_local_extractor() -> Dict[str, BaseReader]: return {".json" : JSONReader(clean_json=False,levels_back=0)} -def get_file_documents(config: FileLoaderConfig): +def get_file_documents(config: FileLoaderConfig,childPath: str): from llama_index.core.readers import SimpleDirectoryReader try: @@ -63,7 +63,7 @@ def get_file_documents(config: FileLoaderConfig): file_extractor = llama_local_extractor() reader = SimpleDirectoryReader( - config.data_dir, + os.path.join(config.data_dir,childPath.replace('_','\\')), recursive=True, filename_as_id=True, raise_on_error=True, diff --git a/backend/app/engine/vectordb.py b/backend/app/engine/vectordb.py index f3f2a7d..7d30c69 100644 --- a/backend/app/engine/vectordb.py +++ b/backend/app/engine/vectordb.py @@ -5,12 +5,13 @@ from qdrant_client import qdrant_client qclient = None -def get_qdrant_vector_store(): - collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default") +def get_qdrant_vector_store(docType:str): + collection_name = docType vector_store_path = os.getenv("VECTOR_STORE_PATH") host=os.getenv("VECTOR_STORE_HOST", "127.0.0.1"), port=int(os.getenv("VECTOR_STORE_PORT", "6333")), + vector_store_path =os.path.join(vector_store_path,docType) if not vector_store_path or not host: raise ValueError( "Please provide either VECTOR_STORE_PATH or VECTOR_STORE_HOST and VECTOR_STORE_PORT" @@ -32,9 +33,9 @@ def get_qdrant_vector_store(): vector_store = QdrantVectorStore(client=qclient, collection_name=collection_name) return vector_store -def get_chroma_vector_store(): - collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default") - vector_store_path = os.getenv("VECTOR_STORE_PATH") +def get_chroma_vector_store(docType:str): + collection_name = docType + vector_store_path =os.path.join(os.getenv("VECTOR_STORE_PATH"),docType) # if VECTOR_STORE_PATH is set, use a local ChromaVectorStore from the path # otherwise, use a remote ChromaVectorStore (ChromaDB Cloud is not supported yet) if vector_store_path: @@ -55,16 +56,16 @@ def get_chroma_vector_store(): ) return store -def get_vector_store(): +def get_vector_store(docType:str): store_type=os.getenv("VECTOR_STORE_TYPE") store = None match store_type: case "chroma": - store = get_chroma_vector_store() + store = get_chroma_vector_store(docType) case "qdrant": - store = get_qdrant_vector_store() + store = get_qdrant_vector_store(docType) case _: raise ValueError(f"Invalid vector store type: {store_type}") diff --git a/backend/config/loaders.yaml b/backend/config/loaders.yaml index af5d2fe..b19f033 100644 --- a/backend/config/loaders.yaml +++ b/backend/config/loaders.yaml @@ -3,46 +3,46 @@ file: # use_llama_parse: Use LlamaParse if `true`. Needs a `LLAMA_CLOUD_API_KEY` from https://cloud.llamaindex.ai set as environment variable use_llama_parse: false -db: +#db: # The configuration for the database loader, only supports MySQL and PostgreSQL databases for now. # uri: The URI for the database. E.g.: mysql+pymysql://user:password@localhost:3306/db or postgresql+psycopg2://user:password@localhost:5432/db # query: The query to fetch data from the database. E.g.: SELECT * FROM table - - uri: mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1 - enable: true # 添加 enable 字段 - queries: - - sql: select * from ProjectProperties; - explanation: "工程属性表数据,层级关系包含在博微电力造价工程文件格式_ProjectProperties.json文件中。" + #- uri: mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1 + #enable: true # 添加 enable 字段 + #queries: + #- sql: select * from ProjectProperties; + #explanation: "工程属性表数据,层级关系包含在博微电力造价工程文件格式_ProjectProperties.json文件中。" - - sql: select Id, ParentId, Level, Name, Code, Amount, Amount_Total from TotalCalculateTable; - explanation: "总算表数据,层级关系包含在博微电力造价工程文件格式_TotalCalculateTable.json文件中。" + #- sql: select Id, ParentId, Level, Name, Code, Amount, Amount_Total from TotalCalculateTable; + #explanation: "总算表数据,层级关系包含在博微电力造价工程文件格式_TotalCalculateTable.json文件中。" - - sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '线路'; - explanation: "专业类型为线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" - - sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '余物清理'; - explanation: "专业类型为余物清理的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" - - sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '拆除线路'; - explanation: "专业类型为拆除线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" + #- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '线路'; + #explanation: "专业类型为线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" + #- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '余物清理'; + #explanation: "专业类型为余物清理的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" + #- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '拆除线路'; + #explanation: "专业类型为拆除线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" - - sql: select Id, ParentId, Level, Name, Code, Rate, Amount from OtherFee; - explanation: "其他费用表数据,层级关系包含在博微电力造价工程文件格式_OtherFee.json文件中" + #- sql: select Id, ParentId, Level, Name, Code, Rate, Amount from OtherFee; + #explanation: "其他费用表数据,层级关系包含在博微电力造价工程文件格式_OtherFee.json文件中" - - sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表' - explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" - - sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(调试工程)aa' - explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" - - sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '大型土石方取费表' - explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" - - sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)' - explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" - - sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)(1)' - explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" - - sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(拆除)' - explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" + #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表' + # explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" + #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(调试工程)aa' + #explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" + #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '大型土石方取费表' + #explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" + #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)' + #explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" + #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)(1)' + #explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" + #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(拆除)' + #explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" - - sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '线路' - explanation: "专业类型为线路的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中" - - sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '余物清理' - explanation: "专业类型为余物清理的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中" + #- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '线路' + #explanation: "专业类型为线路的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中" + #- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '余物清理' + #explanation: "专业类型为余物清理的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中" #web: # driver_arguments: # # The arguments to pass to the webdriver. E.g.: add --headless to run in headless mode diff --git a/backend/data/博微电力造价工程业务数据说明.docx b/backend/data/博微电力造价工程业务数据说明.docx index 670ce04..425772f 100644 Binary files a/backend/data/博微电力造价工程业务数据说明.docx and b/backend/data/博微电力造价工程业务数据说明.docx differ diff --git a/backend/data/工程造价基础知识.docx b/backend/data/工程造价基础知识.docx deleted file mode 100644 index bd7f91d..0000000 Binary files a/backend/data/工程造价基础知识.docx and /dev/null differ diff --git a/backend/tests/query.py b/backend/tests/query.py index 8c82d28..f45d0b5 100644 --- a/backend/tests/query.py +++ b/backend/tests/query.py @@ -19,7 +19,9 @@ def main(): init_settings() init_observability() - index = get_index() + indexs = get_index() + if len(indexs) > 0: + index = list(indexs.values())[0] top_k = 5 filters = generate_filters([])