新增工程信息、检索的知识片段节点回传、下一轮建议问题列表
This commit is contained in:
+127
-50
@@ -3,7 +3,6 @@ import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, List, Any, Optional, AsyncGenerator
|
||||
from collections import deque
|
||||
|
||||
from aiostream import stream
|
||||
from fastapi import APIRouter, Request,HTTPException
|
||||
@@ -13,17 +12,19 @@ from llama_index.core.base.llms.types import ChatMessage
|
||||
from llama_index.core.callbacks import CBEventType
|
||||
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
|
||||
from llama_index.core.tools import ToolOutput
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
from pydantic import BaseModel
|
||||
from app.api.routers.request.base import userMng, conversations,message,parameter,feedback
|
||||
from app.api.routers.request.base import userMng, conversations,message,ProjectInfo,feedback
|
||||
from app.api.routers.request.baseConfig import *
|
||||
from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest
|
||||
from app.engine import get_chat_engine
|
||||
import uuid
|
||||
from app.api.routers.services.fileServices import FileLoadService
|
||||
from app.api.routers.services.fileServices import PrjFileLoadService,ChatFileService
|
||||
from app.api.routers.services.suggestion import NextQuestionSuggestion
|
||||
import time
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
api_router = r = APIRouter()
|
||||
v1_router = v = APIRouter()
|
||||
|
||||
class ChatCallbackEvent(BaseModel):
|
||||
@@ -32,7 +33,7 @@ class ChatCallbackEvent(BaseModel):
|
||||
|
||||
def get_common_param(self)-> dict:
|
||||
return {
|
||||
'event': self.event_type.name,
|
||||
'event': self.event_type.value,
|
||||
'conversation_id':self.payload.get("conversation_id"),
|
||||
'message_id': self.payload.get("message_id"),
|
||||
'created_at': int(time.time()),
|
||||
@@ -48,7 +49,7 @@ class ChatCallbackEvent(BaseModel):
|
||||
"workflow_id": self.payload.get('workflow_id'),
|
||||
"sequence_number": 1709,
|
||||
"inputs": {
|
||||
"sys.query": self.payload.get('query'),
|
||||
"sys.query": f"开始查询 {self.payload.get('query')}",
|
||||
"sys.files": [],
|
||||
"sys.conversation_id": self.payload.get('conversation_id'),
|
||||
"sys.user_id": self.payload.get('use_id')
|
||||
@@ -93,7 +94,7 @@ class ChatCallbackEvent(BaseModel):
|
||||
"id": self.payload.get('nodeid'),
|
||||
"node_id": self.payload.get('nodeid'),
|
||||
"node_type": "http-request",
|
||||
"title": self.payload.get('title'),
|
||||
"title": f"正在执行事件:{self.payload.get('title')}",
|
||||
"index": self.payload.get('index'),
|
||||
"predecessor_node_id": self.payload.get('predecessor_node_id'),
|
||||
"inputs": '',
|
||||
@@ -111,7 +112,7 @@ class ChatCallbackEvent(BaseModel):
|
||||
"id": self.payload.get('nodeid'),
|
||||
"node_id": self.payload.get('nodeid'),
|
||||
"node_type": "http-request",
|
||||
"title": self.payload.get('title'),
|
||||
"title": f"事件执行结束:{self.payload.get('title')}",
|
||||
"index": self.payload.get('index'),
|
||||
"predecessor_node_id": self.payload.get('predecessor_node_id'),
|
||||
"inputs": '',
|
||||
@@ -138,15 +139,54 @@ class ChatCallbackEvent(BaseModel):
|
||||
|
||||
def get_MessageEnd_param(self) -> dict:
|
||||
params = self.get_common_param()
|
||||
nodeInfos = []
|
||||
source_nodes = self.payload.get('source_node')
|
||||
if source_nodes is not None:
|
||||
for i in range(len(source_nodes)):
|
||||
source_node:NodeWithScore = source_nodes[i]
|
||||
metadata:dict = source_node.node.metadata
|
||||
nodeInfo = {
|
||||
"position": i,
|
||||
"dataset_id": metadata.get("pipeline_id"),
|
||||
"dataset_name": metadata.get("file_name"),
|
||||
"document_id": source_node.node_id,
|
||||
"document_name": metadata.get("file_name"),
|
||||
"data_source_type": "upload_file",
|
||||
"segment_id": source_node.node_id,
|
||||
"retriever_from": "workflow",
|
||||
"score": source_node.score,
|
||||
"hit_count": 1,
|
||||
"word_count": 632,
|
||||
"segment_position": i,
|
||||
"index_node_hash": "",
|
||||
"content": source_node.text
|
||||
}
|
||||
nodeInfos.append(nodeInfo)
|
||||
params.update({
|
||||
'id':self.payload.get('message_id'),
|
||||
'metadata':self.payload.get('metadata')
|
||||
'metadata':{
|
||||
"retriever_resources":nodeInfos,
|
||||
"usage":{
|
||||
"prompt_tokens": 4972,
|
||||
"prompt_unit_price": "0.0",
|
||||
"prompt_price_unit": "0.0",
|
||||
"prompt_price": "0.0",
|
||||
"completion_tokens": 332,
|
||||
"completion_unit_price": "0.0",
|
||||
"completion_price_unit": "0.0",
|
||||
"completion_price": "0.0",
|
||||
"total_tokens": 5304,
|
||||
"total_price": "0.0",
|
||||
"currency": "USD",
|
||||
"latency": 4.897703120019287
|
||||
}
|
||||
}
|
||||
})
|
||||
return params
|
||||
|
||||
def to_response(self)-> dict|None:
|
||||
try:
|
||||
match self.event_type:
|
||||
match self.event_type.value:
|
||||
case "workflow_started":
|
||||
return self.get_WorkflowStart_param()
|
||||
case "workflow_finished":
|
||||
@@ -180,24 +220,18 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
]
|
||||
super().__init__(ignored_events, ignored_events)
|
||||
self._aqueue = asyncio.Queue()
|
||||
self._response:str = ''
|
||||
self._response: StreamingAgentChatResponse = None
|
||||
self._params:Dict[str,Any] = params
|
||||
self._nodeStack:deque = deque()
|
||||
self._nodeStack:List[str] = []
|
||||
|
||||
#添加工作流开始事件
|
||||
data:ChatRequestData = self._params['data']
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
args.update(
|
||||
{
|
||||
'use_id': data.user,
|
||||
'query': data.query,
|
||||
'conversation_id': data.conversation_id
|
||||
}
|
||||
)
|
||||
wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args)
|
||||
wf_event = self.makeWorkflow_startEvent()
|
||||
if wf_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(wf_event)
|
||||
|
||||
def setResponse(self,response: StreamingAgentChatResponse):
|
||||
self._response = response
|
||||
|
||||
def on_event_start(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
@@ -208,7 +242,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload))
|
||||
|
||||
self._nodeStack.append(event_id)
|
||||
nindex = self._nodeStack.count() - 1
|
||||
nindex = len(self._nodeStack) - 1
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
args.update(
|
||||
{
|
||||
@@ -222,7 +256,6 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
if nd_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(nd_event)
|
||||
|
||||
|
||||
def on_event_end(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
@@ -236,7 +269,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
nodeID = self._nodeStack[-1]
|
||||
if nodeID == event_id:
|
||||
nindex = self._nodeStack.count() - 1
|
||||
nindex = len(self._nodeStack) - 1
|
||||
args.update(
|
||||
{
|
||||
'nodeid':event_id,
|
||||
@@ -250,7 +283,6 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
self._aqueue.put_nowait(nd_event)
|
||||
self._nodeStack.pop()
|
||||
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
"""No-op."""
|
||||
logger.info("trace_start:{}\n".format(trace_id))
|
||||
@@ -262,21 +294,12 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
) -> None:
|
||||
"""No-op."""
|
||||
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
|
||||
data:ChatRequestData = self._params['data']
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
args.update(
|
||||
{
|
||||
'response':self._response,
|
||||
'conversation_id': data.conversation_id
|
||||
}
|
||||
)
|
||||
wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args)
|
||||
|
||||
wf_event = self.makeWorkflow_finishedEvent()
|
||||
if wf_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(wf_event)
|
||||
|
||||
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args)
|
||||
msgEnt_event = self.makeMessage_EndEvent()
|
||||
if msgEnt_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(msgEnt_event)
|
||||
|
||||
@@ -287,6 +310,38 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
def makeWorkflow_startEvent(self)->ChatCallbackEvent:
|
||||
data:ChatRequestData = self._params['data']
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
args.update(
|
||||
{
|
||||
'use_id': data.user,
|
||||
'query': data.query,
|
||||
'conversation_id': data.conversation_id
|
||||
}
|
||||
)
|
||||
return ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args)
|
||||
|
||||
def makeWorkflow_finishedEvent(self)->ChatCallbackEvent:
|
||||
data:ChatRequestData = self._params['data']
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
args.update(
|
||||
{
|
||||
'response': '',
|
||||
'conversation_id': data.conversation_id
|
||||
}
|
||||
)
|
||||
return ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args)
|
||||
|
||||
def makeMessage_EndEvent(self)->ChatCallbackEvent:
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
if self._response is not None:
|
||||
args.update({
|
||||
'source_node': self._response.source_nodes
|
||||
})
|
||||
msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args)
|
||||
return msgEnt_event
|
||||
|
||||
class IDManager:
|
||||
def createID(self):
|
||||
return {
|
||||
@@ -390,7 +445,7 @@ async def post_conversations(request: Request, data: ChatRequestData):
|
||||
params = data.inputs or {}
|
||||
|
||||
# 获取聊天引擎对象
|
||||
chat_engine = get_chat_engine(filters=filters, params=params,prjFlag = data.prjFlag)
|
||||
chat_engine = get_chat_engine(filters=filters, params=params)
|
||||
|
||||
# 启动聊天事件监听
|
||||
ids = IDManager().createID()
|
||||
@@ -399,7 +454,7 @@ async def post_conversations(request: Request, data: ChatRequestData):
|
||||
|
||||
# 执行异步聊天
|
||||
response = await chat_engine.astream_chat(data.query)
|
||||
|
||||
event_handler.setResponse(response)
|
||||
# 返回异步消息回应
|
||||
return ChatStreamResponse(request, event_handler, response, data,ids)
|
||||
|
||||
@@ -468,29 +523,51 @@ async def query_conversations(user:str, first_id:str = None, limit:str = None, p
|
||||
|
||||
@v.get("/parameters")
|
||||
async def query_parameters(user:str):
|
||||
params = parameter().get(user)
|
||||
if len(params) == 0:
|
||||
params = BaseConfig().ParamterCfg()
|
||||
return params
|
||||
prjObj = ProjectInfo()
|
||||
return BaseConfig().ParamterCfg(projectInfo = prjObj.projectNames())
|
||||
|
||||
@v.post("/messages/{message_id}/feedbacks")
|
||||
async def post_feedbacks(request: Request,message_id:str,params:Dict[str,Any]):
|
||||
if params['rating'] =='null':
|
||||
if params['rating'] is None:
|
||||
feedback().delete(message_id)
|
||||
else:
|
||||
condition = {'id':message_id}
|
||||
results = message().query(**condition)
|
||||
results = message().query(message_id)
|
||||
if len(results) > 0:
|
||||
result = results[0]
|
||||
feedback().add(message_id=message_id,query=result['query'],
|
||||
answer=result['answer'],rating=params['rating'])
|
||||
|
||||
@v.post("")
|
||||
def upload_file(request: ChatFileUploadRequest) -> List[str]:
|
||||
@v.post("/files/upload")
|
||||
def upload_file(request: ChatFileUploadRequest):
|
||||
try:
|
||||
logger.info("Processing file")
|
||||
return FileLoadService.process_file(request.base64)
|
||||
resluts = ChatFileService.process_file(request.base64)
|
||||
return {
|
||||
'id':resluts.get('id'),
|
||||
'name': resluts.get('name'),
|
||||
'size': resluts.get('size'),
|
||||
'extension':resluts.get('extension'),
|
||||
'mime_type':resluts.get('mime_type'),
|
||||
'created_by':str(uuid.uuid4()),
|
||||
'created_at':int(time.time())
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Error processing file")
|
||||
|
||||
@v.post("/applications")
|
||||
def upload_file(request: ChatFileUploadRequest):
|
||||
try:
|
||||
logger.info("Processing file")
|
||||
return PrjFileLoadService.process_file(request.base64)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Error processing file")
|
||||
|
||||
@v.post("/messages/{message_id}/suggested")
|
||||
async def post_suggested(request: Request,message_id:str,user:str):
|
||||
questions = await NextQuestionSuggestion.suggest_next_questions(message_id)
|
||||
return {
|
||||
"result": "success",
|
||||
"data":questions
|
||||
}
|
||||
@@ -2,7 +2,7 @@ from datetime import datetime
|
||||
import uuid
|
||||
from app.api.routers.request.baseConfig import BaseConfig
|
||||
from app.api.routers.request.dbOrm import DBManager
|
||||
|
||||
from typing import List
|
||||
dbManage = DBManager()
|
||||
|
||||
class conversations:
|
||||
@@ -122,8 +122,9 @@ class message:
|
||||
def delete(self,user_id:str):
|
||||
dbManage.delete(self._tableName,user_id = user_id)
|
||||
|
||||
def query(self,**condition):
|
||||
def query(self,id:str):
|
||||
results = []
|
||||
condition = {'id':id}
|
||||
records = dbManage.query(self._tableName,**condition)
|
||||
for record in records:
|
||||
results.append(record.dict())
|
||||
@@ -153,3 +154,33 @@ class feedback:
|
||||
if len(records) > 0:
|
||||
return records[0].dict()
|
||||
return None
|
||||
|
||||
class ProjectInfo:
|
||||
def __init__(self) -> None:
|
||||
self._tableName = 'projectInfos'
|
||||
dbManage.createTable(self._tableName)
|
||||
|
||||
def add(self,name:str,flag:str):
|
||||
record = {
|
||||
'prjectName': name,
|
||||
'prjFlag': flag
|
||||
}
|
||||
dbManage.addRecord(self._tableName,record)
|
||||
|
||||
def projectNames(self)->List[str]:
|
||||
records = dbManage.query(self._tableName)
|
||||
names = []
|
||||
for record in records:
|
||||
data:dict = record.dict()
|
||||
name = data.get('prjectName')
|
||||
if name !='':
|
||||
names.append(name)
|
||||
return names
|
||||
|
||||
def prjFalg(self,name:str):
|
||||
records = dbManage.query(self._tableName)
|
||||
for record in records:
|
||||
data:dict = record.dict()
|
||||
if data.get('prjectName') == name:
|
||||
return data['prjFlag']
|
||||
return ''
|
||||
@@ -5,7 +5,8 @@ from enum import Enum
|
||||
class BaseConfig(BaseModel):
|
||||
projectInfo:str = os.getenv("PROJECT_TITLE","您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!")
|
||||
|
||||
def ParamterCfg(self):
|
||||
def ParamterCfg(self,**args):
|
||||
projectInfo = args.get('projectInfo')
|
||||
questions = os.getenv("CONVERSATION_STARTERS", "dev")
|
||||
return{
|
||||
"opening_statement": self.projectInfo,
|
||||
@@ -30,7 +31,18 @@ class BaseConfig(BaseModel):
|
||||
"more_like_this": {
|
||||
"enabled": False
|
||||
},
|
||||
"user_input_form": [],
|
||||
"user_input_form": [
|
||||
{
|
||||
"select": {
|
||||
"variable": "projectname",
|
||||
"label": "\u5de5\u7a0b\u540d\u79f0",
|
||||
"type": "select",
|
||||
"max_length": 48,
|
||||
"required": True,
|
||||
"options": [projectInfo]
|
||||
}
|
||||
}
|
||||
],
|
||||
"sensitive_word_avoidance": {
|
||||
"enabled": False
|
||||
},
|
||||
|
||||
@@ -55,6 +55,13 @@ class FeedBackOrm(Base):
|
||||
answer = Column(String)
|
||||
rating = Column(String)
|
||||
|
||||
class ProjectInfoOrm(Base):
|
||||
__tablename__ = "projectInfos"
|
||||
|
||||
prjFlag = Column(String,primary_key=True)
|
||||
prjectName = Column(String)
|
||||
|
||||
|
||||
#数据结构
|
||||
class ConversationModel(BaseModel):
|
||||
id: str
|
||||
@@ -121,6 +128,17 @@ class FeedBackModel(BaseModel):
|
||||
def orm(cls):
|
||||
return FeedBackOrm
|
||||
|
||||
class ProjectInfoModel(BaseModel):
|
||||
prjectName:str
|
||||
prjFlag:str
|
||||
|
||||
class Config:
|
||||
from_attributes=True
|
||||
|
||||
@classmethod
|
||||
def orm(cls):
|
||||
return ProjectInfoOrm
|
||||
|
||||
class DBManager:
|
||||
def __init__(self) -> None:
|
||||
DATABASE_URL = os.getenv("SQLITE_DATABASE_URL")
|
||||
|
||||
@@ -10,7 +10,6 @@ class ChatRequestData(BaseModel):
|
||||
response_mode: str
|
||||
files: Any
|
||||
conversation_id: str = None
|
||||
prjFlag:Optional[str] = ''
|
||||
|
||||
class ChatFileUploadRequest(BaseModel):
|
||||
base64: str
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
import base64,os
|
||||
from typing import List
|
||||
import base64,os,mimetypes,requests,tempfile
|
||||
from typing import List,Dict,Any
|
||||
from uuid import uuid4
|
||||
import requests
|
||||
from app.settings import init_settings
|
||||
from app.engine.loaders import get_document_Types, get_documents,getFileCacahePath
|
||||
from app.engine.vectordb import get_vector_store
|
||||
from app.engine.generate import get_doc_store,run_pipeline,persist_storage
|
||||
import tempfile
|
||||
from llama_index.core.schema import Document
|
||||
from pathlib import Path
|
||||
from llama_index.core.readers.file.base import (
|
||||
_try_loading_included_file_formats as get_file_loaders_map,
|
||||
)
|
||||
from llama_index.readers.file import FlatReader
|
||||
from llama_index.core.ingestion import IngestionPipeline
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from app.engine.index import get_index
|
||||
|
||||
STORAGE_DIR = os.getenv("STORAGE_DIR", "storage")
|
||||
|
||||
class FileLoadService:
|
||||
class PrjFileLoadService:
|
||||
@staticmethod
|
||||
def store_and_parse_file(file_data):
|
||||
prjtoJson_url = os.getenv('PRJTOJSON_URL')
|
||||
@@ -20,29 +27,40 @@ class FileLoadService:
|
||||
url = convert_url,
|
||||
files=files
|
||||
)
|
||||
if response1.text is None or response1.text=='':
|
||||
return None
|
||||
|
||||
load_url = prjtoJson_url +'/file_download'
|
||||
response2 = requests.post(
|
||||
url = load_url,
|
||||
data=response1.text
|
||||
)
|
||||
tempFilePath:str = tempfile.gettempdir() + f"\\{str(uuid4())}.zip"
|
||||
with open(tempFilePath,'wb') as file:
|
||||
file.write(response2.content)
|
||||
if response2.text is None or response2.content=='':
|
||||
return None
|
||||
|
||||
prjID = str(uuid4())
|
||||
filePath = getFileCacahePath() + f'/Projects/{prjID}'
|
||||
os.makedirs(filePath)
|
||||
import zipfile
|
||||
with zipfile.ZipFile(tempFilePath,'r') as zip_File:
|
||||
for zip_info in zip_File.infolist():
|
||||
zip_info.filename = zip_info.filename.encode('cp437').decode('gbk')
|
||||
zip_File.extract(zip_info,filePath)
|
||||
os.remove(tempFilePath)
|
||||
return f'Projects_{prjID}'
|
||||
try:
|
||||
tempFilePath:str = tempfile.gettempdir() + f"\\{uuid4().hex}.zip"
|
||||
with open(tempFilePath,'wb') as file:
|
||||
file.write(response2.content)
|
||||
|
||||
prjID = str(uuid4())
|
||||
filePath = getFileCacahePath() + f'/Projects/{prjID}'
|
||||
os.makedirs(filePath)
|
||||
import zipfile
|
||||
with zipfile.ZipFile(tempFilePath,'r') as zip_File:
|
||||
for zip_info in zip_File.infolist():
|
||||
zip_info.filename = zip_info.filename.encode('cp437').decode('gbk')
|
||||
zip_File.extract(zip_info,filePath)
|
||||
os.remove(tempFilePath)
|
||||
return f'Projects_{prjID}'
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def process_file(base64_content: str) -> str:
|
||||
prjFlag = FileLoadService.store_and_parse_file(base64_content)
|
||||
prjFlag = PrjFileLoadService.store_and_parse_file(base64_content)
|
||||
if prjFlag is None:
|
||||
return None
|
||||
#生成向量并持久化至本地
|
||||
documents = get_documents(prjFlag)
|
||||
for doc in documents:
|
||||
@@ -53,3 +71,64 @@ class FileLoadService:
|
||||
persist_storage(docstore, vector_store)
|
||||
return prjFlag
|
||||
|
||||
class ChatFileService:
|
||||
PRIVATE_STORE_PATH = os.getenv('CHAT_UPLOAD_FILECACHE','output/uploaded')
|
||||
resluts:Dict[str,Any] = {}
|
||||
|
||||
@staticmethod
|
||||
def process_file(base64_content: str) -> dict:
|
||||
file_data, extension = ChatFileService.preprocess_base64_file(base64_content)
|
||||
documents = ChatFileService.store_and_parse_file(file_data, extension)
|
||||
|
||||
pipeline = IngestionPipeline()
|
||||
nodes = pipeline.run(documents=documents)
|
||||
current_index = get_index()
|
||||
pipeline = IngestionPipeline()
|
||||
nodes = pipeline.run(documents=documents)
|
||||
if current_index is None:
|
||||
current_index = VectorStoreIndex(nodes=nodes)
|
||||
else:
|
||||
current_index.insert_nodes(nodes=nodes)
|
||||
current_index.storage_context.persist(
|
||||
persist_dir=os.environ.get("STORAGE_DIR", "storage")
|
||||
)
|
||||
|
||||
return ChatFileService.resluts
|
||||
|
||||
@staticmethod
|
||||
def preprocess_base64_file(base64_content: str) -> tuple:
|
||||
header, data = base64_content.split(",", 1)
|
||||
mime_type = header.split(";")[0].split(":", 1)[1]
|
||||
extension = mimetypes.guess_extension(mime_type)
|
||||
ChatFileService.resluts['mime_type'] = mime_type
|
||||
ChatFileService.resluts['extension'] = extension
|
||||
return base64.b64decode(data), extension
|
||||
|
||||
@staticmethod
|
||||
def store_and_parse_file(file_data, extension) -> List[Document]:
|
||||
os.makedirs(ChatFileService.PRIVATE_STORE_PATH, exist_ok=True)
|
||||
fileID = uuid4().hex
|
||||
file_name = f"{fileID}{extension}"
|
||||
file_path = Path(os.path.join(ChatFileService.PRIVATE_STORE_PATH, file_name))
|
||||
ChatFileService.resluts['id'] = fileID
|
||||
ChatFileService.resluts['file_name'] = file_name
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_data)
|
||||
|
||||
ChatFileService.resluts['size'] = os.path.getsize(file_path)
|
||||
reader_cls = ChatFileService.default_file_loaders_map().get(extension)
|
||||
if reader_cls is None:
|
||||
raise ValueError(f"File extension {extension} is not supported")
|
||||
reader = reader_cls()
|
||||
documents = reader.load_data(file_path)
|
||||
for doc in documents:
|
||||
doc.metadata["file_name"] = file_name
|
||||
doc.metadata["private"] = "true"
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def default_file_loaders_map():
|
||||
default_loaders = get_file_loaders_map()
|
||||
default_loaders[".txt"] = FlatReader
|
||||
return default_loaders
|
||||
@@ -0,0 +1,43 @@
|
||||
from typing import List
|
||||
|
||||
from app.api.routers.request.base import message
|
||||
from llama_index.core.prompts import PromptTemplate
|
||||
from llama_index.core.settings import Settings
|
||||
from pydantic import BaseModel
|
||||
|
||||
NEXT_QUESTIONS_SUGGESTION_PROMPT = PromptTemplate(
|
||||
"你是一个乐于助人的助手!你的任务是对用户可能会问的下一个问题给出建议。 "
|
||||
"\n这是对话历史记录"
|
||||
"\n---------------------\n{conversation}\n---------------------"
|
||||
"考虑到对话历史记录,仅限于现在知识库已有内容, 请给我 $number_of_questions 个你接下来可能会问题的问题!"
|
||||
)
|
||||
N_QUESTION_TO_GENERATE = 3
|
||||
|
||||
|
||||
class NextQuestions(BaseModel):
|
||||
"""A list of questions that user might ask next"""
|
||||
|
||||
questions: List[str]
|
||||
|
||||
|
||||
class NextQuestionSuggestion:
|
||||
@staticmethod
|
||||
async def suggest_next_questions(
|
||||
message_id: str,
|
||||
number_of_questions: int = N_QUESTION_TO_GENERATE,
|
||||
) -> List[str]:
|
||||
last_user_message = None
|
||||
last_assistant_message = None
|
||||
results = message().query(message_id)
|
||||
if len(results) > 0:
|
||||
last_user_message = results[0]['query']
|
||||
last_assistant_message = results[0]['answer']
|
||||
conversation: str = f"{last_user_message}\n{last_assistant_message}"
|
||||
output: NextQuestions = await Settings.llm.astructured_predict(
|
||||
NextQuestions,
|
||||
prompt=NEXT_QUESTIONS_SUGGESTION_PROMPT,
|
||||
conversation=conversation,
|
||||
nun_questions=number_of_questions,
|
||||
)
|
||||
return output.questions
|
||||
return []
|
||||
@@ -8,9 +8,18 @@ from app.engine.engine import create_query_engine, create_summary_query_engine
|
||||
from app.engine.index import get_index
|
||||
#from app.engine.loaders.db import makeDescriptionByEngine
|
||||
from app.engine.tools import ToolFactory
|
||||
from app.api.routers.request.base import ProjectInfo
|
||||
|
||||
def getPrjFalg(params:dict=None)->str:
|
||||
prjFlag = ''
|
||||
if params is not None:
|
||||
inputs:dict = params.get('inputs')
|
||||
if inputs is not None:
|
||||
prjFlag = ProjectInfo.prjFalg(inputs.get('projectname'))
|
||||
return prjFlag
|
||||
|
||||
|
||||
def get_chat_engine(filters=None, params=None,**args):
|
||||
def get_chat_engine(filters=None, params:dict=None):
|
||||
system_prompt = os.getenv("SYSTEM_PROMPT")
|
||||
top_k = int(os.getenv("TOP_K", "3"))
|
||||
use_reranker = os.getenv("RERANK_ENABLED")
|
||||
@@ -24,7 +33,13 @@ def get_chat_engine(filters=None, params=None,**args):
|
||||
#tools.append(sql_query_tool)
|
||||
|
||||
# Add query tool if index exists
|
||||
index = get_index(**args)
|
||||
prjFlag = ''
|
||||
if params is not None:
|
||||
inputs:dict = params.get('inputs')
|
||||
if inputs is not None:
|
||||
prjFlag = inputs.get('projectname')
|
||||
|
||||
index = get_index(prjFlag = getPrjFalg(params))
|
||||
if index is not None:
|
||||
summary_query_engine = create_summary_query_engine(index,top_k,use_reranker,filters)
|
||||
summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool",
|
||||
@@ -57,6 +72,7 @@ def get_chat_engine(filters=None, params=None,**args):
|
||||
verbose=True,
|
||||
)
|
||||
return agentrunner
|
||||
|
||||
# create the function calling worker for reasoning
|
||||
# worker = FunctionCallingAgentWorker.from_tools(
|
||||
# tools, verbose=True
|
||||
|
||||
@@ -7,15 +7,15 @@ logger = logging.getLogger("uvicorn")
|
||||
|
||||
def get_index(**args):
|
||||
logger.info("Connecting vector store...")
|
||||
prjFlags = get_document_Types()
|
||||
if len(prjFlags)<=0:
|
||||
return None
|
||||
prjFlag = args.get('prjFlag','')
|
||||
flag = prjFlags[0] if prjFlag not in prjFlags else prjFlag
|
||||
if 'prjFlag' in args:
|
||||
prjFlags = get_document_Types()
|
||||
if len(prjFlags)<=0:
|
||||
return None
|
||||
prjFlag = args.get('prjFlag','')
|
||||
flag = prjFlags[0] if prjFlag not in prjFlags else prjFlag
|
||||
else:
|
||||
flag = ''
|
||||
store = get_vector_store(flag)
|
||||
# Load the index from the vector store
|
||||
# If you are using a vector store that doesn't store text,
|
||||
# you must load the index from both the vector store and the document store
|
||||
index = VectorStoreIndex.from_vector_store(store)
|
||||
logger.info("Finished load index from vector store.")
|
||||
return index
|
||||
|
||||
@@ -58,24 +58,26 @@ def get_document_Types():
|
||||
def get_documents(docType:str):
|
||||
documents = []
|
||||
config = load_configs()
|
||||
|
||||
if config is None or len(config.items()) == 0:
|
||||
return documents
|
||||
return documents
|
||||
|
||||
for loader_type, loader_config in config.items():
|
||||
logger.info(
|
||||
f"Loading documents from loader: {loader_type}, config: {loader_config}"
|
||||
)
|
||||
if loader_config.get('enable', True): # 检查 enable 字段
|
||||
logger.info(
|
||||
f"Loading documents from loader: {loader_type}, config: {loader_config}"
|
||||
)
|
||||
|
||||
loader_config = loader_config or []
|
||||
match loader_type:
|
||||
case "file":
|
||||
document = get_file_documents(FileLoaderConfig(**loader_config),docType)
|
||||
case "web":
|
||||
document = get_web_documents(WebLoaderConfig(**loader_config))
|
||||
case "db":
|
||||
document = get_db_documents(configs=[DBLoaderConfig(**cfg) for cfg in loader_config])
|
||||
case _:
|
||||
raise ValueError(f"Invalid loader type: {loader_type}")
|
||||
documents.extend(document)
|
||||
loader_config = loader_config or []
|
||||
match loader_type:
|
||||
case "file":
|
||||
document = get_file_documents(FileLoaderConfig(**loader_config),docType)
|
||||
case "web":
|
||||
document = get_web_documents(WebLoaderConfig(**loader_config))
|
||||
case "db":
|
||||
document = get_db_documents(configs=[DBLoaderConfig(**cfg) for cfg in loader_config])
|
||||
case _:
|
||||
raise ValueError(f"Invalid loader type: {loader_type}")
|
||||
documents.extend(document)
|
||||
|
||||
return documents
|
||||
+32
-32
@@ -3,46 +3,46 @@ file:
|
||||
# use_llama_parse: Use LlamaParse if `true`. Needs a `LLAMA_CLOUD_API_KEY` from https://cloud.llamaindex.ai set as environment variable
|
||||
use_llama_parse: false
|
||||
|
||||
db:
|
||||
#db:
|
||||
# The configuration for the database loader, only supports MySQL and PostgreSQL databases for now.
|
||||
# uri: The URI for the database. E.g.: mysql+pymysql://user:password@localhost:3306/db or postgresql+psycopg2://user:password@localhost:5432/db
|
||||
# query: The query to fetch data from the database. E.g.: SELECT * FROM table
|
||||
- uri: mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
||||
enable: true # 添加 enable 字段
|
||||
queries:
|
||||
- sql: select * from ProjectProperties;
|
||||
explanation: "工程属性表数据,层级关系包含在博微电力造价工程文件格式_ProjectProperties.json文件中。"
|
||||
#- uri: mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
||||
#enable: false # 添加 enable 字段
|
||||
#queries:
|
||||
#- sql: select * from ProjectProperties;
|
||||
#explanation: "工程属性表数据,层级关系包含在博微电力造价工程文件格式_ProjectProperties.json文件中。"
|
||||
|
||||
- sql: select Id, ParentId, Level, Name, Code, Amount, Amount_Total from TotalCalculateTable;
|
||||
explanation: "总算表数据,层级关系包含在博微电力造价工程文件格式_TotalCalculateTable.json文件中。"
|
||||
#- sql: select Id, ParentId, Level, Name, Code, Amount, Amount_Total from TotalCalculateTable;
|
||||
#explanation: "总算表数据,层级关系包含在博微电力造价工程文件格式_TotalCalculateTable.json文件中。"
|
||||
|
||||
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '线路';
|
||||
explanation: "专业类型为线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
|
||||
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '余物清理';
|
||||
explanation: "专业类型为余物清理的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
|
||||
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '拆除线路';
|
||||
explanation: "专业类型为拆除线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
|
||||
#- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '线路';
|
||||
#explanation: "专业类型为线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
|
||||
#- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '余物清理';
|
||||
#explanation: "专业类型为余物清理的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
|
||||
#- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '拆除线路';
|
||||
#explanation: "专业类型为拆除线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
|
||||
|
||||
- sql: select Id, ParentId, Level, Name, Code, Rate, Amount from OtherFee;
|
||||
explanation: "其他费用表数据,层级关系包含在博微电力造价工程文件格式_OtherFee.json文件中"
|
||||
#- sql: select Id, ParentId, Level, Name, Code, Rate, Amount from OtherFee;
|
||||
#explanation: "其他费用表数据,层级关系包含在博微电力造价工程文件格式_OtherFee.json文件中"
|
||||
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表'
|
||||
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(调试工程)aa'
|
||||
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '大型土石方取费表'
|
||||
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)'
|
||||
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)(1)'
|
||||
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(拆除)'
|
||||
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
#- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表'
|
||||
#explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
#- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(调试工程)aa'
|
||||
#explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
#- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '大型土石方取费表'
|
||||
#explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
#- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)'
|
||||
#explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
#- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)(1)'
|
||||
#explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
#- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(拆除)'
|
||||
#explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '线路'
|
||||
explanation: "专业类型为线路的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '余物清理'
|
||||
explanation: "专业类型为余物清理的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
|
||||
#- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '线路'
|
||||
#explanation: "专业类型为线路的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
|
||||
#- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '余物清理'
|
||||
#explanation: "专业类型为余物清理的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
|
||||
#web:
|
||||
# driver_arguments:
|
||||
# # The arguments to pass to the webdriver. E.g.: add --headless to run in headless mode
|
||||
|
||||
Reference in New Issue
Block a user