599 lines
21 KiB
Python
599 lines
21 KiB
Python
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import time
|
|
from typing import Dict, List, Any, Optional, AsyncGenerator
|
|
|
|
from aiostream import stream
|
|
from fastapi import APIRouter, Request,HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
from llama_index.core import BaseCallbackHandler
|
|
from llama_index.core.base.llms.types import ChatMessage
|
|
from llama_index.core.callbacks import CBEventType
|
|
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
|
|
from llama_index.core.tools import ToolOutput
|
|
from llama_index.core.schema import NodeWithScore
|
|
from pydantic import BaseModel
|
|
from app.api.routers.request.base import userMng, conversations,message,ProjectInfo,feedback
|
|
from app.api.routers.request.baseConfig import *
|
|
from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest
|
|
from app.engine import get_chat_engine
|
|
import uuid
|
|
from app.api.routers.services.fileServices import PrjFileLoadService,ChatFileService
|
|
from app.api.routers.services.suggestion import NextQuestionSuggestion
|
|
import time
|
|
from llama_index.core.settings import Settings
|
|
from llama_index.core.callbacks import CallbackManager
|
|
|
|
logger = logging.getLogger("uvicorn")
|
|
|
|
v1_router = v = APIRouter()
|
|
|
|
Settings.llm.callback_manager = CallbackManager()
|
|
|
|
gEvent_handler = None
|
|
|
|
|
|
class ChatCallbackEvent(BaseModel):
|
|
event_type: ChatEventType
|
|
payload: Optional[Dict[str, Any]] = None
|
|
|
|
def get_common_param(self)-> dict:
|
|
return {
|
|
'event': self.event_type.value,
|
|
'conversation_id':self.payload.get("conversation_id"),
|
|
'message_id': self.payload.get("message_id"),
|
|
'created_at': int(time.time()),
|
|
'task_id': self.payload.get("task_id")
|
|
}
|
|
|
|
def get_WorkflowStart_param(self) -> dict:
|
|
params = self.get_common_param()
|
|
params.update({
|
|
'workflow_run_id':self.payload.get('workflow_run_id'),
|
|
'data':{
|
|
"id": self.payload.get('workflow_run_id'),
|
|
"workflow_id": self.payload.get('workflow_id'),
|
|
"sequence_number": 1709,
|
|
"inputs": {
|
|
"sys.query": f"开始查询 {self.payload.get('query')}",
|
|
"sys.files": [],
|
|
"sys.conversation_id": self.payload.get('conversation_id'),
|
|
"sys.user_id": self.payload.get('use_id')
|
|
},
|
|
"created_at": int(time.time())
|
|
}
|
|
})
|
|
return params
|
|
|
|
def get_WorkflowFinished_param(self) -> dict:
|
|
params = self.get_common_param()
|
|
params.update({
|
|
'workflow_run_id':self.payload.get('workflow_run_id'),
|
|
'data':{
|
|
"id": self.payload.get('workflow_run_id'),
|
|
"workflow_id": self.payload.get('workflow_id'),
|
|
"sequence_number": 1709,
|
|
"status": "succeeded",
|
|
"outputs": {
|
|
"answer": self.payload.get('response')
|
|
},
|
|
"error": '',
|
|
"elapsed_time": 36.03764106379822,
|
|
"total_tokens": 11707,
|
|
"total_steps": 10,
|
|
"created_by": {
|
|
"id": str(uuid.uuid4()),
|
|
"user": self.payload.get('use_id')
|
|
},
|
|
"created_at": int(time.time()),
|
|
"finished_at": int(time.time()),
|
|
"files": []
|
|
}
|
|
})
|
|
return params
|
|
|
|
def get_NodeStart_param(self) -> dict:
|
|
params = self.get_common_param()
|
|
params.update({
|
|
'workflow_run_id':self.payload.get('workflow_run_id'),
|
|
'data':{
|
|
"id": self.payload.get('nodeid'),
|
|
"node_id": self.payload.get('nodeid'),
|
|
"node_type": "http-request",
|
|
"title": f"正在执行事件:{self.payload.get('title')}",
|
|
"index": self.payload.get('index'),
|
|
"predecessor_node_id": self.payload.get('predecessor_node_id'),
|
|
"inputs": '',
|
|
"created_at": 1724398751,
|
|
"extras": {}
|
|
}
|
|
})
|
|
return params
|
|
|
|
def get_NodeFinished_param(self) -> dict:
|
|
params = self.get_common_param()
|
|
params.update({
|
|
'workflow_run_id':self.payload.get('workflow_run_id'),
|
|
'data':{
|
|
"id": self.payload.get('nodeid'),
|
|
"node_id": self.payload.get('nodeid'),
|
|
"node_type": "http-request",
|
|
"title": f"事件执行结束:{self.payload.get('title')}",
|
|
"index": self.payload.get('index'),
|
|
"predecessor_node_id": self.payload.get('predecessor_node_id'),
|
|
"inputs": '',
|
|
"process_data": '',
|
|
"outputs": '',
|
|
"status": "succeeded",
|
|
"error": '',
|
|
"elapsed_time": 0.10402441816404462,
|
|
"execution_metadata": '',
|
|
"created_at": 1724398751,
|
|
"finished_at": 1724398751,
|
|
"files": []
|
|
}
|
|
})
|
|
return params
|
|
|
|
def get_Message_param(self) -> dict:
|
|
params = self.get_common_param()
|
|
params.update({
|
|
'id':self.payload.get('message_id'),
|
|
'answer':self.payload.get('answer')
|
|
})
|
|
return params
|
|
|
|
def get_MessageEnd_param(self) -> dict:
|
|
params = self.get_common_param()
|
|
nodeInfos = []
|
|
source_nodes = self.payload.get('source_node')
|
|
if source_nodes is not None:
|
|
for i in range(len(source_nodes)):
|
|
source_node:NodeWithScore = source_nodes[i]
|
|
metadata:dict = source_node.node.metadata
|
|
nodeInfo = {
|
|
"position": i,
|
|
"dataset_id": metadata.get("pipeline_id"),
|
|
"dataset_name": metadata.get("file_name"),
|
|
"document_id": source_node.node_id,
|
|
"document_name": metadata.get("file_name"),
|
|
"data_source_type": "upload_file",
|
|
"segment_id": source_node.node_id,
|
|
"retriever_from": "workflow",
|
|
"score": source_node.score,
|
|
"hit_count": 1,
|
|
"word_count": 632,
|
|
"segment_position": i,
|
|
"index_node_hash": "",
|
|
"content": source_node.text
|
|
}
|
|
nodeInfos.append(nodeInfo)
|
|
params.update({
|
|
'id':self.payload.get('message_id'),
|
|
'metadata':{
|
|
"retriever_resources":nodeInfos,
|
|
"usage":{
|
|
"prompt_tokens": 4972,
|
|
"prompt_unit_price": "0.0",
|
|
"prompt_price_unit": "0.0",
|
|
"prompt_price": "0.0",
|
|
"completion_tokens": 332,
|
|
"completion_unit_price": "0.0",
|
|
"completion_price_unit": "0.0",
|
|
"completion_price": "0.0",
|
|
"total_tokens": 5304,
|
|
"total_price": "0.0",
|
|
"currency": "USD",
|
|
"latency": 4.897703120019287
|
|
}
|
|
}
|
|
})
|
|
return params
|
|
|
|
def to_response(self)-> dict|None:
|
|
try:
|
|
match self.event_type.value:
|
|
case "workflow_started":
|
|
return self.get_WorkflowStart_param()
|
|
case "workflow_finished":
|
|
return self.get_WorkflowFinished_param()
|
|
case "node_started":
|
|
return self.get_NodeStart_param()
|
|
case 'node_finished':
|
|
return self.get_NodeFinished_param()
|
|
case 'message':
|
|
return self.get_Message_param()
|
|
case 'message_end':
|
|
return self.get_MessageEnd_param()
|
|
case _:
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"转换回应时间时发生错误,原因: {e}")
|
|
return None
|
|
|
|
class ChatEventCallbackHandler(BaseCallbackHandler):
|
|
_aqueue: asyncio.Queue
|
|
is_done: bool = False
|
|
|
|
def __init__(self):
|
|
"""Initialize the base callback handler."""
|
|
ignored_events = [
|
|
# CBEventType.CHUNKING,
|
|
# CBEventType.NODE_PARSING,
|
|
# CBEventType.EMBEDDING,
|
|
# CBEventType.LLM,
|
|
# CBEventType.TEMPLATING,
|
|
]
|
|
super().__init__(ignored_events, ignored_events)
|
|
self._aqueue = asyncio.Queue()
|
|
self._response: StreamingAgentChatResponse = None
|
|
self._ids:Dict[str,Any] = {}
|
|
self._chatData:ChatRequestData = None
|
|
self._nodeStack:List[str] = []
|
|
self._firstEventID:str = None
|
|
|
|
def setInitParams(self,ids:dict,data:ChatRequestData):
|
|
self._ids = ids
|
|
self._chatData = data
|
|
self._firstEventID = None
|
|
|
|
def setResponse(self,response: StreamingAgentChatResponse):
|
|
self._response = response
|
|
|
|
def on_event_start(
|
|
self,
|
|
event_type: CBEventType,
|
|
payload: Optional[Dict[str, Any]] = None,
|
|
event_id: str = "",
|
|
**kwargs: Any,
|
|
) -> str:
|
|
if self._firstEventID is None:
|
|
self._firstEventID = event_id
|
|
self.start()
|
|
|
|
logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload))
|
|
|
|
self._nodeStack.append(event_id)
|
|
nindex = len(self._nodeStack) - 1
|
|
args:Dict[str,Any] = self._ids
|
|
args.update(
|
|
{
|
|
'nodeid':event_id,
|
|
'title':event_type.name,
|
|
'index':nindex + 1,
|
|
'predecessor_node_id': self._nodeStack[nindex - 1] if nindex > 0 else ''
|
|
}
|
|
)
|
|
nd_event = ChatCallbackEvent(event_type = ChatEventType.NODE_START,payload = args)
|
|
if nd_event.to_response() is not None:
|
|
self._aqueue.put_nowait(nd_event)
|
|
|
|
def on_event_end(
|
|
self,
|
|
event_type: CBEventType,
|
|
payload: Optional[Dict[str, Any]] = None,
|
|
event_id: str = "",
|
|
**kwargs: Any,
|
|
) -> None:
|
|
logger.info("event_end:{} type:{} payload:{}\n".format(event_id, event_type, payload))
|
|
|
|
#self.response = payload.get("response","")
|
|
args:Dict[str,Any] = self._ids
|
|
nodeID = self._nodeStack[-1]
|
|
if nodeID == event_id:
|
|
nindex = len(self._nodeStack) - 1
|
|
args.update(
|
|
{
|
|
'nodeid':event_id,
|
|
'title':event_type.name,
|
|
'index':nindex + 1,
|
|
'predecessor_node_id':self._nodeStack[nindex - 1] if nindex > 0 else ''
|
|
}
|
|
)
|
|
nd_event = ChatCallbackEvent(event_type = ChatEventType.NODE_FINISHED,payload = args)
|
|
if nd_event.to_response() is not None:
|
|
self._aqueue.put_nowait(nd_event)
|
|
self._nodeStack.pop()
|
|
|
|
if self._firstEventID is not None and self._firstEventID == event_id:
|
|
self.finished()
|
|
|
|
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
|
"""No-op."""
|
|
logger.info("trace_start:{}\n".format(trace_id))
|
|
|
|
def end_trace(
|
|
self,
|
|
trace_id: Optional[str] = None,
|
|
trace_map: Optional[Dict[str, List[str]]] = None,
|
|
) -> None:
|
|
"""No-op."""
|
|
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
|
|
|
|
async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]:
|
|
while not self._aqueue.empty() or not self.is_done:
|
|
try:
|
|
yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1)
|
|
except asyncio.TimeoutError:
|
|
pass
|
|
|
|
def makeWorkflow_startEvent(self)->ChatCallbackEvent:
|
|
args:Dict[str,Any] = self._ids
|
|
args.update(
|
|
{
|
|
'use_id': self._chatData.user,
|
|
'query': self._chatData.query,
|
|
'conversation_id': self._chatData.conversation_id
|
|
}
|
|
)
|
|
return ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args)
|
|
|
|
def makeWorkflow_finishedEvent(self)->ChatCallbackEvent:
|
|
args:Dict[str,Any] = self._ids
|
|
args.update(
|
|
{
|
|
'response': '',
|
|
'conversation_id': self._chatData.conversation_id
|
|
}
|
|
)
|
|
return ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args)
|
|
|
|
def makeMessage_EndEvent(self)->ChatCallbackEvent:
|
|
args:Dict[str,Any] = self._ids
|
|
if self._response is not None:
|
|
args.update({
|
|
'source_node': self._response.source_nodes
|
|
})
|
|
msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args)
|
|
return msgEnt_event
|
|
|
|
def start(self):
|
|
#添加工作流开始事件
|
|
wf_event = self.makeWorkflow_startEvent()
|
|
if wf_event.to_response() is not None:
|
|
self._aqueue.put_nowait(wf_event)
|
|
|
|
def finished(self):
|
|
wf_event = self.makeWorkflow_finishedEvent()
|
|
if wf_event.to_response() is not None:
|
|
self._aqueue.put_nowait(wf_event)
|
|
|
|
msgEnt_event = self.makeMessage_EndEvent()
|
|
if msgEnt_event.to_response() is not None:
|
|
self._aqueue.put_nowait(msgEnt_event)
|
|
|
|
class IDManager:
|
|
def createID(self):
|
|
return {
|
|
"message_id" : str(uuid.uuid4()),
|
|
'task_id':str(uuid.uuid4()),
|
|
'workflow_run_id': str(uuid.uuid4()),
|
|
"workflow_id": str(uuid.uuid4())
|
|
}
|
|
|
|
class ChatStreamResponse(StreamingResponse):
|
|
TEXT_PREFIX = "data: "
|
|
DATA_PREFIX = "data: "
|
|
ids:Dict[str,Any] = {}
|
|
data:ChatRequestData = None
|
|
|
|
@classmethod
|
|
def convert_Message(cls, token: str):
|
|
params = cls.ids
|
|
params.update({
|
|
'answer':token,
|
|
'conversation_id':cls.data.conversation_id
|
|
})
|
|
event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE,payload = params)
|
|
data_str = json.dumps(event.to_response())
|
|
return f"{cls.DATA_PREFIX}{data_str}\n\n"
|
|
|
|
@classmethod
|
|
def convert_Event(cls, data: dict):
|
|
data_str = json.dumps(data)
|
|
return f"{cls.DATA_PREFIX}{data_str}\n\n"
|
|
|
|
def __init__(
|
|
self,
|
|
request: Request,
|
|
event_handler: ChatEventCallbackHandler,
|
|
response: StreamingAgentChatResponse,
|
|
data: ChatRequestData,
|
|
ids:Dict[str,Any]
|
|
):
|
|
ChatStreamResponse.ids = ids
|
|
ChatStreamResponse.data = data
|
|
content = ChatStreamResponse.content_generator(
|
|
request, event_handler, response, data
|
|
)
|
|
super().__init__(content=content)
|
|
|
|
@classmethod
|
|
async def content_generator(
|
|
cls,
|
|
request: Request,
|
|
event_handler: ChatEventCallbackHandler,
|
|
response: StreamingAgentChatResponse,
|
|
data: ChatRequestData
|
|
):
|
|
|
|
# Yield the text response
|
|
async def _chat_response_generator():
|
|
final_response = ""
|
|
async for token in response.async_response_gen():
|
|
final_response += token
|
|
yield ChatStreamResponse.convert_Message(token)
|
|
|
|
# 存储消息历史
|
|
message().add(user_id=data.user,conversation_id=data.conversation_id,query=data.query,answer=final_response)
|
|
|
|
# the text_generator is the leading stream, once it's finished, also finish the event stream
|
|
event_handler.is_done = True
|
|
event_handler.setResponse(response)
|
|
|
|
# Yield the events from the event handler
|
|
async def _event_generator():
|
|
async for event in event_handler.async_event_gen():
|
|
event_response = event.to_response()
|
|
if event_response is not None:
|
|
yield ChatStreamResponse.convert_Event(event_response)
|
|
|
|
combine = stream.merge(_chat_response_generator(), _event_generator())
|
|
is_stream_started = False
|
|
async with combine.stream() as streamer:
|
|
async for output in streamer:
|
|
if not is_stream_started:
|
|
is_stream_started = True
|
|
|
|
yield output
|
|
|
|
if await request.is_disconnected():
|
|
break
|
|
|
|
@v.post("/chat-messages")
|
|
async def post_chatmessages(request: Request, data: ChatRequestData):
|
|
global gEvent_handler
|
|
userMng.findNoExistCreate(data.user)
|
|
data.conversation_id = data.conversation_id if data.conversation_id else str(uuid.uuid4())
|
|
|
|
conversaObj = conversations()
|
|
conversationinfo = conversaObj.get(data.conversation_id)
|
|
if conversationinfo is None:
|
|
conversationinfo = conversaObj.add(data.conversation_id, data.user, "新建会话",inputs= data.inputs)
|
|
|
|
# 生成聊天参数
|
|
last_message_content = ChatMessage.from_str(data.query)
|
|
filters = None
|
|
params = data.inputs or {}
|
|
|
|
# 启动聊天事件监听
|
|
ids = IDManager().createID()
|
|
if gEvent_handler is None:
|
|
gEvent_handler = ChatEventCallbackHandler()
|
|
Settings.llm.callback_manager.handlers.append(gEvent_handler)
|
|
|
|
if gEvent_handler is not None:
|
|
gEvent_handler.setInitParams(ids = ids,data = data)
|
|
|
|
# 获取聊天引擎对象
|
|
chat_engine = get_chat_engine(filters=filters, params=params)
|
|
# 执行异步聊天
|
|
response = await chat_engine.astream_chat(data.query)
|
|
# 返回异步消息回应
|
|
return ChatStreamResponse(request, gEvent_handler, response, data,ids)
|
|
|
|
@v.get("/messages")
|
|
async def query_messages(user:str, conversation_id:str):
|
|
#conversation_id = default_conversation_id if conversation_id is None else conversation_id
|
|
datas = []
|
|
records = message().gets(user,conversation_id)
|
|
if records is None:
|
|
return {
|
|
"limit": 20,
|
|
"has_more": False,
|
|
"data": []
|
|
}
|
|
|
|
for record in records:
|
|
res = record.dict()
|
|
feeds = feedback().query(res['id'])
|
|
res["message_files"] = []
|
|
res["feedback"] = {'rating':feeds['rating'] } if feeds != None else ''
|
|
res["retriever_resources"] = []
|
|
res["created_at"] = 1723444905
|
|
res["agent_thoughts"] = []
|
|
res["status"] = "normal"
|
|
res["error"] = ''
|
|
datas.append(res)
|
|
|
|
return {
|
|
"limit": 20,
|
|
"has_more": False,
|
|
"data": datas
|
|
}
|
|
|
|
@v.post("/conversations/{itemid}/name")
|
|
async def post_conversations(request: Request,itemid:str,params:Dict[str,Any]):
|
|
consaObj = conversations()
|
|
consaObj.rename(itemid,'知识问答')
|
|
cond = {
|
|
'id':itemid,
|
|
'user_id':params['user']
|
|
}
|
|
results = consaObj.query(**cond)
|
|
if len(results) > 0:
|
|
res = results[0]
|
|
return {
|
|
"id": res['id'],
|
|
"name": res['name'],
|
|
"inputs": res['inputs'],
|
|
"status": res['status'],
|
|
"introduction": res['introduction'],
|
|
"created_at": res['created_at'],
|
|
#"工程位置"
|
|
}
|
|
return 'null'
|
|
|
|
@v.get("/conversations")
|
|
async def query_conversations(user:str, first_id:str = None, limit:str = None, pinned:str = None):
|
|
user_id = '' if user is None else user
|
|
userMng.findNoExistCreate(user_id)
|
|
|
|
return {
|
|
"limit": 20,
|
|
"has_more": False,
|
|
"data": conversations().gets(user_id)
|
|
}
|
|
|
|
@v.get("/parameters")
|
|
async def query_parameters(user:str):
|
|
prjObj = ProjectInfo()
|
|
return BaseConfig().ParamterCfg(projectInfo = prjObj.projectNames())
|
|
|
|
@v.post("/messages/{message_id}/feedbacks")
|
|
async def post_feedbacks(request: Request,message_id:str,params:Dict[str,Any]):
|
|
if params['rating'] is None:
|
|
feedback().delete(message_id)
|
|
else:
|
|
results = message().query(message_id)
|
|
if len(results) > 0:
|
|
result = results[0]
|
|
feedback().add(message_id=message_id,query=result['query'],
|
|
answer=result['answer'],rating=params['rating'])
|
|
|
|
@v.post("/files/upload")
|
|
def upload_file(request: ChatFileUploadRequest):
|
|
try:
|
|
logger.info("Processing file")
|
|
resluts = ChatFileService.process_file(request.base64)
|
|
return {
|
|
'id':resluts.get('id'),
|
|
'name': resluts.get('name'),
|
|
'size': resluts.get('size'),
|
|
'extension':resluts.get('extension'),
|
|
'mime_type':resluts.get('mime_type'),
|
|
'created_by':str(uuid.uuid4()),
|
|
'created_at':int(time.time())
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error processing file: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail="Error processing file")
|
|
|
|
@v.post("/project")
|
|
def upload_file(request: ChatFileUploadRequest):
|
|
try:
|
|
logger.info("Processing file")
|
|
return PrjFileLoadService.process_file(request.base64)
|
|
except Exception as e:
|
|
logger.error(f"Error processing file: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail="Error processing file")
|
|
|
|
@v.post("/messages/{message_id}/suggested")
|
|
async def post_suggested(request: Request,message_id:str,user:str):
|
|
questions = await NextQuestionSuggestion.suggest_next_questions(message_id)
|
|
return {
|
|
"result": "success",
|
|
"data":questions
|
|
} |