合并代码

This commit is contained in:
wanyaokun
2024-08-28 19:58:37 +08:00
parent 20510a937b
commit 4020b603b1
12 changed files with 189 additions and 80 deletions
+8 -2
View File
@@ -4,7 +4,7 @@ import logging
from typing import Dict, List, Any, Optional, AsyncGenerator from typing import Dict, List, Any, Optional, AsyncGenerator
from aiostream import stream from aiostream import stream
from fastapi import APIRouter, Request from fastapi import APIRouter, Request,HTTPException
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from llama_index.core import BaseCallbackHandler from llama_index.core import BaseCallbackHandler
from llama_index.core.base.llms.types import ChatMessage from llama_index.core.base.llms.types import ChatMessage
@@ -16,6 +16,7 @@ from app.api.routers.request.base import userMng, conversations,message,paramete
from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest
from app.engine import get_chat_engine from app.engine import get_chat_engine
import uuid import uuid
from app.api.routers.services.fileServices import FileLoadService
logger = logging.getLogger("uvicorn") logger = logging.getLogger("uvicorn")
@@ -473,4 +474,9 @@ async def query_parameters(user:str):
@r.post("") @r.post("")
def upload_file(request: ChatFileUploadRequest) -> List[str]: def upload_file(request: ChatFileUploadRequest) -> List[str]:
pass try:
logger.info("Processing file")
return FileLoadService.process_file(request.base64)
except Exception as e:
logger.error(f"Error processing file: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Error processing file")
@@ -0,0 +1,55 @@
import base64,os
from typing import List
from uuid import uuid4
import requests
from app.settings import init_settings
from app.engine.loaders import get_document_Types, get_documents,getFileCacahePath
from app.engine.vectordb import get_vector_store
from app.engine.generate import get_doc_store,run_pipeline,persist_storage
STORAGE_DIR = os.getenv("STORAGE_DIR", "storage")
class FileLoadService:
@staticmethod
def store_and_parse_file(file_data):
prjtoJson_url = os.getenv('PRJTOJSON_URL')
convert_url = prjtoJson_url +'/prj_convert_clt2json'
files ={'file':file_data}
response1 = requests.post(
url = convert_url,
files=files
)
load_url = prjtoJson_url +'/file_download'
response2 = requests.post(
url = load_url,
data=response1.text
)
with open('example.zip','wb') as file:
file.write(response2.content)
prjID = str(uuid4())
filePath = getFileCacahePath() + f'/Projects/{prjID}'
os.makedirs(filePath)
import zipfile
with zipfile.ZipFile('example.zip','r') as zip_File:
for zip_info in zip_File.infolist():
zip_info.filename = zip_info.filename.encode('cp437').decode('gbk')
zip_File.extract(zip_info,filePath)
os.remove('example.zip')
return f'Projects_{prjID}'
@staticmethod
def process_file(base64_content: str) -> List[str]:
docType = FileLoadService.store_and_parse_file(base64_content)
#生成向量并持久化至本地
init_settings()
documents = get_documents(docType)
for doc in documents:
doc.metadata["private"] = "false"
docstore = get_doc_store(docType)
vector_store = get_vector_store(docType)
_ = run_pipeline(docstore, vector_store, documents)
persist_storage(docstore, vector_store)
+3 -1
View File
@@ -87,7 +87,9 @@ class PrivateFileService:
nodes = pipeline.run(documents=documents) nodes = pipeline.run(documents=documents)
# Add the nodes to the index and persist it # Add the nodes to the index and persist it
current_index = get_index() indexs = get_index()
if len(indexs) > 0:
current_index = list(indexs.values())[0]
# Insert the documents into the index # Insert the documents into the index
if isinstance(current_index, LlamaCloudIndex): if isinstance(current_index, LlamaCloudIndex):
+3 -1
View File
@@ -25,7 +25,9 @@ def get_chat_engine(filters=None, params=None):
#tools.append(sql_query_tool) #tools.append(sql_query_tool)
# Add query tool if index exists # Add query tool if index exists
index = get_index() indexs = get_index()
if len(indexs) > 0:
index = list(indexs.values())[0]
if index is not None: if index is not None:
summary_query_engine = create_summary_query_engine(index,top_k,use_reranker,filters) summary_query_engine = create_summary_query_engine(index,top_k,use_reranker,filters)
summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool", summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool",
+14 -13
View File
@@ -1,22 +1,23 @@
import logging import logging
from llama_index.core.indices import VectorStoreIndex from llama_index.core.indices import VectorStoreIndex
from app.engine.vectordb import get_vector_store from app.engine.vectordb import get_vector_store
from app.engine.loaders import get_document_Types
logger = logging.getLogger("uvicorn") logger = logging.getLogger("uvicorn")
index = None indexs = {}
def get_index(params=None): def get_index(params=None):
global index global indexs
if index is None: if len(index) <= 0:
logger.info("Connecting vector store...") logger.info("Connecting vector store...")
docTypes = get_document_Types()
store = get_vector_store() for docType in docTypes:
# Load the index from the vector store store = get_vector_store(docType)
# If you are using a vector store that doesn't store text, # Load the index from the vector store
# you must load the index from both the vector store and the document store # If you are using a vector store that doesn't store text,
index = VectorStoreIndex.from_vector_store(store) # you must load the index from both the vector store and the document store
logger.info("Finished load index from vector store.") index = VectorStoreIndex.from_vector_store(store)
logger.info("Finished load index from vector store.")
return index indexs[docType] = index
return indexs
+60 -20
View File
@@ -3,39 +3,79 @@ import yaml
from app.engine.loaders.db import DBLoaderConfig, get_db_documents from app.engine.loaders.db import DBLoaderConfig, get_db_documents
from app.engine.loaders.file import FileLoaderConfig, get_file_documents from app.engine.loaders.file import FileLoaderConfig, get_file_documents
from app.engine.loaders.web import WebLoaderConfig, get_web_documents from app.engine.loaders.web import WebLoaderConfig, get_web_documents
import os
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def load_configs(): def load_configs():
with open("config/loaders.yaml") as f: with open("config/loaders.yaml",encoding='utf-8') as f:
configs = yaml.safe_load(f) configs = yaml.safe_load(f)
return configs return configs
def path_difference(path1:str, path2:str):
import os
path1 = os.path.abspath(path1)
path2 = os.path.abspath(path2)
def get_documents(): path1_parts = path1.split(os.path.sep)
path2_parts = path2.split(os.path.sep)
for i, part in enumerate(path1_parts):
if part != path2_parts[i]:
break
else:
i += 1
pathKey = ''
for j in range(i,len(path2_parts)):
pathKey+=path2_parts[j] + '_'
return pathKey[0:-1]
def getFileCacahePath():
rootPath = 'data'
configs = load_configs()
if configs is not None and len(configs.items()) > 0:
for loader_type, loader_config in configs.items():
if loader_type == "file":
rootPath = FileLoaderConfig(**loader_config).data_dir
break
return rootPath
def get_document_Types():
rootPath = getFileCacahePath()
types = []
dirStack = [rootPath]
while len(dirStack) > 0:
curDir = dirStack.pop()
dirs = [os.path.join(curDir, d) for d in os.listdir(curDir) if os.path.isdir(os.path.join(curDir, d))]
if len(dirs) > 0:
for dir in dirs:
dirStack.append(dir)
else:
types.append(path_difference(rootPath,curDir))
return types
def get_documents(docType:str):
documents = [] documents = []
config = load_configs() config = load_configs()
if config is None or len(config.items()) == 0: if config is None or len(config.items()) == 0:
return documents return documents
for loader_type, loader_config in config.items(): for loader_type, loader_config in config.items():
if loader_config.get('enable', True): # 检查 enable 字段 logger.info(
logger.info( f"Loading documents from loader: {loader_type}, config: {loader_config}"
f"Loading documents from loader: {loader_type}, config: {loader_config}" )
)
loader_config = loader_config or [] loader_config = loader_config or []
match loader_type: match loader_type:
case "file": case "file":
document = get_file_documents(FileLoaderConfig(**loader_config)) document = get_file_documents(FileLoaderConfig(**loader_config),docType)
case "web": case "web":
document = get_web_documents(WebLoaderConfig(**loader_config)) document = get_web_documents(WebLoaderConfig(**loader_config))
case "db": case "db":
document = get_db_documents(configs=[DBLoaderConfig(**cfg) for cfg in loader_config]) document = get_db_documents(configs=[DBLoaderConfig(**cfg) for cfg in loader_config])
case _: case _:
raise ValueError(f"Invalid loader type: {loader_type}") raise ValueError(f"Invalid loader type: {loader_type}")
documents.extend(document) documents.extend(document)
return documents return documents
+2 -2
View File
@@ -46,7 +46,7 @@ def llama_local_extractor() -> Dict[str, BaseReader]:
return {".json" : JSONReader(clean_json=False,levels_back=0)} return {".json" : JSONReader(clean_json=False,levels_back=0)}
def get_file_documents(config: FileLoaderConfig): def get_file_documents(config: FileLoaderConfig,childPath: str):
from llama_index.core.readers import SimpleDirectoryReader from llama_index.core.readers import SimpleDirectoryReader
try: try:
@@ -63,7 +63,7 @@ def get_file_documents(config: FileLoaderConfig):
file_extractor = llama_local_extractor() file_extractor = llama_local_extractor()
reader = SimpleDirectoryReader( reader = SimpleDirectoryReader(
config.data_dir, os.path.join(config.data_dir,childPath.replace('_','\\')),
recursive=True, recursive=True,
filename_as_id=True, filename_as_id=True,
raise_on_error=True, raise_on_error=True,
+9 -8
View File
@@ -5,12 +5,13 @@ from qdrant_client import qdrant_client
qclient = None qclient = None
def get_qdrant_vector_store(): def get_qdrant_vector_store(docType:str):
collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default") collection_name = docType
vector_store_path = os.getenv("VECTOR_STORE_PATH") vector_store_path = os.getenv("VECTOR_STORE_PATH")
host=os.getenv("VECTOR_STORE_HOST", "127.0.0.1"), host=os.getenv("VECTOR_STORE_HOST", "127.0.0.1"),
port=int(os.getenv("VECTOR_STORE_PORT", "6333")), port=int(os.getenv("VECTOR_STORE_PORT", "6333")),
vector_store_path =os.path.join(vector_store_path,docType)
if not vector_store_path or not host: if not vector_store_path or not host:
raise ValueError( raise ValueError(
"Please provide either VECTOR_STORE_PATH or VECTOR_STORE_HOST and VECTOR_STORE_PORT" "Please provide either VECTOR_STORE_PATH or VECTOR_STORE_HOST and VECTOR_STORE_PORT"
@@ -32,9 +33,9 @@ def get_qdrant_vector_store():
vector_store = QdrantVectorStore(client=qclient, collection_name=collection_name) vector_store = QdrantVectorStore(client=qclient, collection_name=collection_name)
return vector_store return vector_store
def get_chroma_vector_store(): def get_chroma_vector_store(docType:str):
collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default") collection_name = docType
vector_store_path = os.getenv("VECTOR_STORE_PATH") vector_store_path =os.path.join(os.getenv("VECTOR_STORE_PATH"),docType)
# if VECTOR_STORE_PATH is set, use a local ChromaVectorStore from the path # if VECTOR_STORE_PATH is set, use a local ChromaVectorStore from the path
# otherwise, use a remote ChromaVectorStore (ChromaDB Cloud is not supported yet) # otherwise, use a remote ChromaVectorStore (ChromaDB Cloud is not supported yet)
if vector_store_path: if vector_store_path:
@@ -55,16 +56,16 @@ def get_chroma_vector_store():
) )
return store return store
def get_vector_store(): def get_vector_store(docType:str):
store_type=os.getenv("VECTOR_STORE_TYPE") store_type=os.getenv("VECTOR_STORE_TYPE")
store = None store = None
match store_type: match store_type:
case "chroma": case "chroma":
store = get_chroma_vector_store() store = get_chroma_vector_store(docType)
case "qdrant": case "qdrant":
store = get_qdrant_vector_store() store = get_qdrant_vector_store(docType)
case _: case _:
raise ValueError(f"Invalid vector store type: {store_type}") raise ValueError(f"Invalid vector store type: {store_type}")
+32 -32
View File
@@ -3,46 +3,46 @@ file:
# use_llama_parse: Use LlamaParse if `true`. Needs a `LLAMA_CLOUD_API_KEY` from https://cloud.llamaindex.ai set as environment variable # use_llama_parse: Use LlamaParse if `true`. Needs a `LLAMA_CLOUD_API_KEY` from https://cloud.llamaindex.ai set as environment variable
use_llama_parse: false use_llama_parse: false
db: #db:
# The configuration for the database loader, only supports MySQL and PostgreSQL databases for now. # The configuration for the database loader, only supports MySQL and PostgreSQL databases for now.
# uri: The URI for the database. E.g.: mysql+pymysql://user:password@localhost:3306/db or postgresql+psycopg2://user:password@localhost:5432/db # uri: The URI for the database. E.g.: mysql+pymysql://user:password@localhost:3306/db or postgresql+psycopg2://user:password@localhost:5432/db
# query: The query to fetch data from the database. E.g.: SELECT * FROM table # query: The query to fetch data from the database. E.g.: SELECT * FROM table
- uri: mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1 #- uri: mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
enable: true # 添加 enable 字段 #enable: true # 添加 enable 字段
queries: #queries:
- sql: select * from ProjectProperties; #- sql: select * from ProjectProperties;
explanation: "工程属性表数据,层级关系包含在博微电力造价工程文件格式_ProjectProperties.json文件中。" #explanation: "工程属性表数据,层级关系包含在博微电力造价工程文件格式_ProjectProperties.json文件中。"
- sql: select Id, ParentId, Level, Name, Code, Amount, Amount_Total from TotalCalculateTable; #- sql: select Id, ParentId, Level, Name, Code, Amount, Amount_Total from TotalCalculateTable;
explanation: "总算表数据,层级关系包含在博微电力造价工程文件格式_TotalCalculateTable.json文件中。" #explanation: "总算表数据,层级关系包含在博微电力造价工程文件格式_TotalCalculateTable.json文件中。"
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '线路'; #- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '线路';
explanation: "专业类型为线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" #explanation: "专业类型为线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '余物清理'; #- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '余物清理';
explanation: "专业类型为余物清理的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" #explanation: "专业类型为余物清理的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '拆除线路'; #- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '拆除线路';
explanation: "专业类型为拆除线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" #explanation: "专业类型为拆除线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
- sql: select Id, ParentId, Level, Name, Code, Rate, Amount from OtherFee; #- sql: select Id, ParentId, Level, Name, Code, Rate, Amount from OtherFee;
explanation: "其他费用表数据,层级关系包含在博微电力造价工程文件格式_OtherFee.json文件中" #explanation: "其他费用表数据,层级关系包含在博微电力造价工程文件格式_OtherFee.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表' #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" # explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(调试工程)aa' #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(调试工程)aa'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" #explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '大型土石方取费表' #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '大型土石方取费表'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" #explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)' #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" #explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)(1)' #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)(1)'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" #explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(拆除)' #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(拆除)'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" #explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '线路' #- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '线路'
explanation: "专业类型为线路的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中" #explanation: "专业类型为线路的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '余物清理' #- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '余物清理'
explanation: "专业类型为余物清理的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中" #explanation: "专业类型为余物清理的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
#web: #web:
# driver_arguments: # driver_arguments:
# # The arguments to pass to the webdriver. E.g.: add --headless to run in headless mode # # The arguments to pass to the webdriver. E.g.: add --headless to run in headless mode
Binary file not shown.
+3 -1
View File
@@ -19,7 +19,9 @@ def main():
init_settings() init_settings()
init_observability() init_observability()
index = get_index() indexs = get_index()
if len(indexs) > 0:
index = list(indexs.values())[0]
top_k = 5 top_k = 5
filters = generate_filters([]) filters = generate_filters([])