diff --git a/backend/app/api/routers/app.py b/backend/app/api/routers/app.py index e806726..101740a 100644 --- a/backend/app/api/routers/app.py +++ b/backend/app/api/routers/app.py @@ -1,23 +1,264 @@ -import os -from typing import Dict, List, Any, Optional, cast -from fastapi import APIRouter,Request -from app.api.routers.request.base import userMng,conversations -from app.api.routers.request.models import ChatRequestData +import asyncio +import json +import logging +from typing import Dict, List, Any, Optional, AsyncGenerator + +from aiostream import stream +from fastapi import APIRouter, Request +from fastapi.responses import StreamingResponse +from llama_index.core import BaseCallbackHandler +from llama_index.core.base.llms.types import ChatMessage +from llama_index.core.callbacks import CBEventType +from llama_index.core.chat_engine.types import StreamingAgentChatResponse +from llama_index.core.tools import ToolOutput +from pydantic import BaseModel + +from app.api.routers.events import EventCallbackHandler +from app.api.routers.request.base import userMng, conversations +from app.api.routers.request.models import ChatRequestData +from app.engine import get_chat_engine + +logger = logging.getLogger("uvicorn") -api_router = r = APIRouter() v1_router = v = APIRouter() +class ChatCallbackEvent(BaseModel): + event_type: CBEventType + payload: Optional[Dict[str, Any]] = None + event_id: str = "" + + def get_retrieval_message(self) -> dict | None: + if self.payload: + nodes = self.payload.get("nodes") + if nodes: + msg = f"根据查询检索到 {len(nodes)} 源文件" + else: + msg = f"查询检索中: '{self.payload.get('query_str')}'" + return { + "type": "events", + "data": {"title": msg}, + } + else: + return None + + def get_tool_message(self) -> dict | None: + func_call_args = self.payload.get("function_call") + if func_call_args is not None and "tool" in self.payload: + tool = self.payload.get("tool") + return { + "type": "events", + "data": { + "title": f"调用工具 {tool.name} ,参数: {func_call_args}", + }, + } + + def _is_output_serializable(self, output: Any) -> bool: + try: + json.dumps(output) + return True + except TypeError: + return False + + def get_agent_tool_response(self) -> dict | None: + response = self.payload.get("response") + if response is not None: + sources = response.sources + for source in sources: + # Return the tool response here to include the toolCall information + if isinstance(source, ToolOutput): + if self._is_output_serializable(source.raw_output): + output = source.raw_output + else: + output = source.content + + return { + "type": "tools", + "data": { + "toolOutput": { + "output": output, + "isError": source.is_error, + }, + "toolCall": { + "id": None, # There is no tool id in the ToolOutput + "name": source.tool_name, + "input": source.raw_input, + }, + }, + } + + def to_response(self): + try: + match self.event_type: + case "retrieve": + return self.get_retrieval_message() + case "function_call": + return self.get_tool_message() + case "agent_step": + return self.get_agent_tool_response() + case _: + return None + except Exception as e: + logger.error(f"转换回应时间时发生错误,原因: {e}") + return None + + +class ChatEventCallbackHandler(BaseCallbackHandler): + _aqueue: asyncio.Queue + is_done: bool = False + + def __init__( + self, + ): + """Initialize the base callback handler.""" + ignored_events = [ + CBEventType.CHUNKING, + CBEventType.NODE_PARSING, + CBEventType.EMBEDDING, + CBEventType.LLM, + CBEventType.TEMPLATING, + ] + super().__init__(ignored_events, ignored_events) + self._aqueue = asyncio.Queue() + + def on_event_start( + self, + event_type: CBEventType, + payload: Optional[Dict[str, Any]] = None, + event_id: str = "", + **kwargs: Any, + ) -> str: + event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload) + if event.to_response() is not None: + self._aqueue.put_nowait(event) + + def on_event_end( + self, + event_type: CBEventType, + payload: Optional[Dict[str, Any]] = None, + event_id: str = "", + **kwargs: Any, + ) -> None: + event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload) + if event.to_response() is not None: + self._aqueue.put_nowait(event) + + def start_trace(self, trace_id: Optional[str] = None) -> None: + """No-op.""" + + def end_trace( + self, + trace_id: Optional[str] = None, + trace_map: Optional[Dict[str, List[str]]] = None, + ) -> None: + """No-op.""" + + async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]: + while not self._aqueue.empty() or not self.is_done: + try: + yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1) + except asyncio.TimeoutError: + pass + +class ChatStreamResponse(StreamingResponse): + TEXT_PREFIX = "data:" + DATA_PREFIX = "data:" + + @classmethod + def convert_text(cls, token: str): + # Escape newlines and double quotes to avoid breaking the stream + token = json.dumps(token) + + #return f"data: {{"event": "message", "conversation_id": "80d85523-de92-4b9d-aca0-c48a5eacb068", "message_id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "created_at": 1724406492, "task_id": "802f3064-030d-42ac-a882-0e1293712d04", "id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "answer": "{token}"}}" + return "" + + @classmethod + def convert_data(cls, data: dict): + data_str = json.dumps(data) + return f"{cls.DATA_PREFIX}{data_str}\n" + + def __init__( + self, + request: Request, + event_handler: EventCallbackHandler, + response: StreamingAgentChatResponse, + data: ChatRequestData + ): + content = ChatStreamResponse.content_generator( + request, event_handler, response, data + ) + super().__init__(content=content) + + @classmethod + async def content_generator( + cls, + request: Request, + event_handler: EventCallbackHandler, + response: StreamingAgentChatResponse, + data: ChatRequestData + ): + # Yield the text response + async def _chat_response_generator(): + final_response = "" + async for token in response.async_response_gen(): + final_response += token + yield ChatStreamResponse.convert_text(token) + + + # 存储消息历史 + #message = Message(data.conversation_id, data.query, answer=final_response) + #messageManager.addmessage(message) + + # the text_generator is the leading stream, once it's finished, also finish the event stream + event_handler.is_done = True + + # Yield the events from the event handler + async def _event_generator(): + async for event in event_handler.async_event_gen(): + event_response = event.to_response() + if event_response is not None: + yield ChatStreamResponse.convert_data(event_response) + + combine = stream.merge(_chat_response_generator(), _event_generator()) + is_stream_started = False + async with combine.stream() as streamer: + async for output in streamer: + if not is_stream_started: + is_stream_started = True + # Stream a blank message to start the stream + yield ChatStreamResponse.convert_text("") + + yield output + + if await request.is_disconnected(): + break + @v.post("/chat-messages") -async def post_conversations(request: Request,data: ChatRequestData): +async def post_conversations(request: Request, data: ChatRequestData): userMng.findNoExistCreate(data.user) conversaObj = conversations() - conversationinfo = conversaObj.get(data.user) + conversationinfo = conversaObj.get(data.user, data.conversation_id) if conversationinfo is None: - conversationinfo = conversaObj.add(data.user, "新建会话") + conversationinfo = conversaObj.add(data.user, "新建会话", data.conversation_id) - return None + # 生成聊天参数 + last_message_content = ChatMessage.from_str(data.query) + filters = None + params = data.inputs or {} + + # 获取聊天引擎对象 + chat_engine = get_chat_engine(filters=filters, params=params) + + # 启动聊天事件监听 + event_handler = EventCallbackHandler() + chat_engine.callback_manager.handlers.append(event_handler) # type: ignore + + # 执行异步聊天 + response = await chat_engine.astream_chat(data.query) + + # 返回异步消息回应 + return ChatStreamResponse(request, event_handler, response, data) @v.get("/messages") async def query_messages(user:str, conversation_id:str): @@ -37,33 +278,3 @@ async def query_conversations(user:str): "has_more": False, "data": conversations().gets(user_id) } - - -@r.get("/conversations") -async def query_conversations(first_id:int = None, limit:int = None, pinned:bool = None): - pass - -#meta查询 -@r.get("/meta") -async def query_meta(): - pass - -#name查询 -@r.get("/name查询") -def query_name(): - with sessionlocal() as session: - name = session.query(NameOrm).first() - - return Name.from_orm(name) - -#parameters查询 -@r.get("/parameters") -async def query_parameters(): - pass - -#msite查询 -@r.get("/site") -async def query_site(): - pass - - diff --git a/backend/app/api/routers/request/base.py b/backend/app/api/routers/request/base.py index 7723d49..00f86b1 100644 --- a/backend/app/api/routers/request/base.py +++ b/backend/app/api/routers/request/base.py @@ -1,10 +1,8 @@ -import os -from typing import Dict, List, Any, Optional, cast -import json -from app.api.routers.request.dbOrm import DBManager -from app.api.routers.request.baseConfig import BaseConfig from datetime import datetime +from app.api.routers.request.baseConfig import BaseConfig +from app.api.routers.request.dbOrm import DBManager + dbManage = DBManager() class conversations: @@ -31,6 +29,7 @@ class conversations: if id == '': id= str(uuid.uuid4()) template = BaseConfig.ConversationCfg + template['id'] = id template['user_id'] = user_id template['name'] = name diff --git a/backend/app/api/routers/request/dbOrm.py b/backend/app/api/routers/request/dbOrm.py index 539a669..bb7c1be 100644 --- a/backend/app/api/routers/request/dbOrm.py +++ b/backend/app/api/routers/request/dbOrm.py @@ -1,11 +1,10 @@ import os -from typing import Dict, List, Any, Optional, cast +from typing import Dict, List, Any -from fastapi import APIRouter -from pydantic import BaseModel, Field -from sqlalchemy import create_engine, Column, String, Integer, Boolean, JSON,ForeignKey -from sqlalchemy.orm import sessionmaker, declarative_base,relationship +from pydantic import BaseModel +from sqlalchemy import create_engine, Column, String, Integer, JSON from sqlalchemy.engine.reflection import Inspector +from sqlalchemy.orm import sessionmaker, declarative_base Base = declarative_base() diff --git a/backend/app/api/routers/request/models.py b/backend/app/api/routers/request/models.py index 493883c..43b06d4 100644 --- a/backend/app/api/routers/request/models.py +++ b/backend/app/api/routers/request/models.py @@ -1,5 +1,8 @@ + +from typing import Dict, Any + from pydantic import BaseModel -from typing import Dict, List, Any, Optional, cast + class ChatRequestData(BaseModel): inputs: Dict[str,Any] @@ -7,4 +10,5 @@ class ChatRequestData(BaseModel): user: str response_mode: str files: Any + conversation_id: str = None \ No newline at end of file diff --git a/backend/main.py b/backend/main.py index f4bca3e..dd12002 100644 --- a/backend/main.py +++ b/backend/main.py @@ -12,7 +12,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse from app.api.routers.chat import chat_router from app.api.routers.upload import file_upload_router -from app.api.routers.app import api_router,v1_router +from app.api.routers.app import v1_router from app.settings import init_settings from app.observability import init_observability from fastapi.staticfiles import StaticFiles @@ -56,7 +56,7 @@ mount_static_files("data", "/api/files/data") mount_static_files("data_output", "/api/files/output") app.include_router(chat_router, prefix="/api/chat") app.include_router(file_upload_router, prefix="/api/chat/upload") -app.include_router(api_router, prefix="/api") + app.include_router(v1_router, prefix="/v1") @app.get("/")