修改POST和Get请求代码

This commit is contained in:
wanyaokun
2024-08-27 17:48:38 +08:00
parent b4c571cddb
commit 07a3b2a147
5 changed files with 168 additions and 28 deletions
+138 -11
View File
@@ -14,14 +14,16 @@ from llama_index.core.tools import ToolOutput
from pydantic import BaseModel
from app.api.routers.events import EventCallbackHandler
from app.api.routers.request.base import userMng, conversations
from app.api.routers.request.base import userMng, conversations,message
from app.api.routers.request.models import ChatRequestData
from app.engine import get_chat_engine
import uuid
logger = logging.getLogger("uvicorn")
v1_router = v = APIRouter()
default_conversation_id = '82e8417f-2c3b-4bb5-ab22-2ad318bbd29a'
class ChatCallbackEvent(BaseModel):
event_type: CBEventType
@@ -102,7 +104,6 @@ class ChatCallbackEvent(BaseModel):
logger.error(f"转换回应时间时发生错误,原因: {e}")
return None
class ChatEventCallbackHandler(BaseCallbackHandler):
_aqueue: asyncio.Queue
is_done: bool = False
@@ -160,6 +161,84 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
except asyncio.TimeoutError:
pass
class IDManager:
def createID(self):
return {
"message_id" : str(uuid.uuid4()),
'task_id':str(uuid.uuid4()),
'workflow_run_id': str(uuid.uuid4()),
"workflow_id": str(uuid.uuid4())
}
class DifyChatResponseEvent(BaseModel):
event: str
conversation_id: str
message_id: str
created_at: int = 1724406492
task_id: str
class Workflow_started_DifyChatResponseEvent(DifyChatResponseEvent):
workflow_run_id:str
data:Dict[str,Any]
def __init__(self,**args):
args['data'] = {
"id": args['workflow_run_id'],
"workflow_id": args['workflow_id'],
"sequence_number": 1709,
"inputs": {
"sys.query": args['query'],
"sys.files": [],
"sys.conversation_id": args['conversation_id'],
"sys.user_id": args['use_id']
},
"created_at": 1724406492
}
args['event'] = 'workflow_started'
super().__init__(**args)
class Workflow_finished_DifyChatResponseEvent(DifyChatResponseEvent):
workflow_run_id:str
data:Dict[str,Any]
def __init__(self,**args):
args['event'] = 'workflow_finished'
args['data'] = {
"id": args['workflow_run_id'],
"workflow_id": args['workflow_id'],
"sequence_number": 1709,
"status": "succeeded",
"outputs": {
"answer": args['response']
},
"error": '',
"elapsed_time": 36.03764106379822,
"total_tokens": 11707,
"total_steps": 10,
"created_by": {
"id": str(uuid.uuid4()),
"user": args['use_id']
},
"created_at": 1724406492,
"finished_at": 1724406528,
"files": []
}
super().__init__(**args)
class Message_DifyChatResponseEvent(DifyChatResponseEvent):
id:str
answer:str
def __init__(self,**args):
args['id'] = args['message_id']
args['event'] = 'message'
super().__init__(**args)
class MessageEnd_DifyChatResponseEvent(DifyChatResponseEvent):
id:str
metadata:Dict[str,Any] = {}
def __init__(self,**args):
args['id'] = args['message_id']
args['event'] = 'message_end'
super().__init__(**args)
class ChatStreamResponse(StreamingResponse):
TEXT_PREFIX = "data:"
DATA_PREFIX = "data:"
@@ -177,10 +256,15 @@ class ChatStreamResponse(StreamingResponse):
data_str = json.dumps(data)
return f"{cls.DATA_PREFIX}{data_str}\n"
@classmethod
def convert_event(cls, event: DifyChatResponseEvent):
data_str = json.dumps(event.dict())
return f"{cls.DATA_PREFIX}{data_str}\n"
def __init__(
self,
request: Request,
event_handler: EventCallbackHandler,
event_handler: ChatEventCallbackHandler,
response: StreamingAgentChatResponse,
data: ChatRequestData
):
@@ -193,24 +277,38 @@ class ChatStreamResponse(StreamingResponse):
async def content_generator(
cls,
request: Request,
event_handler: EventCallbackHandler,
event_handler: ChatEventCallbackHandler,
response: StreamingAgentChatResponse,
data: ChatRequestData
):
ids = IDManager().createID()
# Yield the text response
async def _chat_response_generator():
final_response = ""
async for token in response.async_response_gen():
final_response += token
yield ChatStreamResponse.convert_text(token)
args = ids
args['answer'] = token
args['conversation_id'] = data.conversation_id
event = Message_DifyChatResponseEvent(**args)
yield ChatStreamResponse.convert_event(event)
#yield ChatStreamResponse.convert_text(token)
# 存储消息历史
#message = Message(data.conversation_id, data.query, answer=final_response)
#messageManager.addmessage(message)
message().add(user_id=data.user,conversation_id=data.conversation_id,query=data.query,answer=final_response)
# the text_generator is the leading stream, once it's finished, also finish the event stream
event_handler.is_done = True
# 发送工作流结束事件
args = ids
args['response'] = final_response
args['conversation_id'] = data.conversation_id
wf_event = Workflow_finished_DifyChatResponseEvent(**args)
yield ChatStreamResponse.convert_event(wf_event)
msgEnt_event = MessageEnd_DifyChatResponseEvent(**ids)
yield ChatStreamResponse.convert_event(msgEnt_event)
# Yield the events from the event handler
async def _event_generator():
@@ -225,8 +323,18 @@ class ChatStreamResponse(StreamingResponse):
async for output in streamer:
if not is_stream_started:
is_stream_started = True
# 发送工作流开始事件
args = ids
args['use_id'] = data.user
args['query'] = data.query
args['conversation_id'] = data.conversation_id
wf_event = Workflow_started_DifyChatResponseEvent(**args)
yield ChatStreamResponse.convert_event(wf_event)
# Stream a blank message to start the stream
yield ChatStreamResponse.convert_text("")
# 发送一个空消息事件
#yield ChatStreamResponse.convert_text("")
yield output
@@ -236,6 +344,7 @@ class ChatStreamResponse(StreamingResponse):
@v.post("/chat-messages")
async def post_conversations(request: Request, data: ChatRequestData):
userMng.findNoExistCreate(data.user)
data.conversation_id = default_conversation_id if data.conversation_id is None else data.conversation_id
conversaObj = conversations()
conversationinfo = conversaObj.get(data.user, data.conversation_id)
@@ -251,7 +360,7 @@ async def post_conversations(request: Request, data: ChatRequestData):
chat_engine = get_chat_engine(filters=filters, params=params)
# 启动聊天事件监听
event_handler = EventCallbackHandler()
event_handler = ChatEventCallbackHandler()
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
# 执行异步聊天
@@ -262,7 +371,25 @@ async def post_conversations(request: Request, data: ChatRequestData):
@v.get("/messages")
async def query_messages(user:str, conversation_id:str):
pass
conversation_id = default_conversation_id if conversation_id is None else conversation_id
datas = []
records = message().gets(user,conversation_id)
for record in records:
res = record.dict()
res["message_files"] = []
res["feedback"] = ''
res["retriever_resources"] = []
res["created_at"] = 1723444905
res["agent_thoughts"] = []
res["status"] = "normal"
res["error"] = ''
datas.append(res)
return {
"limit": 20,
"has_more": False,
"data": datas
}
@v.post("/conversations/{itemid}/name")
async def post_conversations(user:str):
+16 -12
View File
@@ -1,5 +1,5 @@
from datetime import datetime
import uuid
from app.api.routers.request.baseConfig import BaseConfig
from app.api.routers.request.dbOrm import DBManager
@@ -25,16 +25,11 @@ class conversations:
return None
def add(self,user_id:str,name:str,id:str = ''):
import uuid
if id == '':
id= str(uuid.uuid4())
template = BaseConfig.ConversationCfg
template['id'] = id
template['user_id'] = user_id
template['name'] = name
template['created_at'] = 1724399038
dbManage.addRecord(self._tableName,template)
def delete(self,id:str):
@@ -70,7 +65,7 @@ class userMng:
@classmethod
def findNoExistCreate(cls,user_id:str):
userInfo = cls.userObj.get(user_id)
if userInfo is None:
if len(userInfo) == 0:
cls.userObj.add(user_id)
def remove(cls,user_id:str):
@@ -116,14 +111,23 @@ class message:
self._tableName = 'messages'
dbManage.createTable(self._tableName)
def gets(self,user_id:str):
return dbManage.query(self._tableName,user_id = user_id)
def gets(self,user_id:str,conversation_id:str):
records = dbManage.query(self._tableName,user_id = user_id,conversation_id = conversation_id)
datas = []
for record in records:
datas.append(record)
return datas
def add(self,user_id:str):
dbManage.addRecord(self._tableName,{})
def add(self,user_id:str,conversation_id:str,query:str,answer:str):
template = BaseConfig.MessageCfg
template['id'] = str(uuid.uuid4())
template['user_id'] = user_id
template['conversation_id'] = conversation_id
template['query'] = query
template['answer'] = answer
dbManage.addRecord(self._tableName,template)
def delete(self,user_id:str):
dbManage.delete(self._tableName,user_id = user_id)
@@ -50,3 +50,13 @@ class BaseConfig:
"introduction": ParamterCfg['opening_statement'],
"created_at":''
}
MessageCfg = {
"id": "",
'user_id':'',
"conversation_id": "",
"inputs": {},
"query": "",
"answer": ""
}
+2 -2
View File
@@ -41,7 +41,7 @@ class MessagesOrm(Base):
conversation_id = Column(String)
inputs = Column(JSON)
query = Column(String)
answer = Column(JSON)
answer = Column(String)
#数据结构
class ConversationModel(BaseModel):
@@ -90,7 +90,7 @@ class MessagesModel(BaseModel):
conversation_id :str
inputs : Dict[str, Any]
query : str
answer : Dict[str, Any]
answer : str
class Config:
#orm_mode = True
@@ -11,4 +11,3 @@ class ChatRequestData(BaseModel):
response_mode: str
files: Any
conversation_id: str = None