新增工程信息、检索的知识片段节点回传、下一轮建议问题列表

This commit is contained in:
wanyaokun
2024-08-30 16:42:36 +08:00
parent 73565b26e4
commit e7628809ad
11 changed files with 411 additions and 134 deletions
@@ -1,16 +1,23 @@
import base64,os
from typing import List
import base64,os,mimetypes,requests,tempfile
from typing import List,Dict,Any
from uuid import uuid4
import requests
from app.settings import init_settings
from app.engine.loaders import get_document_Types, get_documents,getFileCacahePath
from app.engine.vectordb import get_vector_store
from app.engine.generate import get_doc_store,run_pipeline,persist_storage
import tempfile
from llama_index.core.schema import Document
from pathlib import Path
from llama_index.core.readers.file.base import (
_try_loading_included_file_formats as get_file_loaders_map,
)
from llama_index.readers.file import FlatReader
from llama_index.core.ingestion import IngestionPipeline
from llama_index.core import VectorStoreIndex
from app.engine.index import get_index
STORAGE_DIR = os.getenv("STORAGE_DIR", "storage")
class FileLoadService:
class PrjFileLoadService:
@staticmethod
def store_and_parse_file(file_data):
prjtoJson_url = os.getenv('PRJTOJSON_URL')
@@ -20,29 +27,40 @@ class FileLoadService:
url = convert_url,
files=files
)
if response1.text is None or response1.text=='':
return None
load_url = prjtoJson_url +'/file_download'
response2 = requests.post(
url = load_url,
data=response1.text
)
tempFilePath:str = tempfile.gettempdir() + f"\\{str(uuid4())}.zip"
with open(tempFilePath,'wb') as file:
file.write(response2.content)
prjID = str(uuid4())
filePath = getFileCacahePath() + f'/Projects/{prjID}'
os.makedirs(filePath)
import zipfile
with zipfile.ZipFile(tempFilePath,'r') as zip_File:
for zip_info in zip_File.infolist():
zip_info.filename = zip_info.filename.encode('cp437').decode('gbk')
zip_File.extract(zip_info,filePath)
os.remove(tempFilePath)
return f'Projects_{prjID}'
if response2.text is None or response2.content=='':
return None
try:
tempFilePath:str = tempfile.gettempdir() + f"\\{uuid4().hex}.zip"
with open(tempFilePath,'wb') as file:
file.write(response2.content)
prjID = str(uuid4())
filePath = getFileCacahePath() + f'/Projects/{prjID}'
os.makedirs(filePath)
import zipfile
with zipfile.ZipFile(tempFilePath,'r') as zip_File:
for zip_info in zip_File.infolist():
zip_info.filename = zip_info.filename.encode('cp437').decode('gbk')
zip_File.extract(zip_info,filePath)
os.remove(tempFilePath)
return f'Projects_{prjID}'
except Exception as e:
return None
@staticmethod
def process_file(base64_content: str) -> str:
prjFlag = FileLoadService.store_and_parse_file(base64_content)
prjFlag = PrjFileLoadService.store_and_parse_file(base64_content)
if prjFlag is None:
return None
#生成向量并持久化至本地
documents = get_documents(prjFlag)
for doc in documents:
@@ -53,3 +71,64 @@ class FileLoadService:
persist_storage(docstore, vector_store)
return prjFlag
class ChatFileService:
PRIVATE_STORE_PATH = os.getenv('CHAT_UPLOAD_FILECACHE','output/uploaded')
resluts:Dict[str,Any] = {}
@staticmethod
def process_file(base64_content: str) -> dict:
file_data, extension = ChatFileService.preprocess_base64_file(base64_content)
documents = ChatFileService.store_and_parse_file(file_data, extension)
pipeline = IngestionPipeline()
nodes = pipeline.run(documents=documents)
current_index = get_index()
pipeline = IngestionPipeline()
nodes = pipeline.run(documents=documents)
if current_index is None:
current_index = VectorStoreIndex(nodes=nodes)
else:
current_index.insert_nodes(nodes=nodes)
current_index.storage_context.persist(
persist_dir=os.environ.get("STORAGE_DIR", "storage")
)
return ChatFileService.resluts
@staticmethod
def preprocess_base64_file(base64_content: str) -> tuple:
header, data = base64_content.split(",", 1)
mime_type = header.split(";")[0].split(":", 1)[1]
extension = mimetypes.guess_extension(mime_type)
ChatFileService.resluts['mime_type'] = mime_type
ChatFileService.resluts['extension'] = extension
return base64.b64decode(data), extension
@staticmethod
def store_and_parse_file(file_data, extension) -> List[Document]:
os.makedirs(ChatFileService.PRIVATE_STORE_PATH, exist_ok=True)
fileID = uuid4().hex
file_name = f"{fileID}{extension}"
file_path = Path(os.path.join(ChatFileService.PRIVATE_STORE_PATH, file_name))
ChatFileService.resluts['id'] = fileID
ChatFileService.resluts['file_name'] = file_name
with open(file_path, "wb") as f:
f.write(file_data)
ChatFileService.resluts['size'] = os.path.getsize(file_path)
reader_cls = ChatFileService.default_file_loaders_map().get(extension)
if reader_cls is None:
raise ValueError(f"File extension {extension} is not supported")
reader = reader_cls()
documents = reader.load_data(file_path)
for doc in documents:
doc.metadata["file_name"] = file_name
doc.metadata["private"] = "true"
return documents
@staticmethod
def default_file_loaders_map():
default_loaders = get_file_loaders_map()
default_loaders[".txt"] = FlatReader
return default_loaders
@@ -0,0 +1,43 @@
from typing import List
from app.api.routers.request.base import message
from llama_index.core.prompts import PromptTemplate
from llama_index.core.settings import Settings
from pydantic import BaseModel
NEXT_QUESTIONS_SUGGESTION_PROMPT = PromptTemplate(
"你是一个乐于助人的助手!你的任务是对用户可能会问的下一个问题给出建议。 "
"\n这是对话历史记录"
"\n---------------------\n{conversation}\n---------------------"
"考虑到对话历史记录,仅限于现在知识库已有内容, 请给我 $number_of_questions 个你接下来可能会问题的问题!"
)
N_QUESTION_TO_GENERATE = 3
class NextQuestions(BaseModel):
"""A list of questions that user might ask next"""
questions: List[str]
class NextQuestionSuggestion:
@staticmethod
async def suggest_next_questions(
message_id: str,
number_of_questions: int = N_QUESTION_TO_GENERATE,
) -> List[str]:
last_user_message = None
last_assistant_message = None
results = message().query(message_id)
if len(results) > 0:
last_user_message = results[0]['query']
last_assistant_message = results[0]['answer']
conversation: str = f"{last_user_message}\n{last_assistant_message}"
output: NextQuestions = await Settings.llm.astructured_predict(
NextQuestions,
prompt=NEXT_QUESTIONS_SUGGESTION_PROMPT,
conversation=conversation,
nun_questions=number_of_questions,
)
return output.questions
return []