476 lines
16 KiB
Python
476 lines
16 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Dict, List, Any, Optional, AsyncGenerator
|
|
|
|
from aiostream import stream
|
|
from fastapi import APIRouter, Request
|
|
from fastapi.responses import StreamingResponse
|
|
from llama_index.core import BaseCallbackHandler
|
|
from llama_index.core.base.llms.types import ChatMessage
|
|
from llama_index.core.callbacks import CBEventType
|
|
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
|
|
from llama_index.core.tools import ToolOutput
|
|
from pydantic import BaseModel
|
|
from app.api.routers.request.base import userMng, conversations,message,parameter
|
|
from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest
|
|
from app.engine import get_chat_engine
|
|
import uuid
|
|
|
|
logger = logging.getLogger("uvicorn")
|
|
|
|
api_router = r = APIRouter()
|
|
v1_router = v = APIRouter()
|
|
|
|
default_conversation_id = '82e8417f-2c3b-4bb5-ab22-2ad318bbd29a'
|
|
|
|
class ChatCallbackEvent(BaseModel):
|
|
event_type: CBEventType
|
|
payload: Optional[Dict[str, Any]] = None
|
|
event_id: str = ""
|
|
|
|
def get_retrieval_message(self) -> dict | None:
|
|
if self.payload:
|
|
nodes = self.payload.get("nodes")
|
|
if nodes:
|
|
msg = f"根据查询检索到 {len(nodes)} 源文件"
|
|
else:
|
|
msg = f"查询检索中: '{self.payload.get('query_str')}'"
|
|
return {
|
|
"type": "events",
|
|
"data": {"title": msg},
|
|
}
|
|
else:
|
|
return None
|
|
|
|
def get_tool_message(self) -> dict | None:
|
|
func_call_args = self.payload.get("function_call")
|
|
if func_call_args is not None and "tool" in self.payload:
|
|
tool = self.payload.get("tool")
|
|
return {
|
|
"type": "events",
|
|
"data": {
|
|
"title": f"调用工具 {tool.name} ,参数: {func_call_args}",
|
|
},
|
|
}
|
|
|
|
def _is_output_serializable(self, output: Any) -> bool:
|
|
try:
|
|
json.dumps(output)
|
|
return True
|
|
except TypeError:
|
|
return False
|
|
|
|
def get_agent_tool_response(self) -> dict | None:
|
|
response = self.payload.get("response")
|
|
if response is not None:
|
|
sources = response.sources
|
|
for source in sources:
|
|
# Return the tool response here to include the toolCall information
|
|
if isinstance(source, ToolOutput):
|
|
if self._is_output_serializable(source.raw_output):
|
|
output = source.raw_output
|
|
else:
|
|
output = source.content
|
|
|
|
return {
|
|
"type": "tools",
|
|
"data": {
|
|
"toolOutput": {
|
|
"output": output,
|
|
"isError": source.is_error,
|
|
},
|
|
"toolCall": {
|
|
"id": None, # There is no tool id in the ToolOutput
|
|
"name": source.tool_name,
|
|
"input": source.raw_input,
|
|
},
|
|
},
|
|
}
|
|
|
|
def to_response(self):
|
|
try:
|
|
match self.event_type:
|
|
case "retrieve":
|
|
return self.get_retrieval_message()
|
|
case "function_call":
|
|
return self.get_tool_message()
|
|
case "agent_step":
|
|
return self.get_agent_tool_response()
|
|
case _:
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"转换回应时间时发生错误,原因: {e}")
|
|
return None
|
|
|
|
class ChatEventCallbackHandler(BaseCallbackHandler):
|
|
_aqueue: asyncio.Queue
|
|
is_done: bool = False
|
|
|
|
def __init__(
|
|
self,
|
|
):
|
|
"""Initialize the base callback handler."""
|
|
ignored_events = [
|
|
CBEventType.CHUNKING,
|
|
CBEventType.NODE_PARSING,
|
|
CBEventType.EMBEDDING,
|
|
CBEventType.LLM,
|
|
CBEventType.TEMPLATING,
|
|
]
|
|
super().__init__(ignored_events, ignored_events)
|
|
self._aqueue = asyncio.Queue()
|
|
|
|
def on_event_start(
|
|
self,
|
|
event_type: CBEventType,
|
|
payload: Optional[Dict[str, Any]] = None,
|
|
event_id: str = "",
|
|
**kwargs: Any,
|
|
) -> str:
|
|
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
|
if event.to_response() is not None:
|
|
self._aqueue.put_nowait(event)
|
|
|
|
def on_event_end(
|
|
self,
|
|
event_type: CBEventType,
|
|
payload: Optional[Dict[str, Any]] = None,
|
|
event_id: str = "",
|
|
**kwargs: Any,
|
|
) -> None:
|
|
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
|
if event.to_response() is not None:
|
|
self._aqueue.put_nowait(event)
|
|
|
|
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
|
"""No-op."""
|
|
|
|
def end_trace(
|
|
self,
|
|
trace_id: Optional[str] = None,
|
|
trace_map: Optional[Dict[str, List[str]]] = None,
|
|
) -> None:
|
|
"""No-op."""
|
|
|
|
async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]:
|
|
while not self._aqueue.empty() or not self.is_done:
|
|
try:
|
|
yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1)
|
|
except asyncio.TimeoutError:
|
|
pass
|
|
|
|
class IDManager:
|
|
def createID(self):
|
|
return {
|
|
"message_id" : str(uuid.uuid4()),
|
|
'task_id':str(uuid.uuid4()),
|
|
'workflow_run_id': str(uuid.uuid4()),
|
|
"workflow_id": str(uuid.uuid4())
|
|
}
|
|
|
|
class DifyChatResponseEvent(BaseModel):
|
|
event: str
|
|
conversation_id: str
|
|
message_id: str
|
|
created_at: int = 1724406492
|
|
task_id: str
|
|
|
|
class Workflow_started_DifyChatResponseEvent(DifyChatResponseEvent):
|
|
workflow_run_id:str
|
|
data:Dict[str,Any]
|
|
def __init__(self,**args):
|
|
args['data'] = {
|
|
"id": args['workflow_run_id'],
|
|
"workflow_id": args['workflow_id'],
|
|
"sequence_number": 1709,
|
|
"inputs": {
|
|
"sys.query": args['query'],
|
|
"sys.files": [],
|
|
"sys.conversation_id": args['conversation_id'],
|
|
"sys.user_id": args['use_id']
|
|
},
|
|
"created_at": 1724406492
|
|
}
|
|
args['event'] = 'workflow_started'
|
|
super().__init__(**args)
|
|
|
|
class Workflow_finished_DifyChatResponseEvent(DifyChatResponseEvent):
|
|
workflow_run_id:str
|
|
data:Dict[str,Any]
|
|
def __init__(self,**args):
|
|
args['event'] = 'workflow_finished'
|
|
args['data'] = {
|
|
"id": args['workflow_run_id'],
|
|
"workflow_id": args['workflow_id'],
|
|
"sequence_number": 1709,
|
|
"status": "succeeded",
|
|
"outputs": {
|
|
"answer": args['response']
|
|
},
|
|
"error": '',
|
|
"elapsed_time": 36.03764106379822,
|
|
"total_tokens": 11707,
|
|
"total_steps": 10,
|
|
"created_by": {
|
|
"id": str(uuid.uuid4()),
|
|
"user": args['use_id']
|
|
},
|
|
"created_at": 1724406492,
|
|
"finished_at": 1724406528,
|
|
"files": []
|
|
}
|
|
super().__init__(**args)
|
|
|
|
class Message_DifyChatResponseEvent(DifyChatResponseEvent):
|
|
id:str
|
|
answer:str
|
|
def __init__(self,**args):
|
|
args['id'] = args['message_id']
|
|
args['event'] = 'message'
|
|
super().__init__(**args)
|
|
|
|
class MessageEnd_DifyChatResponseEvent(DifyChatResponseEvent):
|
|
id:str
|
|
metadata:Dict[str,Any] = {}
|
|
def __init__(self,**args):
|
|
args['id'] = args['message_id']
|
|
args['event'] = 'message_end'
|
|
super().__init__(**args)
|
|
|
|
class ChatStreamResponse(StreamingResponse):
|
|
TEXT_PREFIX = "data:"
|
|
DATA_PREFIX = "data:"
|
|
|
|
@classmethod
|
|
def convert_text(cls, token: str):
|
|
# Escape newlines and double quotes to avoid breaking the stream
|
|
token = json.dumps(token)
|
|
|
|
#return f"data: {{"event": "message", "conversation_id": "80d85523-de92-4b9d-aca0-c48a5eacb068", "message_id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "created_at": 1724406492, "task_id": "802f3064-030d-42ac-a882-0e1293712d04", "id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "answer": "{token}"}}"
|
|
return ""
|
|
|
|
@classmethod
|
|
def convert_data(cls, data: dict):
|
|
data_str = json.dumps(data)
|
|
return f"{cls.DATA_PREFIX}{data_str}\n"
|
|
|
|
@classmethod
|
|
def convert_event(cls, event: DifyChatResponseEvent):
|
|
data_str = json.dumps(event.dict())
|
|
return f"{cls.DATA_PREFIX}{data_str}\n"
|
|
|
|
def __init__(
|
|
self,
|
|
request: Request,
|
|
event_handler: ChatEventCallbackHandler,
|
|
response: StreamingAgentChatResponse,
|
|
data: ChatRequestData
|
|
):
|
|
content = ChatStreamResponse.content_generator(
|
|
request, event_handler, response, data
|
|
)
|
|
super().__init__(content=content)
|
|
|
|
@classmethod
|
|
async def content_generator(
|
|
cls,
|
|
request: Request,
|
|
event_handler: ChatEventCallbackHandler,
|
|
response: StreamingAgentChatResponse,
|
|
data: ChatRequestData
|
|
):
|
|
ids = IDManager().createID()
|
|
# Yield the text response
|
|
async def _chat_response_generator():
|
|
final_response = ""
|
|
async for token in response.async_response_gen():
|
|
final_response += token
|
|
args = ids
|
|
args['answer'] = token
|
|
args['conversation_id'] = data.conversation_id
|
|
event = Message_DifyChatResponseEvent(**args)
|
|
yield ChatStreamResponse.convert_event(event)
|
|
#yield ChatStreamResponse.convert_text(token)
|
|
|
|
# 存储消息历史
|
|
message().add(user_id=data.user,conversation_id=data.conversation_id,query=data.query,answer=final_response)
|
|
|
|
# the text_generator is the leading stream, once it's finished, also finish the event stream
|
|
event_handler.is_done = True
|
|
# 发送工作流结束事件
|
|
args = ids
|
|
args['response'] = final_response
|
|
args['conversation_id'] = data.conversation_id
|
|
wf_event = Workflow_finished_DifyChatResponseEvent(**args)
|
|
yield ChatStreamResponse.convert_event(wf_event)
|
|
|
|
msgEnt_event = MessageEnd_DifyChatResponseEvent(**ids)
|
|
yield ChatStreamResponse.convert_event(msgEnt_event)
|
|
|
|
|
|
# Yield the events from the event handler
|
|
async def _event_generator():
|
|
async for event in event_handler.async_event_gen():
|
|
event_response = event.to_response()
|
|
if event_response is not None:
|
|
yield ChatStreamResponse.convert_data(event_response)
|
|
|
|
combine = stream.merge(_chat_response_generator(), _event_generator())
|
|
is_stream_started = False
|
|
async with combine.stream() as streamer:
|
|
async for output in streamer:
|
|
if not is_stream_started:
|
|
is_stream_started = True
|
|
|
|
# 发送工作流开始事件
|
|
args = ids
|
|
args['use_id'] = data.user
|
|
args['query'] = data.query
|
|
args['conversation_id'] = data.conversation_id
|
|
wf_event = Workflow_started_DifyChatResponseEvent(**args)
|
|
yield ChatStreamResponse.convert_event(wf_event)
|
|
|
|
# Stream a blank message to start the stream
|
|
# 发送一个空消息事件
|
|
#yield ChatStreamResponse.convert_text("")
|
|
|
|
yield output
|
|
|
|
if await request.is_disconnected():
|
|
break
|
|
|
|
|
|
|
|
@v.post("/chat-messages")
|
|
async def post_conversations(request: Request, data: ChatRequestData):
|
|
userMng.findNoExistCreate(data.user)
|
|
data.conversation_id = default_conversation_id if data.conversation_id is None else data.conversation_id
|
|
|
|
conversaObj = conversations()
|
|
conversationinfo = conversaObj.get(data.user, data.conversation_id)
|
|
if conversationinfo is None:
|
|
conversationinfo = conversaObj.add(data.user, "新建会话", data.conversation_id)
|
|
|
|
# 生成聊天参数
|
|
last_message_content = ChatMessage.from_str(data.query)
|
|
filters = None
|
|
params = data.inputs or {}
|
|
|
|
# 获取聊天引擎对象
|
|
chat_engine = get_chat_engine(filters=filters, params=params)
|
|
|
|
# 启动聊天事件监听
|
|
event_handler = ChatEventCallbackHandler()
|
|
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
|
|
|
|
# 执行异步聊天
|
|
response = await chat_engine.astream_chat(data.query)
|
|
|
|
# 返回异步消息回应
|
|
return ChatStreamResponse(request, event_handler, response, data)
|
|
|
|
@v.get("/messages")
|
|
async def query_messages(user:str, conversation_id:str):
|
|
conversation_id = default_conversation_id if conversation_id is None else conversation_id
|
|
datas = []
|
|
records = message().gets(user,conversation_id)
|
|
for record in records:
|
|
res = record.dict()
|
|
res["message_files"] = []
|
|
res["feedback"] = ''
|
|
res["retriever_resources"] = []
|
|
res["created_at"] = 1723444905
|
|
res["agent_thoughts"] = []
|
|
res["status"] = "normal"
|
|
res["error"] = ''
|
|
datas.append(res)
|
|
|
|
return {
|
|
"limit": 20,
|
|
"has_more": False,
|
|
"data": datas
|
|
}
|
|
|
|
@v.post("/conversations/{itemid}/name")
|
|
async def post_conversations(request: Request,itemid:str,params:Dict[str,Any]):
|
|
consaObj = conversations()
|
|
consaObj.rename(itemid,'知识问答')
|
|
cond = {
|
|
'id':itemid,
|
|
'user_id':params['user']
|
|
}
|
|
results = consaObj.query(**cond)
|
|
if len(results) > 0:
|
|
res = results[0]
|
|
return {
|
|
"id": res['id'],
|
|
"name": res['name'],
|
|
"inputs": res['inputs'],
|
|
"status": res['status'],
|
|
"introduction": res['introduction'],
|
|
"created_at": res['created_at'],
|
|
#"工程位置"
|
|
}
|
|
return 'null'
|
|
|
|
@v.get("/conversations")
|
|
async def query_conversations(user:str):
|
|
user_id = '' if user is None else user
|
|
userMng.findNoExistCreate(user_id)
|
|
|
|
return {
|
|
"limit": 20,
|
|
"has_more": False,
|
|
"data": conversations().gets(user_id)
|
|
}
|
|
|
|
@v.get("/parameters")
|
|
async def query_parameters(user:str):
|
|
params = parameter().get(user)
|
|
if len(params) == 0:
|
|
params = {
|
|
"opening_statement": "您好,我是配网D3造价软件小助手,您可以问我有关配网造价软件的相关问题!",
|
|
"suggested_questions": [],
|
|
"suggested_questions_after_answer": {
|
|
"enabled": False
|
|
},
|
|
"speech_to_text": {
|
|
"enabled": False
|
|
},
|
|
"text_to_speech": {
|
|
"enabled": False,
|
|
"language": "",
|
|
"voice": ""
|
|
},
|
|
"retriever_resource": {
|
|
"enabled": True
|
|
},
|
|
"annotation_reply": {
|
|
"enabled": False
|
|
},
|
|
"more_like_this": {
|
|
"enabled": False
|
|
},
|
|
"user_input_form": [],
|
|
"sensitive_word_avoidance": {
|
|
"enabled": False
|
|
},
|
|
"file_upload": {
|
|
"image": {
|
|
"enabled": False,
|
|
"number_limits": 3,
|
|
"transfer_methods": [
|
|
"remote_url"
|
|
]
|
|
}
|
|
},
|
|
"system_parameters": {
|
|
"image_file_size_limit": "10"
|
|
}
|
|
}
|
|
return params
|
|
|
|
@r.post("")
|
|
def upload_file(request: ChatFileUploadRequest) -> List[str]:
|
|
pass |