合并代码
This commit is contained in:
@@ -4,7 +4,7 @@ import logging
|
||||
from typing import Dict, List, Any, Optional, AsyncGenerator
|
||||
|
||||
from aiostream import stream
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi import APIRouter, Request,HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from llama_index.core import BaseCallbackHandler
|
||||
from llama_index.core.base.llms.types import ChatMessage
|
||||
@@ -16,6 +16,7 @@ from app.api.routers.request.base import userMng, conversations,message,paramete
|
||||
from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest
|
||||
from app.engine import get_chat_engine
|
||||
import uuid
|
||||
from app.api.routers.services.fileServices import FileLoadService
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
@@ -473,4 +474,9 @@ async def query_parameters(user:str):
|
||||
|
||||
@r.post("")
|
||||
def upload_file(request: ChatFileUploadRequest) -> List[str]:
|
||||
pass
|
||||
try:
|
||||
logger.info("Processing file")
|
||||
return FileLoadService.process_file(request.base64)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Error processing file")
|
||||
@@ -0,0 +1,55 @@
|
||||
import base64,os
|
||||
from typing import List
|
||||
from uuid import uuid4
|
||||
import requests
|
||||
from app.settings import init_settings
|
||||
from app.engine.loaders import get_document_Types, get_documents,getFileCacahePath
|
||||
from app.engine.vectordb import get_vector_store
|
||||
from app.engine.generate import get_doc_store,run_pipeline,persist_storage
|
||||
|
||||
|
||||
STORAGE_DIR = os.getenv("STORAGE_DIR", "storage")
|
||||
|
||||
class FileLoadService:
|
||||
@staticmethod
|
||||
def store_and_parse_file(file_data):
|
||||
prjtoJson_url = os.getenv('PRJTOJSON_URL')
|
||||
convert_url = prjtoJson_url +'/prj_convert_clt2json'
|
||||
files ={'file':file_data}
|
||||
response1 = requests.post(
|
||||
url = convert_url,
|
||||
files=files
|
||||
)
|
||||
load_url = prjtoJson_url +'/file_download'
|
||||
response2 = requests.post(
|
||||
url = load_url,
|
||||
data=response1.text
|
||||
)
|
||||
|
||||
with open('example.zip','wb') as file:
|
||||
file.write(response2.content)
|
||||
|
||||
prjID = str(uuid4())
|
||||
filePath = getFileCacahePath() + f'/Projects/{prjID}'
|
||||
os.makedirs(filePath)
|
||||
import zipfile
|
||||
with zipfile.ZipFile('example.zip','r') as zip_File:
|
||||
for zip_info in zip_File.infolist():
|
||||
zip_info.filename = zip_info.filename.encode('cp437').decode('gbk')
|
||||
zip_File.extract(zip_info,filePath)
|
||||
os.remove('example.zip')
|
||||
return f'Projects_{prjID}'
|
||||
|
||||
@staticmethod
|
||||
def process_file(base64_content: str) -> List[str]:
|
||||
docType = FileLoadService.store_and_parse_file(base64_content)
|
||||
#生成向量并持久化至本地
|
||||
init_settings()
|
||||
documents = get_documents(docType)
|
||||
for doc in documents:
|
||||
doc.metadata["private"] = "false"
|
||||
docstore = get_doc_store(docType)
|
||||
vector_store = get_vector_store(docType)
|
||||
_ = run_pipeline(docstore, vector_store, documents)
|
||||
persist_storage(docstore, vector_store)
|
||||
|
||||
@@ -87,7 +87,9 @@ class PrivateFileService:
|
||||
nodes = pipeline.run(documents=documents)
|
||||
|
||||
# Add the nodes to the index and persist it
|
||||
current_index = get_index()
|
||||
indexs = get_index()
|
||||
if len(indexs) > 0:
|
||||
current_index = list(indexs.values())[0]
|
||||
|
||||
# Insert the documents into the index
|
||||
if isinstance(current_index, LlamaCloudIndex):
|
||||
|
||||
@@ -25,7 +25,9 @@ def get_chat_engine(filters=None, params=None):
|
||||
#tools.append(sql_query_tool)
|
||||
|
||||
# Add query tool if index exists
|
||||
index = get_index()
|
||||
indexs = get_index()
|
||||
if len(indexs) > 0:
|
||||
index = list(indexs.values())[0]
|
||||
if index is not None:
|
||||
summary_query_engine = create_summary_query_engine(index,top_k,use_reranker,filters)
|
||||
summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool",
|
||||
|
||||
@@ -1,22 +1,23 @@
|
||||
import logging
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from app.engine.vectordb import get_vector_store
|
||||
|
||||
from app.engine.loaders import get_document_Types
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
index = None
|
||||
indexs = {}
|
||||
|
||||
def get_index(params=None):
|
||||
global index
|
||||
if index is None:
|
||||
global indexs
|
||||
if len(index) <= 0:
|
||||
logger.info("Connecting vector store...")
|
||||
|
||||
store = get_vector_store()
|
||||
docTypes = get_document_Types()
|
||||
for docType in docTypes:
|
||||
store = get_vector_store(docType)
|
||||
# Load the index from the vector store
|
||||
# If you are using a vector store that doesn't store text,
|
||||
# you must load the index from both the vector store and the document store
|
||||
index = VectorStoreIndex.from_vector_store(store)
|
||||
logger.info("Finished load index from vector store.")
|
||||
|
||||
return index
|
||||
indexs[docType] = index
|
||||
return indexs
|
||||
|
||||
@@ -3,25 +3,65 @@ import yaml
|
||||
from app.engine.loaders.db import DBLoaderConfig, get_db_documents
|
||||
from app.engine.loaders.file import FileLoaderConfig, get_file_documents
|
||||
from app.engine.loaders.web import WebLoaderConfig, get_web_documents
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_configs():
|
||||
with open("config/loaders.yaml") as f:
|
||||
with open("config/loaders.yaml",encoding='utf-8') as f:
|
||||
configs = yaml.safe_load(f)
|
||||
return configs
|
||||
|
||||
def path_difference(path1:str, path2:str):
|
||||
import os
|
||||
path1 = os.path.abspath(path1)
|
||||
path2 = os.path.abspath(path2)
|
||||
|
||||
def get_documents():
|
||||
path1_parts = path1.split(os.path.sep)
|
||||
path2_parts = path2.split(os.path.sep)
|
||||
|
||||
for i, part in enumerate(path1_parts):
|
||||
if part != path2_parts[i]:
|
||||
break
|
||||
else:
|
||||
i += 1
|
||||
|
||||
pathKey = ''
|
||||
for j in range(i,len(path2_parts)):
|
||||
pathKey+=path2_parts[j] + '_'
|
||||
return pathKey[0:-1]
|
||||
|
||||
def getFileCacahePath():
|
||||
rootPath = 'data'
|
||||
configs = load_configs()
|
||||
if configs is not None and len(configs.items()) > 0:
|
||||
for loader_type, loader_config in configs.items():
|
||||
if loader_type == "file":
|
||||
rootPath = FileLoaderConfig(**loader_config).data_dir
|
||||
break
|
||||
return rootPath
|
||||
|
||||
def get_document_Types():
|
||||
rootPath = getFileCacahePath()
|
||||
types = []
|
||||
dirStack = [rootPath]
|
||||
while len(dirStack) > 0:
|
||||
curDir = dirStack.pop()
|
||||
dirs = [os.path.join(curDir, d) for d in os.listdir(curDir) if os.path.isdir(os.path.join(curDir, d))]
|
||||
if len(dirs) > 0:
|
||||
for dir in dirs:
|
||||
dirStack.append(dir)
|
||||
else:
|
||||
types.append(path_difference(rootPath,curDir))
|
||||
return types
|
||||
|
||||
def get_documents(docType:str):
|
||||
documents = []
|
||||
config = load_configs()
|
||||
|
||||
if config is None or len(config.items()) == 0:
|
||||
return documents
|
||||
|
||||
for loader_type, loader_config in config.items():
|
||||
if loader_config.get('enable', True): # 检查 enable 字段
|
||||
logger.info(
|
||||
f"Loading documents from loader: {loader_type}, config: {loader_config}"
|
||||
)
|
||||
@@ -29,7 +69,7 @@ def get_documents():
|
||||
loader_config = loader_config or []
|
||||
match loader_type:
|
||||
case "file":
|
||||
document = get_file_documents(FileLoaderConfig(**loader_config))
|
||||
document = get_file_documents(FileLoaderConfig(**loader_config),docType)
|
||||
case "web":
|
||||
document = get_web_documents(WebLoaderConfig(**loader_config))
|
||||
case "db":
|
||||
|
||||
@@ -46,7 +46,7 @@ def llama_local_extractor() -> Dict[str, BaseReader]:
|
||||
return {".json" : JSONReader(clean_json=False,levels_back=0)}
|
||||
|
||||
|
||||
def get_file_documents(config: FileLoaderConfig):
|
||||
def get_file_documents(config: FileLoaderConfig,childPath: str):
|
||||
from llama_index.core.readers import SimpleDirectoryReader
|
||||
|
||||
try:
|
||||
@@ -63,7 +63,7 @@ def get_file_documents(config: FileLoaderConfig):
|
||||
file_extractor = llama_local_extractor()
|
||||
|
||||
reader = SimpleDirectoryReader(
|
||||
config.data_dir,
|
||||
os.path.join(config.data_dir,childPath.replace('_','\\')),
|
||||
recursive=True,
|
||||
filename_as_id=True,
|
||||
raise_on_error=True,
|
||||
|
||||
@@ -5,12 +5,13 @@ from qdrant_client import qdrant_client
|
||||
|
||||
qclient = None
|
||||
|
||||
def get_qdrant_vector_store():
|
||||
collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default")
|
||||
def get_qdrant_vector_store(docType:str):
|
||||
collection_name = docType
|
||||
vector_store_path = os.getenv("VECTOR_STORE_PATH")
|
||||
host=os.getenv("VECTOR_STORE_HOST", "127.0.0.1"),
|
||||
port=int(os.getenv("VECTOR_STORE_PORT", "6333")),
|
||||
|
||||
vector_store_path =os.path.join(vector_store_path,docType)
|
||||
if not vector_store_path or not host:
|
||||
raise ValueError(
|
||||
"Please provide either VECTOR_STORE_PATH or VECTOR_STORE_HOST and VECTOR_STORE_PORT"
|
||||
@@ -32,9 +33,9 @@ def get_qdrant_vector_store():
|
||||
vector_store = QdrantVectorStore(client=qclient, collection_name=collection_name)
|
||||
return vector_store
|
||||
|
||||
def get_chroma_vector_store():
|
||||
collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default")
|
||||
vector_store_path = os.getenv("VECTOR_STORE_PATH")
|
||||
def get_chroma_vector_store(docType:str):
|
||||
collection_name = docType
|
||||
vector_store_path =os.path.join(os.getenv("VECTOR_STORE_PATH"),docType)
|
||||
# if VECTOR_STORE_PATH is set, use a local ChromaVectorStore from the path
|
||||
# otherwise, use a remote ChromaVectorStore (ChromaDB Cloud is not supported yet)
|
||||
if vector_store_path:
|
||||
@@ -55,16 +56,16 @@ def get_chroma_vector_store():
|
||||
)
|
||||
return store
|
||||
|
||||
def get_vector_store():
|
||||
def get_vector_store(docType:str):
|
||||
store_type=os.getenv("VECTOR_STORE_TYPE")
|
||||
|
||||
store = None
|
||||
|
||||
match store_type:
|
||||
case "chroma":
|
||||
store = get_chroma_vector_store()
|
||||
store = get_chroma_vector_store(docType)
|
||||
case "qdrant":
|
||||
store = get_qdrant_vector_store()
|
||||
store = get_qdrant_vector_store(docType)
|
||||
case _:
|
||||
raise ValueError(f"Invalid vector store type: {store_type}")
|
||||
|
||||
|
||||
+32
-32
@@ -3,46 +3,46 @@ file:
|
||||
# use_llama_parse: Use LlamaParse if `true`. Needs a `LLAMA_CLOUD_API_KEY` from https://cloud.llamaindex.ai set as environment variable
|
||||
use_llama_parse: false
|
||||
|
||||
db:
|
||||
#db:
|
||||
# The configuration for the database loader, only supports MySQL and PostgreSQL databases for now.
|
||||
# uri: The URI for the database. E.g.: mysql+pymysql://user:password@localhost:3306/db or postgresql+psycopg2://user:password@localhost:5432/db
|
||||
# query: The query to fetch data from the database. E.g.: SELECT * FROM table
|
||||
- uri: mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
||||
enable: true # 添加 enable 字段
|
||||
queries:
|
||||
- sql: select * from ProjectProperties;
|
||||
explanation: "工程属性表数据,层级关系包含在博微电力造价工程文件格式_ProjectProperties.json文件中。"
|
||||
#- uri: mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
||||
#enable: true # 添加 enable 字段
|
||||
#queries:
|
||||
#- sql: select * from ProjectProperties;
|
||||
#explanation: "工程属性表数据,层级关系包含在博微电力造价工程文件格式_ProjectProperties.json文件中。"
|
||||
|
||||
- sql: select Id, ParentId, Level, Name, Code, Amount, Amount_Total from TotalCalculateTable;
|
||||
explanation: "总算表数据,层级关系包含在博微电力造价工程文件格式_TotalCalculateTable.json文件中。"
|
||||
#- sql: select Id, ParentId, Level, Name, Code, Amount, Amount_Total from TotalCalculateTable;
|
||||
#explanation: "总算表数据,层级关系包含在博微电力造价工程文件格式_TotalCalculateTable.json文件中。"
|
||||
|
||||
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '线路';
|
||||
explanation: "专业类型为线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
|
||||
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '余物清理';
|
||||
explanation: "专业类型为余物清理的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
|
||||
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '拆除线路';
|
||||
explanation: "专业类型为拆除线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
|
||||
#- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '线路';
|
||||
#explanation: "专业类型为线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
|
||||
#- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '余物清理';
|
||||
#explanation: "专业类型为余物清理的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
|
||||
#- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '拆除线路';
|
||||
#explanation: "专业类型为拆除线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
|
||||
|
||||
- sql: select Id, ParentId, Level, Name, Code, Rate, Amount from OtherFee;
|
||||
explanation: "其他费用表数据,层级关系包含在博微电力造价工程文件格式_OtherFee.json文件中"
|
||||
#- sql: select Id, ParentId, Level, Name, Code, Rate, Amount from OtherFee;
|
||||
#explanation: "其他费用表数据,层级关系包含在博微电力造价工程文件格式_OtherFee.json文件中"
|
||||
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表'
|
||||
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(调试工程)aa'
|
||||
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '大型土石方取费表'
|
||||
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)'
|
||||
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)(1)'
|
||||
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(拆除)'
|
||||
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
#- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表'
|
||||
# explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
#- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(调试工程)aa'
|
||||
#explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
#- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '大型土石方取费表'
|
||||
#explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
#- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)'
|
||||
#explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
#- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)(1)'
|
||||
#explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
#- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(拆除)'
|
||||
#explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '线路'
|
||||
explanation: "专业类型为线路的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '余物清理'
|
||||
explanation: "专业类型为余物清理的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
|
||||
#- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '线路'
|
||||
#explanation: "专业类型为线路的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
|
||||
#- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '余物清理'
|
||||
#explanation: "专业类型为余物清理的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
|
||||
#web:
|
||||
# driver_arguments:
|
||||
# # The arguments to pass to the webdriver. E.g.: add --headless to run in headless mode
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -19,7 +19,9 @@ def main():
|
||||
init_settings()
|
||||
init_observability()
|
||||
|
||||
index = get_index()
|
||||
indexs = get_index()
|
||||
if len(indexs) > 0:
|
||||
index = list(indexs.values())[0]
|
||||
|
||||
top_k = 5
|
||||
filters = generate_filters([])
|
||||
|
||||
Reference in New Issue
Block a user