完善接口,实现对DIFY前端消息流传输的支持
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from typing import Dict, List, Any, Optional, AsyncGenerator
|
from typing import Dict, List, Any, Optional, AsyncGenerator
|
||||||
|
|
||||||
from aiostream import stream
|
from aiostream import stream
|
||||||
@@ -22,8 +23,6 @@ logger = logging.getLogger("uvicorn")
|
|||||||
api_router = r = APIRouter()
|
api_router = r = APIRouter()
|
||||||
v1_router = v = APIRouter()
|
v1_router = v = APIRouter()
|
||||||
|
|
||||||
default_conversation_id = '82e8417f-2c3b-4bb5-ab22-2ad318bbd29a'
|
|
||||||
|
|
||||||
class ChatCallbackEvent(BaseModel):
|
class ChatCallbackEvent(BaseModel):
|
||||||
event_type: CBEventType
|
event_type: CBEventType
|
||||||
payload: Optional[Dict[str, Any]] = None
|
payload: Optional[Dict[str, Any]] = None
|
||||||
@@ -112,11 +111,11 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
|||||||
):
|
):
|
||||||
"""Initialize the base callback handler."""
|
"""Initialize the base callback handler."""
|
||||||
ignored_events = [
|
ignored_events = [
|
||||||
CBEventType.CHUNKING,
|
# CBEventType.CHUNKING,
|
||||||
CBEventType.NODE_PARSING,
|
# CBEventType.NODE_PARSING,
|
||||||
CBEventType.EMBEDDING,
|
# CBEventType.EMBEDDING,
|
||||||
CBEventType.LLM,
|
# CBEventType.LLM,
|
||||||
CBEventType.TEMPLATING,
|
# CBEventType.TEMPLATING,
|
||||||
]
|
]
|
||||||
super().__init__(ignored_events, ignored_events)
|
super().__init__(ignored_events, ignored_events)
|
||||||
self._aqueue = asyncio.Queue()
|
self._aqueue = asyncio.Queue()
|
||||||
@@ -128,6 +127,8 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
|||||||
event_id: str = "",
|
event_id: str = "",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload))
|
||||||
|
|
||||||
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
||||||
if event.to_response() is not None:
|
if event.to_response() is not None:
|
||||||
self._aqueue.put_nowait(event)
|
self._aqueue.put_nowait(event)
|
||||||
@@ -139,12 +140,14 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
|||||||
event_id: str = "",
|
event_id: str = "",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logger.info("event_end:{} type:{} payload:{}\n".format(event_id, event_type, payload))
|
||||||
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
||||||
if event.to_response() is not None:
|
if event.to_response() is not None:
|
||||||
self._aqueue.put_nowait(event)
|
self._aqueue.put_nowait(event)
|
||||||
|
|
||||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||||
"""No-op."""
|
"""No-op."""
|
||||||
|
logger.info("trace_start:{}\n".format(trace_id))
|
||||||
|
|
||||||
def end_trace(
|
def end_trace(
|
||||||
self,
|
self,
|
||||||
@@ -152,6 +155,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
|||||||
trace_map: Optional[Dict[str, List[str]]] = None,
|
trace_map: Optional[Dict[str, List[str]]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""No-op."""
|
"""No-op."""
|
||||||
|
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
|
||||||
|
|
||||||
async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]:
|
async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]:
|
||||||
while not self._aqueue.empty() or not self.is_done:
|
while not self._aqueue.empty() or not self.is_done:
|
||||||
@@ -173,7 +177,7 @@ class DifyChatResponseEvent(BaseModel):
|
|||||||
event: str
|
event: str
|
||||||
conversation_id: str
|
conversation_id: str
|
||||||
message_id: str
|
message_id: str
|
||||||
created_at: int = 1724406492
|
created_at: int = int(time.time())
|
||||||
task_id: str
|
task_id: str
|
||||||
|
|
||||||
class Workflow_started_DifyChatResponseEvent(DifyChatResponseEvent):
|
class Workflow_started_DifyChatResponseEvent(DifyChatResponseEvent):
|
||||||
@@ -190,7 +194,7 @@ class Workflow_started_DifyChatResponseEvent(DifyChatResponseEvent):
|
|||||||
"sys.conversation_id": args['conversation_id'],
|
"sys.conversation_id": args['conversation_id'],
|
||||||
"sys.user_id": args['use_id']
|
"sys.user_id": args['use_id']
|
||||||
},
|
},
|
||||||
"created_at": 1724406492
|
"created_at": int(time.time())
|
||||||
}
|
}
|
||||||
args['event'] = 'workflow_started'
|
args['event'] = 'workflow_started'
|
||||||
super().__init__(**args)
|
super().__init__(**args)
|
||||||
@@ -216,8 +220,8 @@ class Workflow_finished_DifyChatResponseEvent(DifyChatResponseEvent):
|
|||||||
"id": str(uuid.uuid4()),
|
"id": str(uuid.uuid4()),
|
||||||
"user": args['use_id']
|
"user": args['use_id']
|
||||||
},
|
},
|
||||||
"created_at": 1724406492,
|
"created_at": int(time.time()),
|
||||||
"finished_at": 1724406528,
|
"finished_at": int(time.time()),
|
||||||
"files": []
|
"files": []
|
||||||
}
|
}
|
||||||
super().__init__(**args)
|
super().__init__(**args)
|
||||||
@@ -239,26 +243,26 @@ class MessageEnd_DifyChatResponseEvent(DifyChatResponseEvent):
|
|||||||
super().__init__(**args)
|
super().__init__(**args)
|
||||||
|
|
||||||
class ChatStreamResponse(StreamingResponse):
|
class ChatStreamResponse(StreamingResponse):
|
||||||
TEXT_PREFIX = "data:"
|
TEXT_PREFIX = "data: "
|
||||||
DATA_PREFIX = "data:"
|
DATA_PREFIX = "data: "
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_text(cls, token: str):
|
def convert_text(cls, token: str):
|
||||||
# Escape newlines and double quotes to avoid breaking the stream
|
# Escape newlines and double quotes to avoid breaking the stream
|
||||||
token = json.dumps(token)
|
#token = json.dumps(token)
|
||||||
|
|
||||||
#return f"data: {{"event": "message", "conversation_id": "80d85523-de92-4b9d-aca0-c48a5eacb068", "message_id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "created_at": 1724406492, "task_id": "802f3064-030d-42ac-a882-0e1293712d04", "id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "answer": "{token}"}}"
|
#return f"data: {{"event": "message", "conversation_id": "80d85523-de92-4b9d-aca0-c48a5eacb068", "message_id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "created_at": 1724406492, "task_id": "802f3064-030d-42ac-a882-0e1293712d04", "id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "answer": "{token}"}}"
|
||||||
return ""
|
return "\n"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_data(cls, data: dict):
|
def convert_data(cls, data: dict):
|
||||||
data_str = json.dumps(data)
|
data_str = json.dumps(data)
|
||||||
return f"{cls.DATA_PREFIX}{data_str}\n"
|
return f"{cls.DATA_PREFIX}{data_str}\n\n"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_event(cls, event: DifyChatResponseEvent):
|
def convert_event(cls, event: DifyChatResponseEvent):
|
||||||
data_str = json.dumps(event.dict())
|
data_str = json.dumps(event.dict())
|
||||||
return f"{cls.DATA_PREFIX}{data_str}\n"
|
return f"{cls.DATA_PREFIX}{data_str}\n\n"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -314,7 +318,7 @@ class ChatStreamResponse(StreamingResponse):
|
|||||||
async for event in event_handler.async_event_gen():
|
async for event in event_handler.async_event_gen():
|
||||||
event_response = event.to_response()
|
event_response = event.to_response()
|
||||||
if event_response is not None:
|
if event_response is not None:
|
||||||
yield ChatStreamResponse.convert_data(event_response)
|
yield ChatStreamResponse.convert_text("")
|
||||||
|
|
||||||
combine = stream.merge(_chat_response_generator(), _event_generator())
|
combine = stream.merge(_chat_response_generator(), _event_generator())
|
||||||
is_stream_started = False
|
is_stream_started = False
|
||||||
@@ -345,12 +349,12 @@ class ChatStreamResponse(StreamingResponse):
|
|||||||
@v.post("/chat-messages")
|
@v.post("/chat-messages")
|
||||||
async def post_conversations(request: Request, data: ChatRequestData):
|
async def post_conversations(request: Request, data: ChatRequestData):
|
||||||
userMng.findNoExistCreate(data.user)
|
userMng.findNoExistCreate(data.user)
|
||||||
data.conversation_id = default_conversation_id if data.conversation_id is None else data.conversation_id
|
data.conversation_id = data.conversation_id if data.conversation_id else str(uuid.uuid4())
|
||||||
|
|
||||||
conversaObj = conversations()
|
conversaObj = conversations()
|
||||||
conversationinfo = conversaObj.get(data.user, data.conversation_id)
|
conversationinfo = conversaObj.get(data.conversation_id)
|
||||||
if conversationinfo is None:
|
if conversationinfo is None:
|
||||||
conversationinfo = conversaObj.add(data.user, "新建会话", data.conversation_id)
|
conversationinfo = conversaObj.add(data.conversation_id, data.user, "新建会话")
|
||||||
|
|
||||||
# 生成聊天参数
|
# 生成聊天参数
|
||||||
last_message_content = ChatMessage.from_str(data.query)
|
last_message_content = ChatMessage.from_str(data.query)
|
||||||
@@ -372,9 +376,16 @@ async def post_conversations(request: Request, data: ChatRequestData):
|
|||||||
|
|
||||||
@v.get("/messages")
|
@v.get("/messages")
|
||||||
async def query_messages(user:str, conversation_id:str):
|
async def query_messages(user:str, conversation_id:str):
|
||||||
conversation_id = default_conversation_id if conversation_id is None else conversation_id
|
#conversation_id = default_conversation_id if conversation_id is None else conversation_id
|
||||||
datas = []
|
datas = []
|
||||||
records = message().gets(user,conversation_id)
|
records = message().gets(user,conversation_id)
|
||||||
|
if records is None:
|
||||||
|
return {
|
||||||
|
"limit": 20,
|
||||||
|
"has_more": False,
|
||||||
|
"data": []
|
||||||
|
}
|
||||||
|
|
||||||
for record in records:
|
for record in records:
|
||||||
res = record.dict()
|
res = record.dict()
|
||||||
res["message_files"] = []
|
res["message_files"] = []
|
||||||
@@ -415,7 +426,7 @@ async def post_conversations(request: Request,itemid:str,params:Dict[str,Any]):
|
|||||||
return 'null'
|
return 'null'
|
||||||
|
|
||||||
@v.get("/conversations")
|
@v.get("/conversations")
|
||||||
async def query_conversations(user:str):
|
async def query_conversations(user:str, first_id:str = None, limit:str = None, pinned:str = None):
|
||||||
user_id = '' if user is None else user
|
user_id = '' if user is None else user
|
||||||
userMng.findNoExistCreate(user_id)
|
userMng.findNoExistCreate(user_id)
|
||||||
|
|
||||||
|
|||||||
@@ -18,13 +18,13 @@ class conversations:
|
|||||||
|
|
||||||
return datas
|
return datas
|
||||||
|
|
||||||
def get(self,user_id:str,id:str = ''):
|
def get(self, id:str):
|
||||||
records = dbManage.query(self._tableName,user_id = user_id,id=id)
|
records = dbManage.query(self._tableName, id=id)
|
||||||
if len(records) >0:
|
if len(records) >0:
|
||||||
return records[0]
|
return records[0]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def add(self,user_id:str,name:str,id:str = ''):
|
def add(self,id:str, user_id:str, name:str):
|
||||||
template = BaseConfig.ConversationCfg
|
template = BaseConfig.ConversationCfg
|
||||||
template['id'] = id
|
template['id'] = id
|
||||||
template['user_id'] = user_id
|
template['user_id'] = user_id
|
||||||
|
|||||||
Reference in New Issue
Block a user