新增工程信息、检索的知识片段节点回传、下一轮建议问题列表

This commit is contained in:
wanyaokun
2024-08-30 16:42:36 +08:00
parent 73565b26e4
commit e7628809ad
11 changed files with 411 additions and 134 deletions
+128 -51
View File
@@ -3,7 +3,6 @@ import json
import logging
import time
from typing import Dict, List, Any, Optional, AsyncGenerator
from collections import deque
from aiostream import stream
from fastapi import APIRouter, Request,HTTPException
@@ -13,17 +12,19 @@ from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.callbacks import CBEventType
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
from llama_index.core.tools import ToolOutput
from llama_index.core.schema import NodeWithScore
from pydantic import BaseModel
from app.api.routers.request.base import userMng, conversations,message,parameter,feedback
from app.api.routers.request.base import userMng, conversations,message,ProjectInfo,feedback
from app.api.routers.request.baseConfig import *
from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest
from app.engine import get_chat_engine
import uuid
from app.api.routers.services.fileServices import FileLoadService
from app.api.routers.services.fileServices import PrjFileLoadService,ChatFileService
from app.api.routers.services.suggestion import NextQuestionSuggestion
import time
logger = logging.getLogger("uvicorn")
api_router = r = APIRouter()
v1_router = v = APIRouter()
class ChatCallbackEvent(BaseModel):
@@ -32,7 +33,7 @@ class ChatCallbackEvent(BaseModel):
def get_common_param(self)-> dict:
return {
'event': self.event_type.name,
'event': self.event_type.value,
'conversation_id':self.payload.get("conversation_id"),
'message_id': self.payload.get("message_id"),
'created_at': int(time.time()),
@@ -48,7 +49,7 @@ class ChatCallbackEvent(BaseModel):
"workflow_id": self.payload.get('workflow_id'),
"sequence_number": 1709,
"inputs": {
"sys.query": self.payload.get('query'),
"sys.query": f"开始查询 {self.payload.get('query')}",
"sys.files": [],
"sys.conversation_id": self.payload.get('conversation_id'),
"sys.user_id": self.payload.get('use_id')
@@ -93,7 +94,7 @@ class ChatCallbackEvent(BaseModel):
"id": self.payload.get('nodeid'),
"node_id": self.payload.get('nodeid'),
"node_type": "http-request",
"title": self.payload.get('title'),
"title": f"正在执行事件:{self.payload.get('title')}",
"index": self.payload.get('index'),
"predecessor_node_id": self.payload.get('predecessor_node_id'),
"inputs": '',
@@ -111,7 +112,7 @@ class ChatCallbackEvent(BaseModel):
"id": self.payload.get('nodeid'),
"node_id": self.payload.get('nodeid'),
"node_type": "http-request",
"title": self.payload.get('title'),
"title": f"事件执行结束:{self.payload.get('title')}",
"index": self.payload.get('index'),
"predecessor_node_id": self.payload.get('predecessor_node_id'),
"inputs": '',
@@ -138,15 +139,54 @@ class ChatCallbackEvent(BaseModel):
def get_MessageEnd_param(self) -> dict:
params = self.get_common_param()
nodeInfos = []
source_nodes = self.payload.get('source_node')
if source_nodes is not None:
for i in range(len(source_nodes)):
source_node:NodeWithScore = source_nodes[i]
metadata:dict = source_node.node.metadata
nodeInfo = {
"position": i,
"dataset_id": metadata.get("pipeline_id"),
"dataset_name": metadata.get("file_name"),
"document_id": source_node.node_id,
"document_name": metadata.get("file_name"),
"data_source_type": "upload_file",
"segment_id": source_node.node_id,
"retriever_from": "workflow",
"score": source_node.score,
"hit_count": 1,
"word_count": 632,
"segment_position": i,
"index_node_hash": "",
"content": source_node.text
}
nodeInfos.append(nodeInfo)
params.update({
'id':self.payload.get('message_id'),
'metadata':self.payload.get('metadata')
'metadata':{
"retriever_resources":nodeInfos,
"usage":{
"prompt_tokens": 4972,
"prompt_unit_price": "0.0",
"prompt_price_unit": "0.0",
"prompt_price": "0.0",
"completion_tokens": 332,
"completion_unit_price": "0.0",
"completion_price_unit": "0.0",
"completion_price": "0.0",
"total_tokens": 5304,
"total_price": "0.0",
"currency": "USD",
"latency": 4.897703120019287
}
}
})
return params
def to_response(self)-> dict|None:
try:
match self.event_type:
match self.event_type.value:
case "workflow_started":
return self.get_WorkflowStart_param()
case "workflow_finished":
@@ -180,24 +220,18 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
]
super().__init__(ignored_events, ignored_events)
self._aqueue = asyncio.Queue()
self._response:str = ''
self._response: StreamingAgentChatResponse = None
self._params:Dict[str,Any] = params
self._nodeStack:deque = deque()
self._nodeStack:List[str] = []
#添加工作流开始事件
data:ChatRequestData = self._params['data']
args:Dict[str,Any] = self._params['ids']
args.update(
{
'use_id': data.user,
'query': data.query,
'conversation_id': data.conversation_id
}
)
wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args)
wf_event = self.makeWorkflow_startEvent()
if wf_event.to_response() is not None:
self._aqueue.put_nowait(wf_event)
def setResponse(self,response: StreamingAgentChatResponse):
self._response = response
def on_event_start(
self,
event_type: CBEventType,
@@ -208,7 +242,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload))
self._nodeStack.append(event_id)
nindex = self._nodeStack.count() - 1
nindex = len(self._nodeStack) - 1
args:Dict[str,Any] = self._params['ids']
args.update(
{
@@ -222,7 +256,6 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
if nd_event.to_response() is not None:
self._aqueue.put_nowait(nd_event)
def on_event_end(
self,
event_type: CBEventType,
@@ -236,7 +269,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
args:Dict[str,Any] = self._params['ids']
nodeID = self._nodeStack[-1]
if nodeID == event_id:
nindex = self._nodeStack.count() - 1
nindex = len(self._nodeStack) - 1
args.update(
{
'nodeid':event_id,
@@ -250,7 +283,6 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
self._aqueue.put_nowait(nd_event)
self._nodeStack.pop()
def start_trace(self, trace_id: Optional[str] = None) -> None:
"""No-op."""
logger.info("trace_start:{}\n".format(trace_id))
@@ -262,23 +294,14 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
) -> None:
"""No-op."""
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
data:ChatRequestData = self._params['data']
args:Dict[str,Any] = self._params['ids']
args.update(
{
'response':self._response,
'conversation_id': data.conversation_id
}
)
wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args)
wf_event = self.makeWorkflow_finishedEvent()
if wf_event.to_response() is not None:
self._aqueue.put_nowait(wf_event)
args:Dict[str,Any] = self._params['ids']
msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args)
msgEnt_event = self.makeMessage_EndEvent()
if msgEnt_event.to_response() is not None:
self._aqueue.put_nowait(msgEnt_event)
self._aqueue.put_nowait(msgEnt_event)
async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]:
while not self._aqueue.empty() or not self.is_done:
@@ -287,6 +310,38 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
except asyncio.TimeoutError:
pass
def makeWorkflow_startEvent(self)->ChatCallbackEvent:
data:ChatRequestData = self._params['data']
args:Dict[str,Any] = self._params['ids']
args.update(
{
'use_id': data.user,
'query': data.query,
'conversation_id': data.conversation_id
}
)
return ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args)
def makeWorkflow_finishedEvent(self)->ChatCallbackEvent:
data:ChatRequestData = self._params['data']
args:Dict[str,Any] = self._params['ids']
args.update(
{
'response': '',
'conversation_id': data.conversation_id
}
)
return ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args)
def makeMessage_EndEvent(self)->ChatCallbackEvent:
args:Dict[str,Any] = self._params['ids']
if self._response is not None:
args.update({
'source_node': self._response.source_nodes
})
msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args)
return msgEnt_event
class IDManager:
def createID(self):
return {
@@ -390,7 +445,7 @@ async def post_conversations(request: Request, data: ChatRequestData):
params = data.inputs or {}
# 获取聊天引擎对象
chat_engine = get_chat_engine(filters=filters, params=params,prjFlag = data.prjFlag)
chat_engine = get_chat_engine(filters=filters, params=params)
# 启动聊天事件监听
ids = IDManager().createID()
@@ -399,7 +454,7 @@ async def post_conversations(request: Request, data: ChatRequestData):
# 执行异步聊天
response = await chat_engine.astream_chat(data.query)
event_handler.setResponse(response)
# 返回异步消息回应
return ChatStreamResponse(request, event_handler, response, data,ids)
@@ -468,29 +523,51 @@ async def query_conversations(user:str, first_id:str = None, limit:str = None, p
@v.get("/parameters")
async def query_parameters(user:str):
params = parameter().get(user)
if len(params) == 0:
params = BaseConfig().ParamterCfg()
return params
prjObj = ProjectInfo()
return BaseConfig().ParamterCfg(projectInfo = prjObj.projectNames())
@v.post("/messages/{message_id}/feedbacks")
async def post_feedbacks(request: Request,message_id:str,params:Dict[str,Any]):
if params['rating'] =='null':
if params['rating'] is None:
feedback().delete(message_id)
else:
condition = {'id':message_id}
results = message().query(**condition)
results = message().query(message_id)
if len(results) > 0:
result = results[0]
feedback().add(message_id=message_id,query=result['query'],
answer=result['answer'],rating=params['rating'])
@v.post("")
def upload_file(request: ChatFileUploadRequest) -> List[str]:
@v.post("/files/upload")
def upload_file(request: ChatFileUploadRequest):
try:
logger.info("Processing file")
return FileLoadService.process_file(request.base64)
resluts = ChatFileService.process_file(request.base64)
return {
'id':resluts.get('id'),
'name': resluts.get('name'),
'size': resluts.get('size'),
'extension':resluts.get('extension'),
'mime_type':resluts.get('mime_type'),
'created_by':str(uuid.uuid4()),
'created_at':int(time.time())
}
except Exception as e:
logger.error(f"Error processing file: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Error processing file")
@v.post("/applications")
def upload_file(request: ChatFileUploadRequest):
try:
logger.info("Processing file")
return PrjFileLoadService.process_file(request.base64)
except Exception as e:
logger.error(f"Error processing file: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Error processing file")
@v.post("/messages/{message_id}/suggested")
async def post_suggested(request: Request,message_id:str,user:str):
questions = await NextQuestionSuggestion.suggest_next_questions(message_id)
return {
"result": "success",
"data":questions
}