增加对接DIFY前端支持功能
This commit is contained in:
+251
-40
@@ -1,23 +1,264 @@
|
||||
import os
|
||||
from typing import Dict, List, Any, Optional, cast
|
||||
from fastapi import APIRouter,Request
|
||||
from app.api.routers.request.base import userMng,conversations
|
||||
from app.api.routers.request.models import ChatRequestData
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional, AsyncGenerator
|
||||
|
||||
from aiostream import stream
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from llama_index.core import BaseCallbackHandler
|
||||
from llama_index.core.base.llms.types import ChatMessage
|
||||
from llama_index.core.callbacks import CBEventType
|
||||
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
|
||||
from llama_index.core.tools import ToolOutput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.api.routers.events import EventCallbackHandler
|
||||
from app.api.routers.request.base import userMng, conversations
|
||||
from app.api.routers.request.models import ChatRequestData
|
||||
from app.engine import get_chat_engine
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
api_router = r = APIRouter()
|
||||
v1_router = v = APIRouter()
|
||||
|
||||
|
||||
class ChatCallbackEvent(BaseModel):
|
||||
event_type: CBEventType
|
||||
payload: Optional[Dict[str, Any]] = None
|
||||
event_id: str = ""
|
||||
|
||||
def get_retrieval_message(self) -> dict | None:
|
||||
if self.payload:
|
||||
nodes = self.payload.get("nodes")
|
||||
if nodes:
|
||||
msg = f"根据查询检索到 {len(nodes)} 源文件"
|
||||
else:
|
||||
msg = f"查询检索中: '{self.payload.get('query_str')}'"
|
||||
return {
|
||||
"type": "events",
|
||||
"data": {"title": msg},
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_tool_message(self) -> dict | None:
|
||||
func_call_args = self.payload.get("function_call")
|
||||
if func_call_args is not None and "tool" in self.payload:
|
||||
tool = self.payload.get("tool")
|
||||
return {
|
||||
"type": "events",
|
||||
"data": {
|
||||
"title": f"调用工具 {tool.name} ,参数: {func_call_args}",
|
||||
},
|
||||
}
|
||||
|
||||
def _is_output_serializable(self, output: Any) -> bool:
|
||||
try:
|
||||
json.dumps(output)
|
||||
return True
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
def get_agent_tool_response(self) -> dict | None:
|
||||
response = self.payload.get("response")
|
||||
if response is not None:
|
||||
sources = response.sources
|
||||
for source in sources:
|
||||
# Return the tool response here to include the toolCall information
|
||||
if isinstance(source, ToolOutput):
|
||||
if self._is_output_serializable(source.raw_output):
|
||||
output = source.raw_output
|
||||
else:
|
||||
output = source.content
|
||||
|
||||
return {
|
||||
"type": "tools",
|
||||
"data": {
|
||||
"toolOutput": {
|
||||
"output": output,
|
||||
"isError": source.is_error,
|
||||
},
|
||||
"toolCall": {
|
||||
"id": None, # There is no tool id in the ToolOutput
|
||||
"name": source.tool_name,
|
||||
"input": source.raw_input,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def to_response(self):
|
||||
try:
|
||||
match self.event_type:
|
||||
case "retrieve":
|
||||
return self.get_retrieval_message()
|
||||
case "function_call":
|
||||
return self.get_tool_message()
|
||||
case "agent_step":
|
||||
return self.get_agent_tool_response()
|
||||
case _:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"转换回应时间时发生错误,原因: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
_aqueue: asyncio.Queue
|
||||
is_done: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
"""Initialize the base callback handler."""
|
||||
ignored_events = [
|
||||
CBEventType.CHUNKING,
|
||||
CBEventType.NODE_PARSING,
|
||||
CBEventType.EMBEDDING,
|
||||
CBEventType.LLM,
|
||||
CBEventType.TEMPLATING,
|
||||
]
|
||||
super().__init__(ignored_events, ignored_events)
|
||||
self._aqueue = asyncio.Queue()
|
||||
|
||||
def on_event_start(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
||||
if event.to_response() is not None:
|
||||
self._aqueue.put_nowait(event)
|
||||
|
||||
def on_event_end(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
||||
if event.to_response() is not None:
|
||||
self._aqueue.put_nowait(event)
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
"""No-op."""
|
||||
|
||||
def end_trace(
|
||||
self,
|
||||
trace_id: Optional[str] = None,
|
||||
trace_map: Optional[Dict[str, List[str]]] = None,
|
||||
) -> None:
|
||||
"""No-op."""
|
||||
|
||||
async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]:
|
||||
while not self._aqueue.empty() or not self.is_done:
|
||||
try:
|
||||
yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
class ChatStreamResponse(StreamingResponse):
|
||||
TEXT_PREFIX = "data:"
|
||||
DATA_PREFIX = "data:"
|
||||
|
||||
@classmethod
|
||||
def convert_text(cls, token: str):
|
||||
# Escape newlines and double quotes to avoid breaking the stream
|
||||
token = json.dumps(token)
|
||||
|
||||
#return f"data: {{"event": "message", "conversation_id": "80d85523-de92-4b9d-aca0-c48a5eacb068", "message_id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "created_at": 1724406492, "task_id": "802f3064-030d-42ac-a882-0e1293712d04", "id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "answer": "{token}"}}"
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def convert_data(cls, data: dict):
|
||||
data_str = json.dumps(data)
|
||||
return f"{cls.DATA_PREFIX}{data_str}\n"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request: Request,
|
||||
event_handler: EventCallbackHandler,
|
||||
response: StreamingAgentChatResponse,
|
||||
data: ChatRequestData
|
||||
):
|
||||
content = ChatStreamResponse.content_generator(
|
||||
request, event_handler, response, data
|
||||
)
|
||||
super().__init__(content=content)
|
||||
|
||||
@classmethod
|
||||
async def content_generator(
|
||||
cls,
|
||||
request: Request,
|
||||
event_handler: EventCallbackHandler,
|
||||
response: StreamingAgentChatResponse,
|
||||
data: ChatRequestData
|
||||
):
|
||||
# Yield the text response
|
||||
async def _chat_response_generator():
|
||||
final_response = ""
|
||||
async for token in response.async_response_gen():
|
||||
final_response += token
|
||||
yield ChatStreamResponse.convert_text(token)
|
||||
|
||||
|
||||
# 存储消息历史
|
||||
#message = Message(data.conversation_id, data.query, answer=final_response)
|
||||
#messageManager.addmessage(message)
|
||||
|
||||
# the text_generator is the leading stream, once it's finished, also finish the event stream
|
||||
event_handler.is_done = True
|
||||
|
||||
# Yield the events from the event handler
|
||||
async def _event_generator():
|
||||
async for event in event_handler.async_event_gen():
|
||||
event_response = event.to_response()
|
||||
if event_response is not None:
|
||||
yield ChatStreamResponse.convert_data(event_response)
|
||||
|
||||
combine = stream.merge(_chat_response_generator(), _event_generator())
|
||||
is_stream_started = False
|
||||
async with combine.stream() as streamer:
|
||||
async for output in streamer:
|
||||
if not is_stream_started:
|
||||
is_stream_started = True
|
||||
# Stream a blank message to start the stream
|
||||
yield ChatStreamResponse.convert_text("")
|
||||
|
||||
yield output
|
||||
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
@v.post("/chat-messages")
|
||||
async def post_conversations(request: Request,data: ChatRequestData):
|
||||
async def post_conversations(request: Request, data: ChatRequestData):
|
||||
userMng.findNoExistCreate(data.user)
|
||||
|
||||
conversaObj = conversations()
|
||||
conversationinfo = conversaObj.get(data.user)
|
||||
conversationinfo = conversaObj.get(data.user, data.conversation_id)
|
||||
if conversationinfo is None:
|
||||
conversationinfo = conversaObj.add(data.user, "新建会话")
|
||||
conversationinfo = conversaObj.add(data.user, "新建会话", data.conversation_id)
|
||||
|
||||
return None
|
||||
# 生成聊天参数
|
||||
last_message_content = ChatMessage.from_str(data.query)
|
||||
filters = None
|
||||
params = data.inputs or {}
|
||||
|
||||
# 获取聊天引擎对象
|
||||
chat_engine = get_chat_engine(filters=filters, params=params)
|
||||
|
||||
# 启动聊天事件监听
|
||||
event_handler = EventCallbackHandler()
|
||||
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
|
||||
|
||||
# 执行异步聊天
|
||||
response = await chat_engine.astream_chat(data.query)
|
||||
|
||||
# 返回异步消息回应
|
||||
return ChatStreamResponse(request, event_handler, response, data)
|
||||
|
||||
@v.get("/messages")
|
||||
async def query_messages(user:str, conversation_id:str):
|
||||
@@ -37,33 +278,3 @@ async def query_conversations(user:str):
|
||||
"has_more": False,
|
||||
"data": conversations().gets(user_id)
|
||||
}
|
||||
|
||||
|
||||
@r.get("/conversations")
|
||||
async def query_conversations(first_id:int = None, limit:int = None, pinned:bool = None):
|
||||
pass
|
||||
|
||||
#meta查询
|
||||
@r.get("/meta")
|
||||
async def query_meta():
|
||||
pass
|
||||
|
||||
#name查询
|
||||
@r.get("/name查询")
|
||||
def query_name():
|
||||
with sessionlocal() as session:
|
||||
name = session.query(NameOrm).first()
|
||||
|
||||
return Name.from_orm(name)
|
||||
|
||||
#parameters查询
|
||||
@r.get("/parameters")
|
||||
async def query_parameters():
|
||||
pass
|
||||
|
||||
#msite查询
|
||||
@r.get("/site")
|
||||
async def query_site():
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import os
|
||||
from typing import Dict, List, Any, Optional, cast
|
||||
import json
|
||||
from app.api.routers.request.dbOrm import DBManager
|
||||
from app.api.routers.request.baseConfig import BaseConfig
|
||||
from datetime import datetime
|
||||
|
||||
from app.api.routers.request.baseConfig import BaseConfig
|
||||
from app.api.routers.request.dbOrm import DBManager
|
||||
|
||||
dbManage = DBManager()
|
||||
|
||||
class conversations:
|
||||
@@ -31,6 +29,7 @@ class conversations:
|
||||
if id == '':
|
||||
id= str(uuid.uuid4())
|
||||
template = BaseConfig.ConversationCfg
|
||||
|
||||
template['id'] = id
|
||||
template['user_id'] = user_id
|
||||
template['name'] = name
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import os
|
||||
from typing import Dict, List, Any, Optional, cast
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import create_engine, Column, String, Integer, Boolean, JSON,ForeignKey
|
||||
from sqlalchemy.orm import sessionmaker, declarative_base,relationship
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import create_engine, Column, String, Integer, JSON
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
|
||||
from typing import Dict, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, List, Any, Optional, cast
|
||||
|
||||
|
||||
class ChatRequestData(BaseModel):
|
||||
inputs: Dict[str,Any]
|
||||
@@ -7,4 +10,5 @@ class ChatRequestData(BaseModel):
|
||||
user: str
|
||||
response_mode: str
|
||||
files: Any
|
||||
conversation_id: str = None
|
||||
|
||||
Reference in New Issue
Block a user