合并代码
This commit is contained in:
@@ -4,7 +4,7 @@ import logging
|
||||
from typing import Dict, List, Any, Optional, AsyncGenerator
|
||||
|
||||
from aiostream import stream
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi import APIRouter, Request,HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from llama_index.core import BaseCallbackHandler
|
||||
from llama_index.core.base.llms.types import ChatMessage
|
||||
@@ -16,6 +16,7 @@ from app.api.routers.request.base import userMng, conversations,message,paramete
|
||||
from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest
|
||||
from app.engine import get_chat_engine
|
||||
import uuid
|
||||
from app.api.routers.services.fileServices import FileLoadService
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
@@ -473,4 +474,9 @@ async def query_parameters(user:str):
|
||||
|
||||
@r.post("")
|
||||
def upload_file(request: ChatFileUploadRequest) -> List[str]:
|
||||
pass
|
||||
try:
|
||||
logger.info("Processing file")
|
||||
return FileLoadService.process_file(request.base64)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Error processing file")
|
||||
@@ -0,0 +1,55 @@
|
||||
import base64,os
|
||||
from typing import List
|
||||
from uuid import uuid4
|
||||
import requests
|
||||
from app.settings import init_settings
|
||||
from app.engine.loaders import get_document_Types, get_documents,getFileCacahePath
|
||||
from app.engine.vectordb import get_vector_store
|
||||
from app.engine.generate import get_doc_store,run_pipeline,persist_storage
|
||||
|
||||
|
||||
STORAGE_DIR = os.getenv("STORAGE_DIR", "storage")
|
||||
|
||||
class FileLoadService:
|
||||
@staticmethod
|
||||
def store_and_parse_file(file_data):
|
||||
prjtoJson_url = os.getenv('PRJTOJSON_URL')
|
||||
convert_url = prjtoJson_url +'/prj_convert_clt2json'
|
||||
files ={'file':file_data}
|
||||
response1 = requests.post(
|
||||
url = convert_url,
|
||||
files=files
|
||||
)
|
||||
load_url = prjtoJson_url +'/file_download'
|
||||
response2 = requests.post(
|
||||
url = load_url,
|
||||
data=response1.text
|
||||
)
|
||||
|
||||
with open('example.zip','wb') as file:
|
||||
file.write(response2.content)
|
||||
|
||||
prjID = str(uuid4())
|
||||
filePath = getFileCacahePath() + f'/Projects/{prjID}'
|
||||
os.makedirs(filePath)
|
||||
import zipfile
|
||||
with zipfile.ZipFile('example.zip','r') as zip_File:
|
||||
for zip_info in zip_File.infolist():
|
||||
zip_info.filename = zip_info.filename.encode('cp437').decode('gbk')
|
||||
zip_File.extract(zip_info,filePath)
|
||||
os.remove('example.zip')
|
||||
return f'Projects_{prjID}'
|
||||
|
||||
@staticmethod
|
||||
def process_file(base64_content: str) -> List[str]:
|
||||
docType = FileLoadService.store_and_parse_file(base64_content)
|
||||
#生成向量并持久化至本地
|
||||
init_settings()
|
||||
documents = get_documents(docType)
|
||||
for doc in documents:
|
||||
doc.metadata["private"] = "false"
|
||||
docstore = get_doc_store(docType)
|
||||
vector_store = get_vector_store(docType)
|
||||
_ = run_pipeline(docstore, vector_store, documents)
|
||||
persist_storage(docstore, vector_store)
|
||||
|
||||
Reference in New Issue
Block a user