完善接口,实现对DIFY前端消息流传输的支持

This commit is contained in:
2024-08-29 08:26:59 +08:00
parent 9b47e1a6e1
commit 131d6ef1d1
2 changed files with 37 additions and 26 deletions
+32 -21
View File
@@ -1,6 +1,7 @@
import asyncio
import json
import logging
import time
from typing import Dict, List, Any, Optional, AsyncGenerator
from aiostream import stream
@@ -22,8 +23,6 @@ logger = logging.getLogger("uvicorn")
api_router = r = APIRouter()
v1_router = v = APIRouter()
default_conversation_id = '82e8417f-2c3b-4bb5-ab22-2ad318bbd29a'
class ChatCallbackEvent(BaseModel):
event_type: CBEventType
payload: Optional[Dict[str, Any]] = None
@@ -112,11 +111,11 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
):
"""Initialize the base callback handler."""
ignored_events = [
CBEventType.CHUNKING,
CBEventType.NODE_PARSING,
CBEventType.EMBEDDING,
CBEventType.LLM,
CBEventType.TEMPLATING,
# CBEventType.CHUNKING,
# CBEventType.NODE_PARSING,
# CBEventType.EMBEDDING,
# CBEventType.LLM,
# CBEventType.TEMPLATING,
]
super().__init__(ignored_events, ignored_events)
self._aqueue = asyncio.Queue()
@@ -128,6 +127,8 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
event_id: str = "",
**kwargs: Any,
) -> str:
logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload))
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
if event.to_response() is not None:
self._aqueue.put_nowait(event)
@@ -139,12 +140,14 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
event_id: str = "",
**kwargs: Any,
) -> None:
logger.info("event_end:{} type:{} payload:{}\n".format(event_id, event_type, payload))
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
if event.to_response() is not None:
self._aqueue.put_nowait(event)
def start_trace(self, trace_id: Optional[str] = None) -> None:
"""No-op."""
logger.info("trace_start:{}\n".format(trace_id))
def end_trace(
self,
@@ -152,6 +155,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""No-op."""
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]:
while not self._aqueue.empty() or not self.is_done:
@@ -173,7 +177,7 @@ class DifyChatResponseEvent(BaseModel):
event: str
conversation_id: str
message_id: str
created_at: int = 1724406492
created_at: int = int(time.time())
task_id: str
class Workflow_started_DifyChatResponseEvent(DifyChatResponseEvent):
@@ -190,7 +194,7 @@ class Workflow_started_DifyChatResponseEvent(DifyChatResponseEvent):
"sys.conversation_id": args['conversation_id'],
"sys.user_id": args['use_id']
},
"created_at": 1724406492
"created_at": int(time.time())
}
args['event'] = 'workflow_started'
super().__init__(**args)
@@ -216,8 +220,8 @@ class Workflow_finished_DifyChatResponseEvent(DifyChatResponseEvent):
"id": str(uuid.uuid4()),
"user": args['use_id']
},
"created_at": 1724406492,
"finished_at": 1724406528,
"created_at": int(time.time()),
"finished_at": int(time.time()),
"files": []
}
super().__init__(**args)
@@ -245,20 +249,20 @@ class ChatStreamResponse(StreamingResponse):
@classmethod
def convert_text(cls, token: str):
# Escape newlines and double quotes to avoid breaking the stream
token = json.dumps(token)
#token = json.dumps(token)
#return f"data: {{"event": "message", "conversation_id": "80d85523-de92-4b9d-aca0-c48a5eacb068", "message_id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "created_at": 1724406492, "task_id": "802f3064-030d-42ac-a882-0e1293712d04", "id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "answer": "{token}"}}"
return ""
return "\n"
@classmethod
def convert_data(cls, data: dict):
data_str = json.dumps(data)
return f"{cls.DATA_PREFIX}{data_str}\n"
return f"{cls.DATA_PREFIX}{data_str}\n\n"
@classmethod
def convert_event(cls, event: DifyChatResponseEvent):
data_str = json.dumps(event.dict())
return f"{cls.DATA_PREFIX}{data_str}\n"
return f"{cls.DATA_PREFIX}{data_str}\n\n"
def __init__(
self,
@@ -314,7 +318,7 @@ class ChatStreamResponse(StreamingResponse):
async for event in event_handler.async_event_gen():
event_response = event.to_response()
if event_response is not None:
yield ChatStreamResponse.convert_data(event_response)
yield ChatStreamResponse.convert_text("")
combine = stream.merge(_chat_response_generator(), _event_generator())
is_stream_started = False
@@ -345,12 +349,12 @@ class ChatStreamResponse(StreamingResponse):
@v.post("/chat-messages")
async def post_conversations(request: Request, data: ChatRequestData):
userMng.findNoExistCreate(data.user)
data.conversation_id = default_conversation_id if data.conversation_id is None else data.conversation_id
data.conversation_id = data.conversation_id if data.conversation_id else str(uuid.uuid4())
conversaObj = conversations()
conversationinfo = conversaObj.get(data.user, data.conversation_id)
conversationinfo = conversaObj.get(data.conversation_id)
if conversationinfo is None:
conversationinfo = conversaObj.add(data.user, "新建会话", data.conversation_id)
conversationinfo = conversaObj.add(data.conversation_id, data.user, "新建会话")
# 生成聊天参数
last_message_content = ChatMessage.from_str(data.query)
@@ -372,9 +376,16 @@ async def post_conversations(request: Request, data: ChatRequestData):
@v.get("/messages")
async def query_messages(user:str, conversation_id:str):
conversation_id = default_conversation_id if conversation_id is None else conversation_id
#conversation_id = default_conversation_id if conversation_id is None else conversation_id
datas = []
records = message().gets(user,conversation_id)
if records is None:
return {
"limit": 20,
"has_more": False,
"data": []
}
for record in records:
res = record.dict()
res["message_files"] = []
@@ -415,7 +426,7 @@ async def post_conversations(request: Request,itemid:str,params:Dict[str,Any]):
return 'null'
@v.get("/conversations")
async def query_conversations(user:str):
async def query_conversations(user:str, first_id:str = None, limit:str = None, pinned:str = None):
user_id = '' if user is None else user
userMng.findNoExistCreate(user_id)
+3 -3
View File
@@ -18,13 +18,13 @@ class conversations:
return datas
def get(self,user_id:str,id:str = ''):
records = dbManage.query(self._tableName,user_id = user_id,id=id)
def get(self, id:str):
records = dbManage.query(self._tableName, id=id)
if len(records) >0:
return records[0]
return None
def add(self,user_id:str,name:str,id:str = ''):
def add(self,id:str, user_id:str, name:str):
template = BaseConfig.ConversationCfg
template['id'] = id
template['user_id'] = user_id