优化ChatCallbackEvent事件代码

This commit is contained in:
wanyaokun
2024-08-29 19:52:53 +08:00
parent 480a1f7fdc
commit 03c4eb1af1
2 changed files with 141 additions and 191 deletions
+102 -161
View File
@@ -26,125 +26,48 @@ api_router = r = APIRouter()
v1_router = v = APIRouter() v1_router = v = APIRouter()
class ChatCallbackEvent(BaseModel): class ChatCallbackEvent(BaseModel):
event_type: CBEventType event_type: ChatEventType
payload: Optional[Dict[str, Any]] = None payload: Optional[Dict[str, Any]] = None
event_id: str = ""
def get_retrieval_message(self) -> dict | None: def get_common_param(self)-> dict:
if self.payload:
nodes = self.payload.get("nodes")
if nodes:
msg = f"根据查询检索到 {len(nodes)} 源文件"
else:
msg = f"查询检索中: '{self.payload.get('query_str')}'"
return { return {
"type": "events", 'event': self.event_type.name,
"data": {"title": msg}, 'conversation_id':self.payload.get("conversation_id"),
} 'message_id': self.payload.get("message_id"),
else: 'created_at': int(time.time()),
return None 'task_id': self.payload.get("task_id")
def get_tool_message(self) -> dict | None:
func_call_args = self.payload.get("function_call")
if func_call_args is not None and "tool" in self.payload:
tool = self.payload.get("tool")
return {
"type": "events",
"data": {
"title": f"调用工具 {tool.name} ,参数: {func_call_args}",
},
} }
def _is_output_serializable(self, output: Any) -> bool: def get_WorkflowStart_param(self) -> dict:
try: params = self.get_common_param()
json.dumps(output) params.update({
return True 'workflow_run_id':self.payload.get('workflow_run_id'),
except TypeError: 'data':{
return False "id": self.payload.get('workflow_run_id'),
"workflow_id": self.payload.get('workflow_id'),
def get_agent_tool_response(self) -> dict | None:
response = self.payload.get("response")
if response is not None:
sources = response.sources
for source in sources:
# Return the tool response here to include the toolCall information
if isinstance(source, ToolOutput):
if self._is_output_serializable(source.raw_output):
output = source.raw_output
else:
output = source.content
return {
"type": "tools",
"data": {
"toolOutput": {
"output": output,
"isError": source.is_error,
},
"toolCall": {
"id": None, # There is no tool id in the ToolOutput
"name": source.tool_name,
"input": source.raw_input,
},
},
}
def to_response(self):
try:
match self.event_type:
case "retrieve":
return self.get_retrieval_message()
case "function_call":
return self.get_tool_message()
case "agent_step":
return self.get_agent_tool_response()
case _:
return None
except Exception as e:
logger.error(f"转换回应时间时发生错误,原因: {e}")
return None
class DifyChatResponseEvent(BaseModel):
event: str
conversation_id: str
message_id: str
created_at: int = int(time.time())
task_id: str
def to_response(self):
return self.dict()
class Workflow_started_DifyChatResponseEvent(DifyChatResponseEvent):
event: str = 'workflow_started'
workflow_run_id:str
data:Dict[str,Any]
def __init__(self,**args):
args['data'] = {
"id": args['workflow_run_id'],
"workflow_id": args['workflow_id'],
"sequence_number": 1709, "sequence_number": 1709,
"inputs": { "inputs": {
"sys.query": args['query'], "sys.query": self.payload.get('query'),
"sys.files": [], "sys.files": [],
"sys.conversation_id": args['conversation_id'], "sys.conversation_id": self.payload.get('conversation_id'),
"sys.user_id": args['use_id'] "sys.user_id": self.payload.get('use_id')
}, },
"created_at": int(time.time()) "created_at": int(time.time())
} }
super().__init__(**args) })
return params
class Workflow_finished_DifyChatResponseEvent(DifyChatResponseEvent): def get_WorkflowFinished_param(self) -> dict:
event: str = 'workflow_finished' params = self.get_common_param()
workflow_run_id:str params.update({
data:Dict[str,Any] 'workflow_run_id':self.payload.get('workflow_run_id'),
def __init__(self,**args): 'data':{
args['data'] = { "id": self.payload.get('workflow_run_id'),
"id": args['workflow_run_id'], "workflow_id": self.payload.get('workflow_id'),
"workflow_id": args['workflow_id'],
"sequence_number": 1709, "sequence_number": 1709,
"status": "succeeded", "status": "succeeded",
"outputs": { "outputs": {
"answer": args['response'] "answer": self.payload.get('response')
}, },
"error": '', "error": '',
"elapsed_time": 36.03764106379822, "elapsed_time": 36.03764106379822,
@@ -152,60 +75,44 @@ class Workflow_finished_DifyChatResponseEvent(DifyChatResponseEvent):
"total_steps": 10, "total_steps": 10,
"created_by": { "created_by": {
"id": str(uuid.uuid4()), "id": str(uuid.uuid4()),
"user": args['use_id'] "user": self.payload.get('use_id')
}, },
"created_at": int(time.time()), "created_at": int(time.time()),
"finished_at": int(time.time()), "finished_at": int(time.time()),
"files": [] "files": []
} }
super().__init__(**args) })
return params
class Message_DifyChatResponseEvent(DifyChatResponseEvent): def get_NodeStart_param(self) -> dict:
event: str = 'message' params = self.get_common_param()
id:str params.update({
answer:str 'workflow_run_id':self.payload.get('workflow_run_id'),
def __init__(self,**args): 'data':{
args['id'] = args['message_id'] "id": self.payload.get('nodeid'),
super().__init__(**args) "node_id": self.payload.get('nodeid'),
class MessageEnd_DifyChatResponseEvent(DifyChatResponseEvent):
event: str = 'message_end'
id:str
metadata:Dict[str,Any] = {}
def __init__(self,**args):
args['id'] = args['message_id']
super().__init__(**args)
class Node_started_DifyChatResponseEvent(DifyChatResponseEvent):
event: str = 'node_started'
workflow_run_id:str
data:Dict[str,Any]
def __init__(self,**args):
args['data'] = {
"id": args['nodeid'],
"node_id": args['nodeid'],
"node_type": "http-request", "node_type": "http-request",
"title": args['title'], "title": self.payload.get('title'),
"index": args['index'], "index": self.payload.get('index'),
"predecessor_node_id": args['predecessor_node_id'], "predecessor_node_id": self.payload.get('predecessor_node_id'),
"inputs": '', "inputs": '',
"created_at": 1724398751, "created_at": 1724398751,
"extras": {} "extras": {}
} }
super().__init__(**args) })
return params
class Node_finished_DifyChatResponseEvent(DifyChatResponseEvent): def get_NodeFinished_param(self) -> dict:
event: str = 'node_finished' params = self.get_common_param()
workflow_run_id:str params.update({
data:Dict[str,Any] 'workflow_run_id':self.payload.get('workflow_run_id'),
def __init__(self,**args): 'data':{
args['data'] = { "id": self.payload.get('nodeid'),
"id": args['nodeid'], "node_id": self.payload.get('nodeid'),
"node_id": args['nodeid'],
"node_type": "http-request", "node_type": "http-request",
"title": args['title'], "title": self.payload.get('title'),
"index": args['index'], "index": self.payload.get('index'),
"predecessor_node_id": args['predecessor_node_id'], "predecessor_node_id": self.payload.get('predecessor_node_id'),
"inputs": '', "inputs": '',
"process_data": '', "process_data": '',
"outputs": '', "outputs": '',
@@ -217,7 +124,45 @@ class Node_finished_DifyChatResponseEvent(DifyChatResponseEvent):
"finished_at": 1724398751, "finished_at": 1724398751,
"files": [] "files": []
} }
super().__init__(**args) })
return params
def get_Message_param(self) -> dict:
params = self.get_common_param()
params.update({
'id':self.payload.get('message_id'),
'answer':self.payload.get('answer')
})
return params
def get_MessageEnd_param(self) -> dict:
params = self.get_common_param()
params.update({
'id':self.payload.get('message_id'),
'metadata':self.payload.get('metadata')
})
return params
def to_response(self)-> dict|None:
try:
match self.event_type:
case "workflow_started":
return self.get_WorkflowStart_param()
case "workflow_finished":
return self.get_WorkflowFinished_param()
case "node_started":
return self.get_NodeStart_param()
case 'node_finished':
return self.get_NodeFinished_param()
case 'message':
return self.get_Message_param()
case 'message_end':
return self.get_MessageEnd_param()
case _:
return None
except Exception as e:
logger.error(f"转换回应时间时发生错误,原因: {e}")
return None
class ChatEventCallbackHandler(BaseCallbackHandler): class ChatEventCallbackHandler(BaseCallbackHandler):
_aqueue: asyncio.Queue _aqueue: asyncio.Queue
@@ -239,9 +184,8 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
self._nodeStack:deque = deque() self._nodeStack:deque = deque()
#添加工作流开始事件 #添加工作流开始事件
ids:Dict[str,Any] = self._params['ids']
data:ChatRequestData = self._params['data'] data:ChatRequestData = self._params['data']
args = ids args:Dict[str,Any] = self._params['ids']
args.update( args.update(
{ {
'use_id': data.user, 'use_id': data.user,
@@ -249,7 +193,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
'conversation_id': data.conversation_id 'conversation_id': data.conversation_id
} }
) )
wf_event = Workflow_started_DifyChatResponseEvent(**args) wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args)
if wf_event.to_response() is not None: if wf_event.to_response() is not None:
self._aqueue.put_nowait(wf_event) self._aqueue.put_nowait(wf_event)
@@ -264,9 +208,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
self._nodeStack.append(event_id) self._nodeStack.append(event_id)
nindex = self._nodeStack.count() - 1 nindex = self._nodeStack.count() - 1
args:Dict[str,Any] = self._params['ids']
ids:Dict[str,Any] = self._params['ids']
args = ids
args.update( args.update(
{ {
'nodeid':event_id, 'nodeid':event_id,
@@ -275,7 +217,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
'predecessor_node_id': self._nodeStack[nindex - 1] if nindex > 0 else '' 'predecessor_node_id': self._nodeStack[nindex - 1] if nindex > 0 else ''
} }
) )
nd_event = Node_started_DifyChatResponseEvent(**args) nd_event = ChatCallbackEvent(event_type = ChatEventType.NODE_START,payload = args)
if nd_event.to_response() is not None: if nd_event.to_response() is not None:
self._aqueue.put_nowait(nd_event) self._aqueue.put_nowait(nd_event)
@@ -302,7 +244,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
'predecessor_node_id':self._nodeStack[nindex - 1] if nindex > 0 else '' 'predecessor_node_id':self._nodeStack[nindex - 1] if nindex > 0 else ''
} }
) )
nd_event = Node_finished_DifyChatResponseEvent(**args) nd_event = ChatCallbackEvent(event_type = ChatEventType.NODE_FINISHED,payload = args)
if nd_event.to_response() is not None: if nd_event.to_response() is not None:
self._aqueue.put_nowait(nd_event) self._aqueue.put_nowait(nd_event)
self._nodeStack.pop() self._nodeStack.pop()
@@ -319,22 +261,21 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
) -> None: ) -> None:
"""No-op.""" """No-op."""
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map)) logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
ids:Dict[str,Any] = self._params['ids']
data:ChatRequestData = self._params['data'] data:ChatRequestData = self._params['data']
args = ids args:Dict[str,Any] = self._params['ids']
args.update( args.update(
{ {
'response':self._response, 'response':self._response,
'conversation_id': data.conversation_id 'conversation_id': data.conversation_id
} }
) )
wf_event = Workflow_finished_DifyChatResponseEvent(**args) wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args)
if wf_event.to_response() is not None: if wf_event.to_response() is not None:
self._aqueue.put_nowait(wf_event) self._aqueue.put_nowait(wf_event)
args = ids args:Dict[str,Any] = self._params['ids']
msgEnt_event = MessageEnd_DifyChatResponseEvent(**args) msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args)
if msgEnt_event.to_response() is not None: if msgEnt_event.to_response() is not None:
self._aqueue.put_nowait(msgEnt_event) self._aqueue.put_nowait(msgEnt_event)
@@ -367,8 +308,8 @@ class ChatStreamResponse(StreamingResponse):
'answer':token, 'answer':token,
'conversation_id':cls.data.conversation_id 'conversation_id':cls.data.conversation_id
}) })
event = Message_DifyChatResponseEvent(**params) event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE,payload = params)
data_str = json.dumps(event.dict()) data_str = json.dumps(event.to_response())
return f"{cls.DATA_PREFIX}{data_str}\n\n" return f"{cls.DATA_PREFIX}{data_str}\n\n"
@classmethod @classmethod
@@ -1,5 +1,7 @@
from pydantic import BaseModel from pydantic import BaseModel
import os import os
from enum import Enum
class BaseConfig(BaseModel): class BaseConfig(BaseModel):
projectInfo:str = os.getenv("PROJECT_TITLE","您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!") projectInfo:str = os.getenv("PROJECT_TITLE","您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!")
@@ -69,3 +71,10 @@ class BaseConfig(BaseModel):
} }
class ChatEventType(str, Enum):
WORKFLOW_START = "workflow_started"
WORKFLOW_FINISHED = "workflow_finished"
NODE_START = "node_started"
NODE_FINISHED = "node_finished"
MESSAGE = "message"
MESSAGE_END = "message_end"