新增工程信息、检索的知识片段节点回传、下一轮建议问题列表
This commit is contained in:
@@ -1,16 +1,23 @@
|
||||
import base64,os
|
||||
from typing import List
|
||||
import base64,os,mimetypes,requests,tempfile
|
||||
from typing import List,Dict,Any
|
||||
from uuid import uuid4
|
||||
import requests
|
||||
from app.settings import init_settings
|
||||
from app.engine.loaders import get_document_Types, get_documents,getFileCacahePath
|
||||
from app.engine.vectordb import get_vector_store
|
||||
from app.engine.generate import get_doc_store,run_pipeline,persist_storage
|
||||
import tempfile
|
||||
from llama_index.core.schema import Document
|
||||
from pathlib import Path
|
||||
from llama_index.core.readers.file.base import (
|
||||
_try_loading_included_file_formats as get_file_loaders_map,
|
||||
)
|
||||
from llama_index.readers.file import FlatReader
|
||||
from llama_index.core.ingestion import IngestionPipeline
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from app.engine.index import get_index
|
||||
|
||||
STORAGE_DIR = os.getenv("STORAGE_DIR", "storage")
|
||||
|
||||
class FileLoadService:
|
||||
class PrjFileLoadService:
|
||||
@staticmethod
|
||||
def store_and_parse_file(file_data):
|
||||
prjtoJson_url = os.getenv('PRJTOJSON_URL')
|
||||
@@ -20,29 +27,40 @@ class FileLoadService:
|
||||
url = convert_url,
|
||||
files=files
|
||||
)
|
||||
if response1.text is None or response1.text=='':
|
||||
return None
|
||||
|
||||
load_url = prjtoJson_url +'/file_download'
|
||||
response2 = requests.post(
|
||||
url = load_url,
|
||||
data=response1.text
|
||||
)
|
||||
tempFilePath:str = tempfile.gettempdir() + f"\\{str(uuid4())}.zip"
|
||||
with open(tempFilePath,'wb') as file:
|
||||
file.write(response2.content)
|
||||
|
||||
prjID = str(uuid4())
|
||||
filePath = getFileCacahePath() + f'/Projects/{prjID}'
|
||||
os.makedirs(filePath)
|
||||
import zipfile
|
||||
with zipfile.ZipFile(tempFilePath,'r') as zip_File:
|
||||
for zip_info in zip_File.infolist():
|
||||
zip_info.filename = zip_info.filename.encode('cp437').decode('gbk')
|
||||
zip_File.extract(zip_info,filePath)
|
||||
os.remove(tempFilePath)
|
||||
return f'Projects_{prjID}'
|
||||
if response2.text is None or response2.content=='':
|
||||
return None
|
||||
|
||||
try:
|
||||
tempFilePath:str = tempfile.gettempdir() + f"\\{uuid4().hex}.zip"
|
||||
with open(tempFilePath,'wb') as file:
|
||||
file.write(response2.content)
|
||||
|
||||
prjID = str(uuid4())
|
||||
filePath = getFileCacahePath() + f'/Projects/{prjID}'
|
||||
os.makedirs(filePath)
|
||||
import zipfile
|
||||
with zipfile.ZipFile(tempFilePath,'r') as zip_File:
|
||||
for zip_info in zip_File.infolist():
|
||||
zip_info.filename = zip_info.filename.encode('cp437').decode('gbk')
|
||||
zip_File.extract(zip_info,filePath)
|
||||
os.remove(tempFilePath)
|
||||
return f'Projects_{prjID}'
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def process_file(base64_content: str) -> str:
|
||||
prjFlag = FileLoadService.store_and_parse_file(base64_content)
|
||||
prjFlag = PrjFileLoadService.store_and_parse_file(base64_content)
|
||||
if prjFlag is None:
|
||||
return None
|
||||
#生成向量并持久化至本地
|
||||
documents = get_documents(prjFlag)
|
||||
for doc in documents:
|
||||
@@ -53,3 +71,64 @@ class FileLoadService:
|
||||
persist_storage(docstore, vector_store)
|
||||
return prjFlag
|
||||
|
||||
class ChatFileService:
|
||||
PRIVATE_STORE_PATH = os.getenv('CHAT_UPLOAD_FILECACHE','output/uploaded')
|
||||
resluts:Dict[str,Any] = {}
|
||||
|
||||
@staticmethod
|
||||
def process_file(base64_content: str) -> dict:
|
||||
file_data, extension = ChatFileService.preprocess_base64_file(base64_content)
|
||||
documents = ChatFileService.store_and_parse_file(file_data, extension)
|
||||
|
||||
pipeline = IngestionPipeline()
|
||||
nodes = pipeline.run(documents=documents)
|
||||
current_index = get_index()
|
||||
pipeline = IngestionPipeline()
|
||||
nodes = pipeline.run(documents=documents)
|
||||
if current_index is None:
|
||||
current_index = VectorStoreIndex(nodes=nodes)
|
||||
else:
|
||||
current_index.insert_nodes(nodes=nodes)
|
||||
current_index.storage_context.persist(
|
||||
persist_dir=os.environ.get("STORAGE_DIR", "storage")
|
||||
)
|
||||
|
||||
return ChatFileService.resluts
|
||||
|
||||
@staticmethod
|
||||
def preprocess_base64_file(base64_content: str) -> tuple:
|
||||
header, data = base64_content.split(",", 1)
|
||||
mime_type = header.split(";")[0].split(":", 1)[1]
|
||||
extension = mimetypes.guess_extension(mime_type)
|
||||
ChatFileService.resluts['mime_type'] = mime_type
|
||||
ChatFileService.resluts['extension'] = extension
|
||||
return base64.b64decode(data), extension
|
||||
|
||||
@staticmethod
|
||||
def store_and_parse_file(file_data, extension) -> List[Document]:
|
||||
os.makedirs(ChatFileService.PRIVATE_STORE_PATH, exist_ok=True)
|
||||
fileID = uuid4().hex
|
||||
file_name = f"{fileID}{extension}"
|
||||
file_path = Path(os.path.join(ChatFileService.PRIVATE_STORE_PATH, file_name))
|
||||
ChatFileService.resluts['id'] = fileID
|
||||
ChatFileService.resluts['file_name'] = file_name
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_data)
|
||||
|
||||
ChatFileService.resluts['size'] = os.path.getsize(file_path)
|
||||
reader_cls = ChatFileService.default_file_loaders_map().get(extension)
|
||||
if reader_cls is None:
|
||||
raise ValueError(f"File extension {extension} is not supported")
|
||||
reader = reader_cls()
|
||||
documents = reader.load_data(file_path)
|
||||
for doc in documents:
|
||||
doc.metadata["file_name"] = file_name
|
||||
doc.metadata["private"] = "true"
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def default_file_loaders_map():
|
||||
default_loaders = get_file_loaders_map()
|
||||
default_loaders[".txt"] = FlatReader
|
||||
return default_loaders
|
||||
Reference in New Issue
Block a user