134 lines
5.2 KiB
Python
134 lines
5.2 KiB
Python
import base64,os,mimetypes,requests,tempfile
|
|
from typing import List,Dict,Any
|
|
from uuid import uuid4
|
|
from app.settings import init_settings
|
|
from app.engine.loaders import get_document_Types, get_documents,getFileCacahePath
|
|
from app.engine.vectordb import get_vector_store
|
|
from app.engine.generate import get_doc_store,run_pipeline,persist_storage
|
|
from llama_index.core.schema import Document
|
|
from pathlib import Path
|
|
from llama_index.core.readers.file.base import (
|
|
_try_loading_included_file_formats as get_file_loaders_map,
|
|
)
|
|
from llama_index.readers.file import FlatReader
|
|
from llama_index.core.ingestion import IngestionPipeline
|
|
from llama_index.core import VectorStoreIndex
|
|
from app.engine.index import get_index
|
|
|
|
STORAGE_DIR = os.getenv("STORAGE_DIR", "storage")
|
|
|
|
class PrjFileLoadService:
|
|
@staticmethod
|
|
def store_and_parse_file(file_data):
|
|
prjtoJson_url = os.getenv('PRJTOJSON_URL')
|
|
convert_url = prjtoJson_url +'/prj_convert_clt2json'
|
|
files ={'file':file_data}
|
|
response1 = requests.post(
|
|
url = convert_url,
|
|
files=files
|
|
)
|
|
if response1.text is None or response1.text=='':
|
|
return None
|
|
|
|
load_url = prjtoJson_url +'/file_download'
|
|
response2 = requests.post(
|
|
url = load_url,
|
|
data=response1.text
|
|
)
|
|
if response2.text is None or response2.content=='':
|
|
return None
|
|
|
|
try:
|
|
tempFilePath:str = tempfile.gettempdir() + f"\\{uuid4().hex}.zip"
|
|
with open(tempFilePath,'wb') as file:
|
|
file.write(response2.content)
|
|
|
|
prjID = str(uuid4())
|
|
filePath = getFileCacahePath() + f'/Projects/{prjID}'
|
|
os.makedirs(filePath)
|
|
import zipfile
|
|
with zipfile.ZipFile(tempFilePath,'r') as zip_File:
|
|
for zip_info in zip_File.infolist():
|
|
zip_info.filename = zip_info.filename.encode('cp437').decode('gbk')
|
|
zip_File.extract(zip_info,filePath)
|
|
os.remove(tempFilePath)
|
|
return f'Projects_{prjID}'
|
|
except Exception as e:
|
|
return None
|
|
|
|
@staticmethod
|
|
def process_file(base64_content: str) -> str:
|
|
prjFlag = PrjFileLoadService.store_and_parse_file(base64_content)
|
|
if prjFlag is None:
|
|
return None
|
|
#生成向量并持久化至本地
|
|
documents = get_documents(prjFlag)
|
|
for doc in documents:
|
|
doc.metadata["private"] = "false"
|
|
docstore = get_doc_store(prjFlag)
|
|
vector_store = get_vector_store(prjFlag)
|
|
_ = run_pipeline(docstore, vector_store, documents)
|
|
persist_storage(docstore, vector_store)
|
|
return prjFlag
|
|
|
|
class ChatFileService:
|
|
PRIVATE_STORE_PATH = os.getenv('CHAT_UPLOAD_FILECACHE','output/uploaded')
|
|
resluts:Dict[str,Any] = {}
|
|
|
|
@staticmethod
|
|
def process_file(base64_content: str) -> dict:
|
|
file_data, extension = ChatFileService.preprocess_base64_file(base64_content)
|
|
documents = ChatFileService.store_and_parse_file(file_data, extension)
|
|
|
|
pipeline = IngestionPipeline()
|
|
nodes = pipeline.run(documents=documents)
|
|
current_index = get_index()
|
|
pipeline = IngestionPipeline()
|
|
nodes = pipeline.run(documents=documents)
|
|
if current_index is None:
|
|
current_index = VectorStoreIndex(nodes=nodes)
|
|
else:
|
|
current_index.insert_nodes(nodes=nodes)
|
|
current_index.storage_context.persist(
|
|
persist_dir=os.environ.get("STORAGE_DIR", "storage")
|
|
)
|
|
|
|
return ChatFileService.resluts
|
|
|
|
@staticmethod
|
|
def preprocess_base64_file(base64_content: str) -> tuple:
|
|
header, data = base64_content.split(",", 1)
|
|
mime_type = header.split(";")[0].split(":", 1)[1]
|
|
extension = mimetypes.guess_extension(mime_type)
|
|
ChatFileService.resluts['mime_type'] = mime_type
|
|
ChatFileService.resluts['extension'] = extension
|
|
return base64.b64decode(data), extension
|
|
|
|
@staticmethod
|
|
def store_and_parse_file(file_data, extension) -> List[Document]:
|
|
os.makedirs(ChatFileService.PRIVATE_STORE_PATH, exist_ok=True)
|
|
fileID = uuid4().hex
|
|
file_name = f"{fileID}{extension}"
|
|
file_path = Path(os.path.join(ChatFileService.PRIVATE_STORE_PATH, file_name))
|
|
ChatFileService.resluts['id'] = fileID
|
|
ChatFileService.resluts['file_name'] = file_name
|
|
|
|
with open(file_path, "wb") as f:
|
|
f.write(file_data)
|
|
|
|
ChatFileService.resluts['size'] = os.path.getsize(file_path)
|
|
reader_cls = ChatFileService.default_file_loaders_map().get(extension)
|
|
if reader_cls is None:
|
|
raise ValueError(f"File extension {extension} is not supported")
|
|
reader = reader_cls()
|
|
documents = reader.load_data(file_path)
|
|
for doc in documents:
|
|
doc.metadata["file_name"] = file_name
|
|
doc.metadata["private"] = "true"
|
|
return documents
|
|
|
|
@staticmethod
|
|
def default_file_loaders_map():
|
|
default_loaders = get_file_loaders_map()
|
|
default_loaders[".txt"] = FlatReader
|
|
return default_loaders |