支持节点元数据输出及代码优化,减少事件重复添加

This commit is contained in:
wanyaokun
2024-09-02 19:58:18 +08:00
parent a4dd385368
commit 728ee06c5a
3 changed files with 64 additions and 37 deletions
+61 -36
View File
@@ -23,11 +23,18 @@ import uuid
from app.api.routers.services.fileServices import PrjFileLoadService,ChatFileService
from app.api.routers.services.suggestion import NextQuestionSuggestion
import time
from llama_index.core.settings import Settings
from llama_index.core.callbacks import CallbackManager
logger = logging.getLogger("uvicorn")
v1_router = v = APIRouter()
Settings.llm.callback_manager = CallbackManager()
gEvent_handler = None
class ChatCallbackEvent(BaseModel):
event_type: ChatEventType
payload: Optional[Dict[str, Any]] = None
@@ -210,7 +217,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
_aqueue: asyncio.Queue
is_done: bool = False
def __init__(self,**params):
def __init__(self):
"""Initialize the base callback handler."""
ignored_events = [
# CBEventType.CHUNKING,
@@ -222,13 +229,15 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
super().__init__(ignored_events, ignored_events)
self._aqueue = asyncio.Queue()
self._response: StreamingAgentChatResponse = None
self._params:Dict[str,Any] = params
self._ids:Dict[str,Any] = {}
self._chatData:ChatRequestData = None
self._nodeStack:List[str] = []
#添加工作流开始事件
wf_event = self.makeWorkflow_startEvent()
if wf_event.to_response() is not None:
self._aqueue.put_nowait(wf_event)
self._firstEventID:str = None
def setInitParams(self,ids:dict,data:ChatRequestData):
self._ids = ids
self._chatData = data
self._firstEventID = None
def setResponse(self,response: StreamingAgentChatResponse):
self._response = response
@@ -240,11 +249,15 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
event_id: str = "",
**kwargs: Any,
) -> str:
if self._firstEventID is None:
self._firstEventID = event_id
self.start()
logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload))
self._nodeStack.append(event_id)
nindex = len(self._nodeStack) - 1
args:Dict[str,Any] = self._params['ids']
args:Dict[str,Any] = self._ids
args.update(
{
'nodeid':event_id,
@@ -267,7 +280,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
logger.info("event_end:{} type:{} payload:{}\n".format(event_id, event_type, payload))
#self.response = payload.get("response","")
args:Dict[str,Any] = self._params['ids']
args:Dict[str,Any] = self._ids
nodeID = self._nodeStack[-1]
if nodeID == event_id:
nindex = len(self._nodeStack) - 1
@@ -284,6 +297,9 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
self._aqueue.put_nowait(nd_event)
self._nodeStack.pop()
if self._firstEventID is not None and self._firstEventID == event_id:
self.finished()
def start_trace(self, trace_id: Optional[str] = None) -> None:
"""No-op."""
logger.info("trace_start:{}\n".format(trace_id))
@@ -294,15 +310,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""No-op."""
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
wf_event = self.makeWorkflow_finishedEvent()
if wf_event.to_response() is not None:
self._aqueue.put_nowait(wf_event)
msgEnt_event = self.makeMessage_EndEvent()
if msgEnt_event.to_response() is not None:
self._aqueue.put_nowait(msgEnt_event)
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]:
while not self._aqueue.empty() or not self.is_done:
@@ -312,30 +320,28 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
pass
def makeWorkflow_startEvent(self)->ChatCallbackEvent:
data:ChatRequestData = self._params['data']
args:Dict[str,Any] = self._params['ids']
args:Dict[str,Any] = self._ids
args.update(
{
'use_id': data.user,
'query': data.query,
'conversation_id': data.conversation_id
'use_id': self._chatData.user,
'query': self._chatData.query,
'conversation_id': self._chatData.conversation_id
}
)
return ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args)
def makeWorkflow_finishedEvent(self)->ChatCallbackEvent:
data:ChatRequestData = self._params['data']
args:Dict[str,Any] = self._params['ids']
def makeWorkflow_finishedEvent(self)->ChatCallbackEvent:
args:Dict[str,Any] = self._ids
args.update(
{
'response': '',
'conversation_id': data.conversation_id
'conversation_id': self._chatData.conversation_id
}
)
return ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args)
def makeMessage_EndEvent(self)->ChatCallbackEvent:
args:Dict[str,Any] = self._params['ids']
args:Dict[str,Any] = self._ids
if self._response is not None:
args.update({
'source_node': self._response.source_nodes
@@ -343,6 +349,21 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args)
return msgEnt_event
def start(self):
#添加工作流开始事件
wf_event = self.makeWorkflow_startEvent()
if wf_event.to_response() is not None:
self._aqueue.put_nowait(wf_event)
def finished(self):
wf_event = self.makeWorkflow_finishedEvent()
if wf_event.to_response() is not None:
self._aqueue.put_nowait(wf_event)
msgEnt_event = self.makeMessage_EndEvent()
if msgEnt_event.to_response() is not None:
self._aqueue.put_nowait(msgEnt_event)
class IDManager:
def createID(self):
return {
@@ -410,6 +431,7 @@ class ChatStreamResponse(StreamingResponse):
# the text_generator is the leading stream, once it's finished, also finish the event stream
event_handler.is_done = True
event_handler.setResponse(response)
# Yield the events from the event handler
async def _event_generator():
@@ -431,7 +453,8 @@ class ChatStreamResponse(StreamingResponse):
break
@v.post("/chat-messages")
async def post_conversations(request: Request, data: ChatRequestData):
async def post_chatmessages(request: Request, data: ChatRequestData):
global gEvent_handler
userMng.findNoExistCreate(data.user)
data.conversation_id = data.conversation_id if data.conversation_id else str(uuid.uuid4())
@@ -445,19 +468,21 @@ async def post_conversations(request: Request, data: ChatRequestData):
filters = None
params = data.inputs or {}
# 获取聊天引擎对象
chat_engine = get_chat_engine(filters=filters, params=params)
# 启动聊天事件监听
ids = IDManager().createID()
event_handler = ChatEventCallbackHandler(ids = ids,data = data)
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
if gEvent_handler is None:
gEvent_handler = ChatEventCallbackHandler()
Settings.llm.callback_manager.handlers.append(gEvent_handler)
if gEvent_handler is not None:
gEvent_handler.setInitParams(ids = ids,data = data)
# 获取聊天引擎对象
chat_engine = get_chat_engine(filters=filters, params=params)
# 执行异步聊天
response = await chat_engine.astream_chat(data.query)
event_handler.setResponse(response)
# 返回异步消息回应
return ChatStreamResponse(request, event_handler, response, data,ids)
return ChatStreamResponse(request, gEvent_handler, response, data,ids)
@v.get("/messages")
async def query_messages(user:str, conversation_id:str):