合并代码

This commit is contained in:
wanyaokun
2024-08-28 19:58:37 +08:00
parent 20510a937b
commit 4020b603b1
12 changed files with 189 additions and 80 deletions
+8 -2
View File
@@ -4,7 +4,7 @@ import logging
from typing import Dict, List, Any, Optional, AsyncGenerator
from aiostream import stream
from fastapi import APIRouter, Request
from fastapi import APIRouter, Request,HTTPException
from fastapi.responses import StreamingResponse
from llama_index.core import BaseCallbackHandler
from llama_index.core.base.llms.types import ChatMessage
@@ -16,6 +16,7 @@ from app.api.routers.request.base import userMng, conversations,message,paramete
from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest
from app.engine import get_chat_engine
import uuid
from app.api.routers.services.fileServices import FileLoadService
logger = logging.getLogger("uvicorn")
@@ -473,4 +474,9 @@ async def query_parameters(user:str):
@r.post("")
def upload_file(request: ChatFileUploadRequest) -> List[str]:
pass
try:
logger.info("Processing file")
return FileLoadService.process_file(request.base64)
except Exception as e:
logger.error(f"Error processing file: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Error processing file")
@@ -0,0 +1,55 @@
import base64,os
from typing import List
from uuid import uuid4
import requests
from app.settings import init_settings
from app.engine.loaders import get_document_Types, get_documents,getFileCacahePath
from app.engine.vectordb import get_vector_store
from app.engine.generate import get_doc_store,run_pipeline,persist_storage
STORAGE_DIR = os.getenv("STORAGE_DIR", "storage")
class FileLoadService:
@staticmethod
def store_and_parse_file(file_data):
prjtoJson_url = os.getenv('PRJTOJSON_URL')
convert_url = prjtoJson_url +'/prj_convert_clt2json'
files ={'file':file_data}
response1 = requests.post(
url = convert_url,
files=files
)
load_url = prjtoJson_url +'/file_download'
response2 = requests.post(
url = load_url,
data=response1.text
)
with open('example.zip','wb') as file:
file.write(response2.content)
prjID = str(uuid4())
filePath = getFileCacahePath() + f'/Projects/{prjID}'
os.makedirs(filePath)
import zipfile
with zipfile.ZipFile('example.zip','r') as zip_File:
for zip_info in zip_File.infolist():
zip_info.filename = zip_info.filename.encode('cp437').decode('gbk')
zip_File.extract(zip_info,filePath)
os.remove('example.zip')
return f'Projects_{prjID}'
@staticmethod
def process_file(base64_content: str) -> List[str]:
docType = FileLoadService.store_and_parse_file(base64_content)
#生成向量并持久化至本地
init_settings()
documents = get_documents(docType)
for doc in documents:
doc.metadata["private"] = "false"
docstore = get_doc_store(docType)
vector_store = get_vector_store(docType)
_ = run_pipeline(docstore, vector_store, documents)
persist_storage(docstore, vector_store)
+3 -1
View File
@@ -87,7 +87,9 @@ class PrivateFileService:
nodes = pipeline.run(documents=documents)
# Add the nodes to the index and persist it
current_index = get_index()
indexs = get_index()
if len(indexs) > 0:
current_index = list(indexs.values())[0]
# Insert the documents into the index
if isinstance(current_index, LlamaCloudIndex):
+3 -1
View File
@@ -25,7 +25,9 @@ def get_chat_engine(filters=None, params=None):
#tools.append(sql_query_tool)
# Add query tool if index exists
index = get_index()
indexs = get_index()
if len(indexs) > 0:
index = list(indexs.values())[0]
if index is not None:
summary_query_engine = create_summary_query_engine(index,top_k,use_reranker,filters)
summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool",
+14 -13
View File
@@ -1,22 +1,23 @@
import logging
from llama_index.core.indices import VectorStoreIndex
from app.engine.vectordb import get_vector_store
from app.engine.loaders import get_document_Types
logger = logging.getLogger("uvicorn")
index = None
indexs = {}
def get_index(params=None):
global index
if index is None:
global indexs
if len(index) <= 0:
logger.info("Connecting vector store...")
store = get_vector_store()
# Load the index from the vector store
# If you are using a vector store that doesn't store text,
# you must load the index from both the vector store and the document store
index = VectorStoreIndex.from_vector_store(store)
logger.info("Finished load index from vector store.")
return index
docTypes = get_document_Types()
for docType in docTypes:
store = get_vector_store(docType)
# Load the index from the vector store
# If you are using a vector store that doesn't store text,
# you must load the index from both the vector store and the document store
index = VectorStoreIndex.from_vector_store(store)
logger.info("Finished load index from vector store.")
indexs[docType] = index
return indexs
+60 -20
View File
@@ -3,39 +3,79 @@ import yaml
from app.engine.loaders.db import DBLoaderConfig, get_db_documents
from app.engine.loaders.file import FileLoaderConfig, get_file_documents
from app.engine.loaders.web import WebLoaderConfig, get_web_documents
import os
logger = logging.getLogger(__name__)
def load_configs():
with open("config/loaders.yaml") as f:
with open("config/loaders.yaml",encoding='utf-8') as f:
configs = yaml.safe_load(f)
return configs
def path_difference(path1:str, path2:str):
import os
path1 = os.path.abspath(path1)
path2 = os.path.abspath(path2)
def get_documents():
path1_parts = path1.split(os.path.sep)
path2_parts = path2.split(os.path.sep)
for i, part in enumerate(path1_parts):
if part != path2_parts[i]:
break
else:
i += 1
pathKey = ''
for j in range(i,len(path2_parts)):
pathKey+=path2_parts[j] + '_'
return pathKey[0:-1]
def getFileCacahePath():
rootPath = 'data'
configs = load_configs()
if configs is not None and len(configs.items()) > 0:
for loader_type, loader_config in configs.items():
if loader_type == "file":
rootPath = FileLoaderConfig(**loader_config).data_dir
break
return rootPath
def get_document_Types():
rootPath = getFileCacahePath()
types = []
dirStack = [rootPath]
while len(dirStack) > 0:
curDir = dirStack.pop()
dirs = [os.path.join(curDir, d) for d in os.listdir(curDir) if os.path.isdir(os.path.join(curDir, d))]
if len(dirs) > 0:
for dir in dirs:
dirStack.append(dir)
else:
types.append(path_difference(rootPath,curDir))
return types
def get_documents(docType:str):
documents = []
config = load_configs()
if config is None or len(config.items()) == 0:
return documents
return documents
for loader_type, loader_config in config.items():
if loader_config.get('enable', True): # 检查 enable 字段
logger.info(
f"Loading documents from loader: {loader_type}, config: {loader_config}"
)
logger.info(
f"Loading documents from loader: {loader_type}, config: {loader_config}"
)
loader_config = loader_config or []
match loader_type:
case "file":
document = get_file_documents(FileLoaderConfig(**loader_config))
case "web":
document = get_web_documents(WebLoaderConfig(**loader_config))
case "db":
document = get_db_documents(configs=[DBLoaderConfig(**cfg) for cfg in loader_config])
case _:
raise ValueError(f"Invalid loader type: {loader_type}")
documents.extend(document)
loader_config = loader_config or []
match loader_type:
case "file":
document = get_file_documents(FileLoaderConfig(**loader_config),docType)
case "web":
document = get_web_documents(WebLoaderConfig(**loader_config))
case "db":
document = get_db_documents(configs=[DBLoaderConfig(**cfg) for cfg in loader_config])
case _:
raise ValueError(f"Invalid loader type: {loader_type}")
documents.extend(document)
return documents
+2 -2
View File
@@ -46,7 +46,7 @@ def llama_local_extractor() -> Dict[str, BaseReader]:
return {".json" : JSONReader(clean_json=False,levels_back=0)}
def get_file_documents(config: FileLoaderConfig):
def get_file_documents(config: FileLoaderConfig,childPath: str):
from llama_index.core.readers import SimpleDirectoryReader
try:
@@ -63,7 +63,7 @@ def get_file_documents(config: FileLoaderConfig):
file_extractor = llama_local_extractor()
reader = SimpleDirectoryReader(
config.data_dir,
os.path.join(config.data_dir,childPath.replace('_','\\')),
recursive=True,
filename_as_id=True,
raise_on_error=True,
+9 -8
View File
@@ -5,12 +5,13 @@ from qdrant_client import qdrant_client
qclient = None
def get_qdrant_vector_store():
collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default")
def get_qdrant_vector_store(docType:str):
collection_name = docType
vector_store_path = os.getenv("VECTOR_STORE_PATH")
host=os.getenv("VECTOR_STORE_HOST", "127.0.0.1"),
port=int(os.getenv("VECTOR_STORE_PORT", "6333")),
vector_store_path =os.path.join(vector_store_path,docType)
if not vector_store_path or not host:
raise ValueError(
"Please provide either VECTOR_STORE_PATH or VECTOR_STORE_HOST and VECTOR_STORE_PORT"
@@ -32,9 +33,9 @@ def get_qdrant_vector_store():
vector_store = QdrantVectorStore(client=qclient, collection_name=collection_name)
return vector_store
def get_chroma_vector_store():
collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default")
vector_store_path = os.getenv("VECTOR_STORE_PATH")
def get_chroma_vector_store(docType:str):
collection_name = docType
vector_store_path =os.path.join(os.getenv("VECTOR_STORE_PATH"),docType)
# if VECTOR_STORE_PATH is set, use a local ChromaVectorStore from the path
# otherwise, use a remote ChromaVectorStore (ChromaDB Cloud is not supported yet)
if vector_store_path:
@@ -55,16 +56,16 @@ def get_chroma_vector_store():
)
return store
def get_vector_store():
def get_vector_store(docType:str):
store_type=os.getenv("VECTOR_STORE_TYPE")
store = None
match store_type:
case "chroma":
store = get_chroma_vector_store()
store = get_chroma_vector_store(docType)
case "qdrant":
store = get_qdrant_vector_store()
store = get_qdrant_vector_store(docType)
case _:
raise ValueError(f"Invalid vector store type: {store_type}")