From 728ee06c5aecf8606306795085112f4d84146a00 Mon Sep 17 00:00:00 2001 From: wanyaokun <12345678> Date: Mon, 2 Sep 2024 19:58:18 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81=E8=8A=82=E7=82=B9=E5=85=83?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E8=BE=93=E5=87=BA=E5=8F=8A=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E4=BC=98=E5=8C=96=EF=BC=8C=E5=87=8F=E5=B0=91=E4=BA=8B=E4=BB=B6?= =?UTF-8?q?=E9=87=8D=E5=A4=8D=E6=B7=BB=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/api/routers/app.py | 97 +++++++++++++++++++++------------- backend/app/engine/engine.py | 2 +- backend/main.py | 2 + 3 files changed, 64 insertions(+), 37 deletions(-) diff --git a/backend/app/api/routers/app.py b/backend/app/api/routers/app.py index 02f2984..fa61dac 100644 --- a/backend/app/api/routers/app.py +++ b/backend/app/api/routers/app.py @@ -23,11 +23,18 @@ import uuid from app.api.routers.services.fileServices import PrjFileLoadService,ChatFileService from app.api.routers.services.suggestion import NextQuestionSuggestion import time +from llama_index.core.settings import Settings +from llama_index.core.callbacks import CallbackManager logger = logging.getLogger("uvicorn") v1_router = v = APIRouter() +Settings.llm.callback_manager = CallbackManager() + +gEvent_handler = None + + class ChatCallbackEvent(BaseModel): event_type: ChatEventType payload: Optional[Dict[str, Any]] = None @@ -210,7 +217,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler): _aqueue: asyncio.Queue is_done: bool = False - def __init__(self,**params): + def __init__(self): """Initialize the base callback handler.""" ignored_events = [ # CBEventType.CHUNKING, @@ -222,13 +229,15 @@ class ChatEventCallbackHandler(BaseCallbackHandler): super().__init__(ignored_events, ignored_events) self._aqueue = asyncio.Queue() self._response: StreamingAgentChatResponse = None - self._params:Dict[str,Any] = params + self._ids:Dict[str,Any] = {} + self._chatData:ChatRequestData = None self._nodeStack:List[str] = [] - - #添加工作流开始事件 - wf_event = self.makeWorkflow_startEvent() - if wf_event.to_response() is not None: - self._aqueue.put_nowait(wf_event) + self._firstEventID:str = None + + def setInitParams(self,ids:dict,data:ChatRequestData): + self._ids = ids + self._chatData = data + self._firstEventID = None def setResponse(self,response: StreamingAgentChatResponse): self._response = response @@ -240,11 +249,15 @@ class ChatEventCallbackHandler(BaseCallbackHandler): event_id: str = "", **kwargs: Any, ) -> str: + if self._firstEventID is None: + self._firstEventID = event_id + self.start() + logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload)) self._nodeStack.append(event_id) nindex = len(self._nodeStack) - 1 - args:Dict[str,Any] = self._params['ids'] + args:Dict[str,Any] = self._ids args.update( { 'nodeid':event_id, @@ -267,7 +280,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler): logger.info("event_end:{} type:{} payload:{}\n".format(event_id, event_type, payload)) #self.response = payload.get("response","") - args:Dict[str,Any] = self._params['ids'] + args:Dict[str,Any] = self._ids nodeID = self._nodeStack[-1] if nodeID == event_id: nindex = len(self._nodeStack) - 1 @@ -284,6 +297,9 @@ class ChatEventCallbackHandler(BaseCallbackHandler): self._aqueue.put_nowait(nd_event) self._nodeStack.pop() + if self._firstEventID is not None and self._firstEventID == event_id: + self.finished() + def start_trace(self, trace_id: Optional[str] = None) -> None: """No-op.""" logger.info("trace_start:{}\n".format(trace_id)) @@ -294,15 +310,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler): trace_map: Optional[Dict[str, List[str]]] = None, ) -> None: """No-op.""" - logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map)) - - wf_event = self.makeWorkflow_finishedEvent() - if wf_event.to_response() is not None: - self._aqueue.put_nowait(wf_event) - - msgEnt_event = self.makeMessage_EndEvent() - if msgEnt_event.to_response() is not None: - self._aqueue.put_nowait(msgEnt_event) + logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map)) async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]: while not self._aqueue.empty() or not self.is_done: @@ -312,30 +320,28 @@ class ChatEventCallbackHandler(BaseCallbackHandler): pass def makeWorkflow_startEvent(self)->ChatCallbackEvent: - data:ChatRequestData = self._params['data'] - args:Dict[str,Any] = self._params['ids'] + args:Dict[str,Any] = self._ids args.update( { - 'use_id': data.user, - 'query': data.query, - 'conversation_id': data.conversation_id + 'use_id': self._chatData.user, + 'query': self._chatData.query, + 'conversation_id': self._chatData.conversation_id } ) return ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args) - def makeWorkflow_finishedEvent(self)->ChatCallbackEvent: - data:ChatRequestData = self._params['data'] - args:Dict[str,Any] = self._params['ids'] + def makeWorkflow_finishedEvent(self)->ChatCallbackEvent: + args:Dict[str,Any] = self._ids args.update( { 'response': '', - 'conversation_id': data.conversation_id + 'conversation_id': self._chatData.conversation_id } ) return ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args) def makeMessage_EndEvent(self)->ChatCallbackEvent: - args:Dict[str,Any] = self._params['ids'] + args:Dict[str,Any] = self._ids if self._response is not None: args.update({ 'source_node': self._response.source_nodes @@ -343,6 +349,21 @@ class ChatEventCallbackHandler(BaseCallbackHandler): msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args) return msgEnt_event + def start(self): + #添加工作流开始事件 + wf_event = self.makeWorkflow_startEvent() + if wf_event.to_response() is not None: + self._aqueue.put_nowait(wf_event) + + def finished(self): + wf_event = self.makeWorkflow_finishedEvent() + if wf_event.to_response() is not None: + self._aqueue.put_nowait(wf_event) + + msgEnt_event = self.makeMessage_EndEvent() + if msgEnt_event.to_response() is not None: + self._aqueue.put_nowait(msgEnt_event) + class IDManager: def createID(self): return { @@ -410,6 +431,7 @@ class ChatStreamResponse(StreamingResponse): # the text_generator is the leading stream, once it's finished, also finish the event stream event_handler.is_done = True + event_handler.setResponse(response) # Yield the events from the event handler async def _event_generator(): @@ -431,7 +453,8 @@ class ChatStreamResponse(StreamingResponse): break @v.post("/chat-messages") -async def post_conversations(request: Request, data: ChatRequestData): +async def post_chatmessages(request: Request, data: ChatRequestData): + global gEvent_handler userMng.findNoExistCreate(data.user) data.conversation_id = data.conversation_id if data.conversation_id else str(uuid.uuid4()) @@ -445,19 +468,21 @@ async def post_conversations(request: Request, data: ChatRequestData): filters = None params = data.inputs or {} - # 获取聊天引擎对象 - chat_engine = get_chat_engine(filters=filters, params=params) - # 启动聊天事件监听 ids = IDManager().createID() - event_handler = ChatEventCallbackHandler(ids = ids,data = data) - chat_engine.callback_manager.handlers.append(event_handler) # type: ignore + if gEvent_handler is None: + gEvent_handler = ChatEventCallbackHandler() + Settings.llm.callback_manager.handlers.append(gEvent_handler) + + if gEvent_handler is not None: + gEvent_handler.setInitParams(ids = ids,data = data) + # 获取聊天引擎对象 + chat_engine = get_chat_engine(filters=filters, params=params) # 执行异步聊天 response = await chat_engine.astream_chat(data.query) - event_handler.setResponse(response) # 返回异步消息回应 - return ChatStreamResponse(request, event_handler, response, data,ids) + return ChatStreamResponse(request, gEvent_handler, response, data,ids) @v.get("/messages") async def query_messages(user:str, conversation_id:str): diff --git a/backend/app/engine/engine.py b/backend/app/engine/engine.py index 4bbd993..4d44ce6 100644 --- a/backend/app/engine/engine.py +++ b/backend/app/engine/engine.py @@ -102,7 +102,7 @@ def create_query_engine(index, top_k=3, use_reranker=False, filters=None, respon simple_template = simple_template, node_postprocessors=postprocess, use_async=True, - streaming=True, + streaming=False, ResponseMode = response_mode ) diff --git a/backend/main.py b/backend/main.py index 00dca98..9782ddc 100644 --- a/backend/main.py +++ b/backend/main.py @@ -24,6 +24,8 @@ logger = logging.getLogger("uvicorn") usPrj = using_project(os.getenv("PHOENIX_PROJECT_NAME")) usPrj.__enter__() +import nest_asyncio +nest_asyncio.apply() init_settings() init_observability()