dev #5
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, List, Any, Optional, AsyncGenerator
|
||||
|
||||
from aiostream import stream
|
||||
@@ -22,8 +23,6 @@ logger = logging.getLogger("uvicorn")
|
||||
api_router = r = APIRouter()
|
||||
v1_router = v = APIRouter()
|
||||
|
||||
default_conversation_id = '82e8417f-2c3b-4bb5-ab22-2ad318bbd29a'
|
||||
|
||||
class ChatCallbackEvent(BaseModel):
|
||||
event_type: CBEventType
|
||||
payload: Optional[Dict[str, Any]] = None
|
||||
@@ -112,11 +111,11 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
):
|
||||
"""Initialize the base callback handler."""
|
||||
ignored_events = [
|
||||
CBEventType.CHUNKING,
|
||||
CBEventType.NODE_PARSING,
|
||||
CBEventType.EMBEDDING,
|
||||
CBEventType.LLM,
|
||||
CBEventType.TEMPLATING,
|
||||
# CBEventType.CHUNKING,
|
||||
# CBEventType.NODE_PARSING,
|
||||
# CBEventType.EMBEDDING,
|
||||
# CBEventType.LLM,
|
||||
# CBEventType.TEMPLATING,
|
||||
]
|
||||
super().__init__(ignored_events, ignored_events)
|
||||
self._aqueue = asyncio.Queue()
|
||||
@@ -128,6 +127,8 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload))
|
||||
|
||||
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
||||
if event.to_response() is not None:
|
||||
self._aqueue.put_nowait(event)
|
||||
@@ -139,12 +140,14 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
logger.info("event_end:{} type:{} payload:{}\n".format(event_id, event_type, payload))
|
||||
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
||||
if event.to_response() is not None:
|
||||
self._aqueue.put_nowait(event)
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
"""No-op."""
|
||||
logger.info("trace_start:{}\n".format(trace_id))
|
||||
|
||||
def end_trace(
|
||||
self,
|
||||
@@ -152,6 +155,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
trace_map: Optional[Dict[str, List[str]]] = None,
|
||||
) -> None:
|
||||
"""No-op."""
|
||||
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
|
||||
|
||||
async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]:
|
||||
while not self._aqueue.empty() or not self.is_done:
|
||||
@@ -173,7 +177,7 @@ class DifyChatResponseEvent(BaseModel):
|
||||
event: str
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
created_at: int = 1724406492
|
||||
created_at: int = int(time.time())
|
||||
task_id: str
|
||||
|
||||
class Workflow_started_DifyChatResponseEvent(DifyChatResponseEvent):
|
||||
@@ -190,7 +194,7 @@ class Workflow_started_DifyChatResponseEvent(DifyChatResponseEvent):
|
||||
"sys.conversation_id": args['conversation_id'],
|
||||
"sys.user_id": args['use_id']
|
||||
},
|
||||
"created_at": 1724406492
|
||||
"created_at": int(time.time())
|
||||
}
|
||||
args['event'] = 'workflow_started'
|
||||
super().__init__(**args)
|
||||
@@ -216,8 +220,8 @@ class Workflow_finished_DifyChatResponseEvent(DifyChatResponseEvent):
|
||||
"id": str(uuid.uuid4()),
|
||||
"user": args['use_id']
|
||||
},
|
||||
"created_at": 1724406492,
|
||||
"finished_at": 1724406528,
|
||||
"created_at": int(time.time()),
|
||||
"finished_at": int(time.time()),
|
||||
"files": []
|
||||
}
|
||||
super().__init__(**args)
|
||||
@@ -239,26 +243,26 @@ class MessageEnd_DifyChatResponseEvent(DifyChatResponseEvent):
|
||||
super().__init__(**args)
|
||||
|
||||
class ChatStreamResponse(StreamingResponse):
|
||||
TEXT_PREFIX = "data:"
|
||||
DATA_PREFIX = "data:"
|
||||
TEXT_PREFIX = "data: "
|
||||
DATA_PREFIX = "data: "
|
||||
|
||||
@classmethod
|
||||
def convert_text(cls, token: str):
|
||||
# Escape newlines and double quotes to avoid breaking the stream
|
||||
token = json.dumps(token)
|
||||
#token = json.dumps(token)
|
||||
|
||||
#return f"data: {{"event": "message", "conversation_id": "80d85523-de92-4b9d-aca0-c48a5eacb068", "message_id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "created_at": 1724406492, "task_id": "802f3064-030d-42ac-a882-0e1293712d04", "id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "answer": "{token}"}}"
|
||||
return ""
|
||||
return "\n"
|
||||
|
||||
@classmethod
|
||||
def convert_data(cls, data: dict):
|
||||
data_str = json.dumps(data)
|
||||
return f"{cls.DATA_PREFIX}{data_str}\n"
|
||||
return f"{cls.DATA_PREFIX}{data_str}\n\n"
|
||||
|
||||
@classmethod
|
||||
def convert_event(cls, event: DifyChatResponseEvent):
|
||||
data_str = json.dumps(event.dict())
|
||||
return f"{cls.DATA_PREFIX}{data_str}\n"
|
||||
return f"{cls.DATA_PREFIX}{data_str}\n\n"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -314,7 +318,7 @@ class ChatStreamResponse(StreamingResponse):
|
||||
async for event in event_handler.async_event_gen():
|
||||
event_response = event.to_response()
|
||||
if event_response is not None:
|
||||
yield ChatStreamResponse.convert_data(event_response)
|
||||
yield ChatStreamResponse.convert_text("")
|
||||
|
||||
combine = stream.merge(_chat_response_generator(), _event_generator())
|
||||
is_stream_started = False
|
||||
@@ -345,12 +349,12 @@ class ChatStreamResponse(StreamingResponse):
|
||||
@v.post("/chat-messages")
|
||||
async def post_conversations(request: Request, data: ChatRequestData):
|
||||
userMng.findNoExistCreate(data.user)
|
||||
data.conversation_id = default_conversation_id if data.conversation_id is None else data.conversation_id
|
||||
data.conversation_id = data.conversation_id if data.conversation_id else str(uuid.uuid4())
|
||||
|
||||
conversaObj = conversations()
|
||||
conversationinfo = conversaObj.get(data.user, data.conversation_id)
|
||||
conversationinfo = conversaObj.get(data.conversation_id)
|
||||
if conversationinfo is None:
|
||||
conversationinfo = conversaObj.add(data.user, "新建会话", data.conversation_id)
|
||||
conversationinfo = conversaObj.add(data.conversation_id, data.user, "新建会话")
|
||||
|
||||
# 生成聊天参数
|
||||
last_message_content = ChatMessage.from_str(data.query)
|
||||
@@ -372,9 +376,16 @@ async def post_conversations(request: Request, data: ChatRequestData):
|
||||
|
||||
@v.get("/messages")
|
||||
async def query_messages(user:str, conversation_id:str):
|
||||
conversation_id = default_conversation_id if conversation_id is None else conversation_id
|
||||
#conversation_id = default_conversation_id if conversation_id is None else conversation_id
|
||||
datas = []
|
||||
records = message().gets(user,conversation_id)
|
||||
if records is None:
|
||||
return {
|
||||
"limit": 20,
|
||||
"has_more": False,
|
||||
"data": []
|
||||
}
|
||||
|
||||
for record in records:
|
||||
res = record.dict()
|
||||
res["message_files"] = []
|
||||
@@ -415,7 +426,7 @@ async def post_conversations(request: Request,itemid:str,params:Dict[str,Any]):
|
||||
return 'null'
|
||||
|
||||
@v.get("/conversations")
|
||||
async def query_conversations(user:str):
|
||||
async def query_conversations(user:str, first_id:str = None, limit:str = None, pinned:str = None):
|
||||
user_id = '' if user is None else user
|
||||
userMng.findNoExistCreate(user_id)
|
||||
|
||||
|
||||
@@ -18,13 +18,13 @@ class conversations:
|
||||
|
||||
return datas
|
||||
|
||||
def get(self,user_id:str,id:str = ''):
|
||||
records = dbManage.query(self._tableName,user_id = user_id,id=id)
|
||||
def get(self, id:str):
|
||||
records = dbManage.query(self._tableName, id=id)
|
||||
if len(records) >0:
|
||||
return records[0]
|
||||
return None
|
||||
|
||||
def add(self,user_id:str,name:str,id:str = ''):
|
||||
def add(self,id:str, user_id:str, name:str):
|
||||
template = BaseConfig.ConversationCfg
|
||||
template['id'] = id
|
||||
template['user_id'] = user_id
|
||||
|
||||
Reference in New Issue
Block a user