优化ChatCallbackEvent事件代码
This commit is contained in:
+132
-191
@@ -26,199 +26,144 @@ api_router = r = APIRouter()
|
||||
v1_router = v = APIRouter()
|
||||
|
||||
class ChatCallbackEvent(BaseModel):
|
||||
event_type: CBEventType
|
||||
event_type: ChatEventType
|
||||
payload: Optional[Dict[str, Any]] = None
|
||||
event_id: str = ""
|
||||
|
||||
def get_retrieval_message(self) -> dict | None:
|
||||
if self.payload:
|
||||
nodes = self.payload.get("nodes")
|
||||
if nodes:
|
||||
msg = f"根据查询检索到 {len(nodes)} 源文件"
|
||||
else:
|
||||
msg = f"查询检索中: '{self.payload.get('query_str')}'"
|
||||
return {
|
||||
"type": "events",
|
||||
"data": {"title": msg},
|
||||
}
|
||||
else:
|
||||
return None
|
||||
def get_common_param(self)-> dict:
|
||||
return {
|
||||
'event': self.event_type.name,
|
||||
'conversation_id':self.payload.get("conversation_id"),
|
||||
'message_id': self.payload.get("message_id"),
|
||||
'created_at': int(time.time()),
|
||||
'task_id': self.payload.get("task_id")
|
||||
}
|
||||
|
||||
def get_tool_message(self) -> dict | None:
|
||||
func_call_args = self.payload.get("function_call")
|
||||
if func_call_args is not None and "tool" in self.payload:
|
||||
tool = self.payload.get("tool")
|
||||
return {
|
||||
"type": "events",
|
||||
"data": {
|
||||
"title": f"调用工具 {tool.name} ,参数: {func_call_args}",
|
||||
def get_WorkflowStart_param(self) -> dict:
|
||||
params = self.get_common_param()
|
||||
params.update({
|
||||
'workflow_run_id':self.payload.get('workflow_run_id'),
|
||||
'data':{
|
||||
"id": self.payload.get('workflow_run_id'),
|
||||
"workflow_id": self.payload.get('workflow_id'),
|
||||
"sequence_number": 1709,
|
||||
"inputs": {
|
||||
"sys.query": self.payload.get('query'),
|
||||
"sys.files": [],
|
||||
"sys.conversation_id": self.payload.get('conversation_id'),
|
||||
"sys.user_id": self.payload.get('use_id')
|
||||
},
|
||||
"created_at": int(time.time())
|
||||
}
|
||||
})
|
||||
return params
|
||||
|
||||
def _is_output_serializable(self, output: Any) -> bool:
|
||||
try:
|
||||
json.dumps(output)
|
||||
return True
|
||||
except TypeError:
|
||||
return False
|
||||
def get_WorkflowFinished_param(self) -> dict:
|
||||
params = self.get_common_param()
|
||||
params.update({
|
||||
'workflow_run_id':self.payload.get('workflow_run_id'),
|
||||
'data':{
|
||||
"id": self.payload.get('workflow_run_id'),
|
||||
"workflow_id": self.payload.get('workflow_id'),
|
||||
"sequence_number": 1709,
|
||||
"status": "succeeded",
|
||||
"outputs": {
|
||||
"answer": self.payload.get('response')
|
||||
},
|
||||
"error": '',
|
||||
"elapsed_time": 36.03764106379822,
|
||||
"total_tokens": 11707,
|
||||
"total_steps": 10,
|
||||
"created_by": {
|
||||
"id": str(uuid.uuid4()),
|
||||
"user": self.payload.get('use_id')
|
||||
},
|
||||
"created_at": int(time.time()),
|
||||
"finished_at": int(time.time()),
|
||||
"files": []
|
||||
}
|
||||
})
|
||||
return params
|
||||
|
||||
def get_agent_tool_response(self) -> dict | None:
|
||||
response = self.payload.get("response")
|
||||
if response is not None:
|
||||
sources = response.sources
|
||||
for source in sources:
|
||||
# Return the tool response here to include the toolCall information
|
||||
if isinstance(source, ToolOutput):
|
||||
if self._is_output_serializable(source.raw_output):
|
||||
output = source.raw_output
|
||||
else:
|
||||
output = source.content
|
||||
def get_NodeStart_param(self) -> dict:
|
||||
params = self.get_common_param()
|
||||
params.update({
|
||||
'workflow_run_id':self.payload.get('workflow_run_id'),
|
||||
'data':{
|
||||
"id": self.payload.get('nodeid'),
|
||||
"node_id": self.payload.get('nodeid'),
|
||||
"node_type": "http-request",
|
||||
"title": self.payload.get('title'),
|
||||
"index": self.payload.get('index'),
|
||||
"predecessor_node_id": self.payload.get('predecessor_node_id'),
|
||||
"inputs": '',
|
||||
"created_at": 1724398751,
|
||||
"extras": {}
|
||||
}
|
||||
})
|
||||
return params
|
||||
|
||||
return {
|
||||
"type": "tools",
|
||||
"data": {
|
||||
"toolOutput": {
|
||||
"output": output,
|
||||
"isError": source.is_error,
|
||||
},
|
||||
"toolCall": {
|
||||
"id": None, # There is no tool id in the ToolOutput
|
||||
"name": source.tool_name,
|
||||
"input": source.raw_input,
|
||||
},
|
||||
},
|
||||
}
|
||||
def get_NodeFinished_param(self) -> dict:
|
||||
params = self.get_common_param()
|
||||
params.update({
|
||||
'workflow_run_id':self.payload.get('workflow_run_id'),
|
||||
'data':{
|
||||
"id": self.payload.get('nodeid'),
|
||||
"node_id": self.payload.get('nodeid'),
|
||||
"node_type": "http-request",
|
||||
"title": self.payload.get('title'),
|
||||
"index": self.payload.get('index'),
|
||||
"predecessor_node_id": self.payload.get('predecessor_node_id'),
|
||||
"inputs": '',
|
||||
"process_data": '',
|
||||
"outputs": '',
|
||||
"status": "succeeded",
|
||||
"error": '',
|
||||
"elapsed_time": 0.10402441816404462,
|
||||
"execution_metadata": '',
|
||||
"created_at": 1724398751,
|
||||
"finished_at": 1724398751,
|
||||
"files": []
|
||||
}
|
||||
})
|
||||
return params
|
||||
|
||||
def to_response(self):
|
||||
def get_Message_param(self) -> dict:
|
||||
params = self.get_common_param()
|
||||
params.update({
|
||||
'id':self.payload.get('message_id'),
|
||||
'answer':self.payload.get('answer')
|
||||
})
|
||||
return params
|
||||
|
||||
def get_MessageEnd_param(self) -> dict:
|
||||
params = self.get_common_param()
|
||||
params.update({
|
||||
'id':self.payload.get('message_id'),
|
||||
'metadata':self.payload.get('metadata')
|
||||
})
|
||||
return params
|
||||
|
||||
def to_response(self)-> dict|None:
|
||||
try:
|
||||
match self.event_type:
|
||||
case "retrieve":
|
||||
return self.get_retrieval_message()
|
||||
case "function_call":
|
||||
return self.get_tool_message()
|
||||
case "agent_step":
|
||||
return self.get_agent_tool_response()
|
||||
case "workflow_started":
|
||||
return self.get_WorkflowStart_param()
|
||||
case "workflow_finished":
|
||||
return self.get_WorkflowFinished_param()
|
||||
case "node_started":
|
||||
return self.get_NodeStart_param()
|
||||
case 'node_finished':
|
||||
return self.get_NodeFinished_param()
|
||||
case 'message':
|
||||
return self.get_Message_param()
|
||||
case 'message_end':
|
||||
return self.get_MessageEnd_param()
|
||||
case _:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"转换回应时间时发生错误,原因: {e}")
|
||||
return None
|
||||
|
||||
class DifyChatResponseEvent(BaseModel):
|
||||
event: str
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
created_at: int = int(time.time())
|
||||
task_id: str
|
||||
|
||||
def to_response(self):
|
||||
return self.dict()
|
||||
|
||||
class Workflow_started_DifyChatResponseEvent(DifyChatResponseEvent):
|
||||
event: str = 'workflow_started'
|
||||
workflow_run_id:str
|
||||
data:Dict[str,Any]
|
||||
def __init__(self,**args):
|
||||
args['data'] = {
|
||||
"id": args['workflow_run_id'],
|
||||
"workflow_id": args['workflow_id'],
|
||||
"sequence_number": 1709,
|
||||
"inputs": {
|
||||
"sys.query": args['query'],
|
||||
"sys.files": [],
|
||||
"sys.conversation_id": args['conversation_id'],
|
||||
"sys.user_id": args['use_id']
|
||||
},
|
||||
"created_at": int(time.time())
|
||||
}
|
||||
super().__init__(**args)
|
||||
|
||||
class Workflow_finished_DifyChatResponseEvent(DifyChatResponseEvent):
|
||||
event: str = 'workflow_finished'
|
||||
workflow_run_id:str
|
||||
data:Dict[str,Any]
|
||||
def __init__(self,**args):
|
||||
args['data'] = {
|
||||
"id": args['workflow_run_id'],
|
||||
"workflow_id": args['workflow_id'],
|
||||
"sequence_number": 1709,
|
||||
"status": "succeeded",
|
||||
"outputs": {
|
||||
"answer": args['response']
|
||||
},
|
||||
"error": '',
|
||||
"elapsed_time": 36.03764106379822,
|
||||
"total_tokens": 11707,
|
||||
"total_steps": 10,
|
||||
"created_by": {
|
||||
"id": str(uuid.uuid4()),
|
||||
"user": args['use_id']
|
||||
},
|
||||
"created_at": int(time.time()),
|
||||
"finished_at": int(time.time()),
|
||||
"files": []
|
||||
}
|
||||
super().__init__(**args)
|
||||
|
||||
class Message_DifyChatResponseEvent(DifyChatResponseEvent):
|
||||
event: str = 'message'
|
||||
id:str
|
||||
answer:str
|
||||
def __init__(self,**args):
|
||||
args['id'] = args['message_id']
|
||||
super().__init__(**args)
|
||||
|
||||
class MessageEnd_DifyChatResponseEvent(DifyChatResponseEvent):
|
||||
event: str = 'message_end'
|
||||
id:str
|
||||
metadata:Dict[str,Any] = {}
|
||||
def __init__(self,**args):
|
||||
args['id'] = args['message_id']
|
||||
super().__init__(**args)
|
||||
|
||||
class Node_started_DifyChatResponseEvent(DifyChatResponseEvent):
|
||||
event: str = 'node_started'
|
||||
workflow_run_id:str
|
||||
data:Dict[str,Any]
|
||||
def __init__(self,**args):
|
||||
args['data'] = {
|
||||
"id": args['nodeid'],
|
||||
"node_id": args['nodeid'],
|
||||
"node_type": "http-request",
|
||||
"title": args['title'],
|
||||
"index": args['index'],
|
||||
"predecessor_node_id": args['predecessor_node_id'],
|
||||
"inputs": '',
|
||||
"created_at": 1724398751,
|
||||
"extras": {}
|
||||
}
|
||||
super().__init__(**args)
|
||||
|
||||
class Node_finished_DifyChatResponseEvent(DifyChatResponseEvent):
|
||||
event: str = 'node_finished'
|
||||
workflow_run_id:str
|
||||
data:Dict[str,Any]
|
||||
def __init__(self,**args):
|
||||
args['data'] = {
|
||||
"id": args['nodeid'],
|
||||
"node_id": args['nodeid'],
|
||||
"node_type": "http-request",
|
||||
"title": args['title'],
|
||||
"index": args['index'],
|
||||
"predecessor_node_id": args['predecessor_node_id'],
|
||||
"inputs": '',
|
||||
"process_data": '',
|
||||
"outputs": '',
|
||||
"status": "succeeded",
|
||||
"error": '',
|
||||
"elapsed_time": 0.10402441816404462,
|
||||
"execution_metadata": '',
|
||||
"created_at": 1724398751,
|
||||
"finished_at": 1724398751,
|
||||
"files": []
|
||||
}
|
||||
super().__init__(**args)
|
||||
|
||||
class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
_aqueue: asyncio.Queue
|
||||
is_done: bool = False
|
||||
@@ -239,9 +184,8 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
self._nodeStack:deque = deque()
|
||||
|
||||
#添加工作流开始事件
|
||||
ids:Dict[str,Any] = self._params['ids']
|
||||
data:ChatRequestData = self._params['data']
|
||||
args = ids
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
args.update(
|
||||
{
|
||||
'use_id': data.user,
|
||||
@@ -249,7 +193,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
'conversation_id': data.conversation_id
|
||||
}
|
||||
)
|
||||
wf_event = Workflow_started_DifyChatResponseEvent(**args)
|
||||
wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args)
|
||||
if wf_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(wf_event)
|
||||
|
||||
@@ -264,9 +208,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
self._nodeStack.append(event_id)
|
||||
nindex = self._nodeStack.count() - 1
|
||||
|
||||
ids:Dict[str,Any] = self._params['ids']
|
||||
args = ids
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
args.update(
|
||||
{
|
||||
'nodeid':event_id,
|
||||
@@ -275,7 +217,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
'predecessor_node_id': self._nodeStack[nindex - 1] if nindex > 0 else ''
|
||||
}
|
||||
)
|
||||
nd_event = Node_started_DifyChatResponseEvent(**args)
|
||||
nd_event = ChatCallbackEvent(event_type = ChatEventType.NODE_START,payload = args)
|
||||
if nd_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(nd_event)
|
||||
|
||||
@@ -302,7 +244,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
'predecessor_node_id':self._nodeStack[nindex - 1] if nindex > 0 else ''
|
||||
}
|
||||
)
|
||||
nd_event = Node_finished_DifyChatResponseEvent(**args)
|
||||
nd_event = ChatCallbackEvent(event_type = ChatEventType.NODE_FINISHED,payload = args)
|
||||
if nd_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(nd_event)
|
||||
self._nodeStack.pop()
|
||||
@@ -319,22 +261,21 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
) -> None:
|
||||
"""No-op."""
|
||||
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
|
||||
ids:Dict[str,Any] = self._params['ids']
|
||||
data:ChatRequestData = self._params['data']
|
||||
args = ids
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
args.update(
|
||||
{
|
||||
'response':self._response,
|
||||
'conversation_id': data.conversation_id
|
||||
}
|
||||
)
|
||||
wf_event = Workflow_finished_DifyChatResponseEvent(**args)
|
||||
wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args)
|
||||
if wf_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(wf_event)
|
||||
|
||||
|
||||
args = ids
|
||||
msgEnt_event = MessageEnd_DifyChatResponseEvent(**args)
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args)
|
||||
if msgEnt_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(msgEnt_event)
|
||||
|
||||
@@ -367,8 +308,8 @@ class ChatStreamResponse(StreamingResponse):
|
||||
'answer':token,
|
||||
'conversation_id':cls.data.conversation_id
|
||||
})
|
||||
event = Message_DifyChatResponseEvent(**params)
|
||||
data_str = json.dumps(event.dict())
|
||||
event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE,payload = params)
|
||||
data_str = json.dumps(event.to_response())
|
||||
return f"{cls.DATA_PREFIX}{data_str}\n\n"
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from pydantic import BaseModel
|
||||
import os
|
||||
from enum import Enum
|
||||
|
||||
class BaseConfig(BaseModel):
|
||||
projectInfo:str = os.getenv("PROJECT_TITLE","您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!")
|
||||
|
||||
@@ -69,3 +71,10 @@ class BaseConfig(BaseModel):
|
||||
}
|
||||
|
||||
|
||||
class ChatEventType(str, Enum):
|
||||
WORKFLOW_START = "workflow_started"
|
||||
WORKFLOW_FINISHED = "workflow_finished"
|
||||
NODE_START = "node_started"
|
||||
NODE_FINISHED = "node_finished"
|
||||
MESSAGE = "message"
|
||||
MESSAGE_END = "message_end"
|
||||
Reference in New Issue
Block a user