import asyncio import json import logging from typing import Dict, List, Any, Optional, AsyncGenerator from aiostream import stream from fastapi import APIRouter, Request from fastapi.responses import StreamingResponse from llama_index.core import BaseCallbackHandler from llama_index.core.base.llms.types import ChatMessage from llama_index.core.callbacks import CBEventType from llama_index.core.chat_engine.types import StreamingAgentChatResponse from llama_index.core.tools import ToolOutput from pydantic import BaseModel from app.api.routers.events import EventCallbackHandler from app.api.routers.request.base import userMng, conversations from app.api.routers.request.models import ChatRequestData from app.engine import get_chat_engine logger = logging.getLogger("uvicorn") v1_router = v = APIRouter() class ChatCallbackEvent(BaseModel): event_type: CBEventType payload: Optional[Dict[str, Any]] = None event_id: str = "" def get_retrieval_message(self) -> dict | None: if self.payload: nodes = self.payload.get("nodes") if nodes: msg = f"根据查询检索到 {len(nodes)} 源文件" else: msg = f"查询检索中: '{self.payload.get('query_str')}'" return { "type": "events", "data": {"title": msg}, } else: return None def get_tool_message(self) -> dict | None: func_call_args = self.payload.get("function_call") if func_call_args is not None and "tool" in self.payload: tool = self.payload.get("tool") return { "type": "events", "data": { "title": f"调用工具 {tool.name} ,参数: {func_call_args}", }, } def _is_output_serializable(self, output: Any) -> bool: try: json.dumps(output) return True except TypeError: return False def get_agent_tool_response(self) -> dict | None: response = self.payload.get("response") if response is not None: sources = response.sources for source in sources: # Return the tool response here to include the toolCall information if isinstance(source, ToolOutput): if self._is_output_serializable(source.raw_output): output = source.raw_output else: output = source.content return { "type": "tools", "data": { "toolOutput": { "output": output, "isError": source.is_error, }, "toolCall": { "id": None, # There is no tool id in the ToolOutput "name": source.tool_name, "input": source.raw_input, }, }, } def to_response(self): try: match self.event_type: case "retrieve": return self.get_retrieval_message() case "function_call": return self.get_tool_message() case "agent_step": return self.get_agent_tool_response() case _: return None except Exception as e: logger.error(f"转换回应时间时发生错误,原因: {e}") return None class ChatEventCallbackHandler(BaseCallbackHandler): _aqueue: asyncio.Queue is_done: bool = False def __init__( self, ): """Initialize the base callback handler.""" ignored_events = [ CBEventType.CHUNKING, CBEventType.NODE_PARSING, CBEventType.EMBEDDING, CBEventType.LLM, CBEventType.TEMPLATING, ] super().__init__(ignored_events, ignored_events) self._aqueue = asyncio.Queue() def on_event_start( self, event_type: CBEventType, payload: Optional[Dict[str, Any]] = None, event_id: str = "", **kwargs: Any, ) -> str: event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload) if event.to_response() is not None: self._aqueue.put_nowait(event) def on_event_end( self, event_type: CBEventType, payload: Optional[Dict[str, Any]] = None, event_id: str = "", **kwargs: Any, ) -> None: event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload) if event.to_response() is not None: self._aqueue.put_nowait(event) def start_trace(self, trace_id: Optional[str] = None) -> None: """No-op.""" def end_trace( self, trace_id: Optional[str] = None, trace_map: Optional[Dict[str, List[str]]] = None, ) -> None: """No-op.""" async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]: while not self._aqueue.empty() or not self.is_done: try: yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1) except asyncio.TimeoutError: pass class ChatStreamResponse(StreamingResponse): TEXT_PREFIX = "data:" DATA_PREFIX = "data:" @classmethod def convert_text(cls, token: str): # Escape newlines and double quotes to avoid breaking the stream token = json.dumps(token) #return f"data: {{"event": "message", "conversation_id": "80d85523-de92-4b9d-aca0-c48a5eacb068", "message_id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "created_at": 1724406492, "task_id": "802f3064-030d-42ac-a882-0e1293712d04", "id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "answer": "{token}"}}" return "" @classmethod def convert_data(cls, data: dict): data_str = json.dumps(data) return f"{cls.DATA_PREFIX}{data_str}\n" def __init__( self, request: Request, event_handler: EventCallbackHandler, response: StreamingAgentChatResponse, data: ChatRequestData ): content = ChatStreamResponse.content_generator( request, event_handler, response, data ) super().__init__(content=content) @classmethod async def content_generator( cls, request: Request, event_handler: EventCallbackHandler, response: StreamingAgentChatResponse, data: ChatRequestData ): # Yield the text response async def _chat_response_generator(): final_response = "" async for token in response.async_response_gen(): final_response += token yield ChatStreamResponse.convert_text(token) # 存储消息历史 #message = Message(data.conversation_id, data.query, answer=final_response) #messageManager.addmessage(message) # the text_generator is the leading stream, once it's finished, also finish the event stream event_handler.is_done = True # Yield the events from the event handler async def _event_generator(): async for event in event_handler.async_event_gen(): event_response = event.to_response() if event_response is not None: yield ChatStreamResponse.convert_data(event_response) combine = stream.merge(_chat_response_generator(), _event_generator()) is_stream_started = False async with combine.stream() as streamer: async for output in streamer: if not is_stream_started: is_stream_started = True # Stream a blank message to start the stream yield ChatStreamResponse.convert_text("") yield output if await request.is_disconnected(): break @v.post("/chat-messages") async def post_conversations(request: Request, data: ChatRequestData): userMng.findNoExistCreate(data.user) conversaObj = conversations() conversationinfo = conversaObj.get(data.user, data.conversation_id) if conversationinfo is None: conversationinfo = conversaObj.add(data.user, "新建会话", data.conversation_id) # 生成聊天参数 last_message_content = ChatMessage.from_str(data.query) filters = None params = data.inputs or {} # 获取聊天引擎对象 chat_engine = get_chat_engine(filters=filters, params=params) # 启动聊天事件监听 event_handler = EventCallbackHandler() chat_engine.callback_manager.handlers.append(event_handler) # type: ignore # 执行异步聊天 response = await chat_engine.astream_chat(data.query) # 返回异步消息回应 return ChatStreamResponse(request, event_handler, response, data) @v.get("/messages") async def query_messages(user:str, conversation_id:str): pass @v.post("/conversations/{itemid}/name") async def post_conversations(user:str): pass @v.get("/conversations") async def query_conversations(user:str): user_id = '' if user is None else user userMng.findNoExistCreate(user_id) return { "limit": 20, "has_more": False, "data": conversations().gets(user_id) }