优化ChatCallbackEvent事件代码

This commit is contained in:
wanyaokun
2024-08-29 19:52:53 +08:00
parent 480a1f7fdc
commit 03c4eb1af1
2 changed files with 141 additions and 191 deletions
+132 -191
View File
@@ -26,199 +26,144 @@ api_router = r = APIRouter()
v1_router = v = APIRouter() v1_router = v = APIRouter()
class ChatCallbackEvent(BaseModel): class ChatCallbackEvent(BaseModel):
event_type: CBEventType event_type: ChatEventType
payload: Optional[Dict[str, Any]] = None payload: Optional[Dict[str, Any]] = None
event_id: str = ""
def get_retrieval_message(self) -> dict | None: def get_common_param(self)-> dict:
if self.payload: return {
nodes = self.payload.get("nodes") 'event': self.event_type.name,
if nodes: 'conversation_id':self.payload.get("conversation_id"),
msg = f"根据查询检索到 {len(nodes)} 源文件" 'message_id': self.payload.get("message_id"),
else: 'created_at': int(time.time()),
msg = f"查询检索中: '{self.payload.get('query_str')}'" 'task_id': self.payload.get("task_id")
return { }
"type": "events",
"data": {"title": msg},
}
else:
return None
def get_tool_message(self) -> dict | None: def get_WorkflowStart_param(self) -> dict:
func_call_args = self.payload.get("function_call") params = self.get_common_param()
if func_call_args is not None and "tool" in self.payload: params.update({
tool = self.payload.get("tool") 'workflow_run_id':self.payload.get('workflow_run_id'),
return { 'data':{
"type": "events", "id": self.payload.get('workflow_run_id'),
"data": { "workflow_id": self.payload.get('workflow_id'),
"title": f"调用工具 {tool.name} ,参数: {func_call_args}", "sequence_number": 1709,
"inputs": {
"sys.query": self.payload.get('query'),
"sys.files": [],
"sys.conversation_id": self.payload.get('conversation_id'),
"sys.user_id": self.payload.get('use_id')
}, },
"created_at": int(time.time())
} }
})
return params
def _is_output_serializable(self, output: Any) -> bool: def get_WorkflowFinished_param(self) -> dict:
try: params = self.get_common_param()
json.dumps(output) params.update({
return True 'workflow_run_id':self.payload.get('workflow_run_id'),
except TypeError: 'data':{
return False "id": self.payload.get('workflow_run_id'),
"workflow_id": self.payload.get('workflow_id'),
"sequence_number": 1709,
"status": "succeeded",
"outputs": {
"answer": self.payload.get('response')
},
"error": '',
"elapsed_time": 36.03764106379822,
"total_tokens": 11707,
"total_steps": 10,
"created_by": {
"id": str(uuid.uuid4()),
"user": self.payload.get('use_id')
},
"created_at": int(time.time()),
"finished_at": int(time.time()),
"files": []
}
})
return params
def get_NodeStart_param(self) -> dict:
params = self.get_common_param()
params.update({
'workflow_run_id':self.payload.get('workflow_run_id'),
'data':{
"id": self.payload.get('nodeid'),
"node_id": self.payload.get('nodeid'),
"node_type": "http-request",
"title": self.payload.get('title'),
"index": self.payload.get('index'),
"predecessor_node_id": self.payload.get('predecessor_node_id'),
"inputs": '',
"created_at": 1724398751,
"extras": {}
}
})
return params
def get_agent_tool_response(self) -> dict | None: def get_NodeFinished_param(self) -> dict:
response = self.payload.get("response") params = self.get_common_param()
if response is not None: params.update({
sources = response.sources 'workflow_run_id':self.payload.get('workflow_run_id'),
for source in sources: 'data':{
# Return the tool response here to include the toolCall information "id": self.payload.get('nodeid'),
if isinstance(source, ToolOutput): "node_id": self.payload.get('nodeid'),
if self._is_output_serializable(source.raw_output): "node_type": "http-request",
output = source.raw_output "title": self.payload.get('title'),
else: "index": self.payload.get('index'),
output = source.content "predecessor_node_id": self.payload.get('predecessor_node_id'),
"inputs": '',
"process_data": '',
"outputs": '',
"status": "succeeded",
"error": '',
"elapsed_time": 0.10402441816404462,
"execution_metadata": '',
"created_at": 1724398751,
"finished_at": 1724398751,
"files": []
}
})
return params
return { def get_Message_param(self) -> dict:
"type": "tools", params = self.get_common_param()
"data": { params.update({
"toolOutput": { 'id':self.payload.get('message_id'),
"output": output, 'answer':self.payload.get('answer')
"isError": source.is_error, })
}, return params
"toolCall": {
"id": None, # There is no tool id in the ToolOutput def get_MessageEnd_param(self) -> dict:
"name": source.tool_name, params = self.get_common_param()
"input": source.raw_input, params.update({
}, 'id':self.payload.get('message_id'),
}, 'metadata':self.payload.get('metadata')
} })
return params
def to_response(self): def to_response(self)-> dict|None:
try: try:
match self.event_type: match self.event_type:
case "retrieve": case "workflow_started":
return self.get_retrieval_message() return self.get_WorkflowStart_param()
case "function_call": case "workflow_finished":
return self.get_tool_message() return self.get_WorkflowFinished_param()
case "agent_step": case "node_started":
return self.get_agent_tool_response() return self.get_NodeStart_param()
case 'node_finished':
return self.get_NodeFinished_param()
case 'message':
return self.get_Message_param()
case 'message_end':
return self.get_MessageEnd_param()
case _: case _:
return None return None
except Exception as e: except Exception as e:
logger.error(f"转换回应时间时发生错误,原因: {e}") logger.error(f"转换回应时间时发生错误,原因: {e}")
return None return None
class DifyChatResponseEvent(BaseModel):
event: str
conversation_id: str
message_id: str
created_at: int = int(time.time())
task_id: str
def to_response(self):
return self.dict()
class Workflow_started_DifyChatResponseEvent(DifyChatResponseEvent):
event: str = 'workflow_started'
workflow_run_id:str
data:Dict[str,Any]
def __init__(self,**args):
args['data'] = {
"id": args['workflow_run_id'],
"workflow_id": args['workflow_id'],
"sequence_number": 1709,
"inputs": {
"sys.query": args['query'],
"sys.files": [],
"sys.conversation_id": args['conversation_id'],
"sys.user_id": args['use_id']
},
"created_at": int(time.time())
}
super().__init__(**args)
class Workflow_finished_DifyChatResponseEvent(DifyChatResponseEvent):
event: str = 'workflow_finished'
workflow_run_id:str
data:Dict[str,Any]
def __init__(self,**args):
args['data'] = {
"id": args['workflow_run_id'],
"workflow_id": args['workflow_id'],
"sequence_number": 1709,
"status": "succeeded",
"outputs": {
"answer": args['response']
},
"error": '',
"elapsed_time": 36.03764106379822,
"total_tokens": 11707,
"total_steps": 10,
"created_by": {
"id": str(uuid.uuid4()),
"user": args['use_id']
},
"created_at": int(time.time()),
"finished_at": int(time.time()),
"files": []
}
super().__init__(**args)
class Message_DifyChatResponseEvent(DifyChatResponseEvent):
event: str = 'message'
id:str
answer:str
def __init__(self,**args):
args['id'] = args['message_id']
super().__init__(**args)
class MessageEnd_DifyChatResponseEvent(DifyChatResponseEvent):
event: str = 'message_end'
id:str
metadata:Dict[str,Any] = {}
def __init__(self,**args):
args['id'] = args['message_id']
super().__init__(**args)
class Node_started_DifyChatResponseEvent(DifyChatResponseEvent):
event: str = 'node_started'
workflow_run_id:str
data:Dict[str,Any]
def __init__(self,**args):
args['data'] = {
"id": args['nodeid'],
"node_id": args['nodeid'],
"node_type": "http-request",
"title": args['title'],
"index": args['index'],
"predecessor_node_id": args['predecessor_node_id'],
"inputs": '',
"created_at": 1724398751,
"extras": {}
}
super().__init__(**args)
class Node_finished_DifyChatResponseEvent(DifyChatResponseEvent):
event: str = 'node_finished'
workflow_run_id:str
data:Dict[str,Any]
def __init__(self,**args):
args['data'] = {
"id": args['nodeid'],
"node_id": args['nodeid'],
"node_type": "http-request",
"title": args['title'],
"index": args['index'],
"predecessor_node_id": args['predecessor_node_id'],
"inputs": '',
"process_data": '',
"outputs": '',
"status": "succeeded",
"error": '',
"elapsed_time": 0.10402441816404462,
"execution_metadata": '',
"created_at": 1724398751,
"finished_at": 1724398751,
"files": []
}
super().__init__(**args)
class ChatEventCallbackHandler(BaseCallbackHandler): class ChatEventCallbackHandler(BaseCallbackHandler):
_aqueue: asyncio.Queue _aqueue: asyncio.Queue
is_done: bool = False is_done: bool = False
@@ -239,9 +184,8 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
self._nodeStack:deque = deque() self._nodeStack:deque = deque()
#添加工作流开始事件 #添加工作流开始事件
ids:Dict[str,Any] = self._params['ids']
data:ChatRequestData = self._params['data'] data:ChatRequestData = self._params['data']
args = ids args:Dict[str,Any] = self._params['ids']
args.update( args.update(
{ {
'use_id': data.user, 'use_id': data.user,
@@ -249,7 +193,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
'conversation_id': data.conversation_id 'conversation_id': data.conversation_id
} }
) )
wf_event = Workflow_started_DifyChatResponseEvent(**args) wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args)
if wf_event.to_response() is not None: if wf_event.to_response() is not None:
self._aqueue.put_nowait(wf_event) self._aqueue.put_nowait(wf_event)
@@ -264,9 +208,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
self._nodeStack.append(event_id) self._nodeStack.append(event_id)
nindex = self._nodeStack.count() - 1 nindex = self._nodeStack.count() - 1
args:Dict[str,Any] = self._params['ids']
ids:Dict[str,Any] = self._params['ids']
args = ids
args.update( args.update(
{ {
'nodeid':event_id, 'nodeid':event_id,
@@ -275,7 +217,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
'predecessor_node_id': self._nodeStack[nindex - 1] if nindex > 0 else '' 'predecessor_node_id': self._nodeStack[nindex - 1] if nindex > 0 else ''
} }
) )
nd_event = Node_started_DifyChatResponseEvent(**args) nd_event = ChatCallbackEvent(event_type = ChatEventType.NODE_START,payload = args)
if nd_event.to_response() is not None: if nd_event.to_response() is not None:
self._aqueue.put_nowait(nd_event) self._aqueue.put_nowait(nd_event)
@@ -302,7 +244,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
'predecessor_node_id':self._nodeStack[nindex - 1] if nindex > 0 else '' 'predecessor_node_id':self._nodeStack[nindex - 1] if nindex > 0 else ''
} }
) )
nd_event = Node_finished_DifyChatResponseEvent(**args) nd_event = ChatCallbackEvent(event_type = ChatEventType.NODE_FINISHED,payload = args)
if nd_event.to_response() is not None: if nd_event.to_response() is not None:
self._aqueue.put_nowait(nd_event) self._aqueue.put_nowait(nd_event)
self._nodeStack.pop() self._nodeStack.pop()
@@ -319,22 +261,21 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
) -> None: ) -> None:
"""No-op.""" """No-op."""
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map)) logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
ids:Dict[str,Any] = self._params['ids']
data:ChatRequestData = self._params['data'] data:ChatRequestData = self._params['data']
args = ids args:Dict[str,Any] = self._params['ids']
args.update( args.update(
{ {
'response':self._response, 'response':self._response,
'conversation_id': data.conversation_id 'conversation_id': data.conversation_id
} }
) )
wf_event = Workflow_finished_DifyChatResponseEvent(**args) wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args)
if wf_event.to_response() is not None: if wf_event.to_response() is not None:
self._aqueue.put_nowait(wf_event) self._aqueue.put_nowait(wf_event)
args = ids args:Dict[str,Any] = self._params['ids']
msgEnt_event = MessageEnd_DifyChatResponseEvent(**args) msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args)
if msgEnt_event.to_response() is not None: if msgEnt_event.to_response() is not None:
self._aqueue.put_nowait(msgEnt_event) self._aqueue.put_nowait(msgEnt_event)
@@ -367,8 +308,8 @@ class ChatStreamResponse(StreamingResponse):
'answer':token, 'answer':token,
'conversation_id':cls.data.conversation_id 'conversation_id':cls.data.conversation_id
}) })
event = Message_DifyChatResponseEvent(**params) event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE,payload = params)
data_str = json.dumps(event.dict()) data_str = json.dumps(event.to_response())
return f"{cls.DATA_PREFIX}{data_str}\n\n" return f"{cls.DATA_PREFIX}{data_str}\n\n"
@classmethod @classmethod
@@ -1,5 +1,7 @@
from pydantic import BaseModel from pydantic import BaseModel
import os import os
from enum import Enum
class BaseConfig(BaseModel): class BaseConfig(BaseModel):
projectInfo:str = os.getenv("PROJECT_TITLE","您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!") projectInfo:str = os.getenv("PROJECT_TITLE","您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!")
@@ -69,3 +71,10 @@ class BaseConfig(BaseModel):
} }
class ChatEventType(str, Enum):
WORKFLOW_START = "workflow_started"
WORKFLOW_FINISHED = "workflow_finished"
NODE_START = "node_started"
NODE_FINISHED = "node_finished"
MESSAGE = "message"
MESSAGE_END = "message_end"