新增工程信息、检索的知识片段节点回传、下一轮建议问题列表
This commit is contained in:
+128
-51
@@ -3,7 +3,6 @@ import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, List, Any, Optional, AsyncGenerator
|
||||
from collections import deque
|
||||
|
||||
from aiostream import stream
|
||||
from fastapi import APIRouter, Request,HTTPException
|
||||
@@ -13,17 +12,19 @@ from llama_index.core.base.llms.types import ChatMessage
|
||||
from llama_index.core.callbacks import CBEventType
|
||||
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
|
||||
from llama_index.core.tools import ToolOutput
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
from pydantic import BaseModel
|
||||
from app.api.routers.request.base import userMng, conversations,message,parameter,feedback
|
||||
from app.api.routers.request.base import userMng, conversations,message,ProjectInfo,feedback
|
||||
from app.api.routers.request.baseConfig import *
|
||||
from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest
|
||||
from app.engine import get_chat_engine
|
||||
import uuid
|
||||
from app.api.routers.services.fileServices import FileLoadService
|
||||
from app.api.routers.services.fileServices import PrjFileLoadService,ChatFileService
|
||||
from app.api.routers.services.suggestion import NextQuestionSuggestion
|
||||
import time
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
api_router = r = APIRouter()
|
||||
v1_router = v = APIRouter()
|
||||
|
||||
class ChatCallbackEvent(BaseModel):
|
||||
@@ -32,7 +33,7 @@ class ChatCallbackEvent(BaseModel):
|
||||
|
||||
def get_common_param(self)-> dict:
|
||||
return {
|
||||
'event': self.event_type.name,
|
||||
'event': self.event_type.value,
|
||||
'conversation_id':self.payload.get("conversation_id"),
|
||||
'message_id': self.payload.get("message_id"),
|
||||
'created_at': int(time.time()),
|
||||
@@ -48,7 +49,7 @@ class ChatCallbackEvent(BaseModel):
|
||||
"workflow_id": self.payload.get('workflow_id'),
|
||||
"sequence_number": 1709,
|
||||
"inputs": {
|
||||
"sys.query": self.payload.get('query'),
|
||||
"sys.query": f"开始查询 {self.payload.get('query')}",
|
||||
"sys.files": [],
|
||||
"sys.conversation_id": self.payload.get('conversation_id'),
|
||||
"sys.user_id": self.payload.get('use_id')
|
||||
@@ -93,7 +94,7 @@ class ChatCallbackEvent(BaseModel):
|
||||
"id": self.payload.get('nodeid'),
|
||||
"node_id": self.payload.get('nodeid'),
|
||||
"node_type": "http-request",
|
||||
"title": self.payload.get('title'),
|
||||
"title": f"正在执行事件:{self.payload.get('title')}",
|
||||
"index": self.payload.get('index'),
|
||||
"predecessor_node_id": self.payload.get('predecessor_node_id'),
|
||||
"inputs": '',
|
||||
@@ -111,7 +112,7 @@ class ChatCallbackEvent(BaseModel):
|
||||
"id": self.payload.get('nodeid'),
|
||||
"node_id": self.payload.get('nodeid'),
|
||||
"node_type": "http-request",
|
||||
"title": self.payload.get('title'),
|
||||
"title": f"事件执行结束:{self.payload.get('title')}",
|
||||
"index": self.payload.get('index'),
|
||||
"predecessor_node_id": self.payload.get('predecessor_node_id'),
|
||||
"inputs": '',
|
||||
@@ -138,15 +139,54 @@ class ChatCallbackEvent(BaseModel):
|
||||
|
||||
def get_MessageEnd_param(self) -> dict:
|
||||
params = self.get_common_param()
|
||||
nodeInfos = []
|
||||
source_nodes = self.payload.get('source_node')
|
||||
if source_nodes is not None:
|
||||
for i in range(len(source_nodes)):
|
||||
source_node:NodeWithScore = source_nodes[i]
|
||||
metadata:dict = source_node.node.metadata
|
||||
nodeInfo = {
|
||||
"position": i,
|
||||
"dataset_id": metadata.get("pipeline_id"),
|
||||
"dataset_name": metadata.get("file_name"),
|
||||
"document_id": source_node.node_id,
|
||||
"document_name": metadata.get("file_name"),
|
||||
"data_source_type": "upload_file",
|
||||
"segment_id": source_node.node_id,
|
||||
"retriever_from": "workflow",
|
||||
"score": source_node.score,
|
||||
"hit_count": 1,
|
||||
"word_count": 632,
|
||||
"segment_position": i,
|
||||
"index_node_hash": "",
|
||||
"content": source_node.text
|
||||
}
|
||||
nodeInfos.append(nodeInfo)
|
||||
params.update({
|
||||
'id':self.payload.get('message_id'),
|
||||
'metadata':self.payload.get('metadata')
|
||||
'metadata':{
|
||||
"retriever_resources":nodeInfos,
|
||||
"usage":{
|
||||
"prompt_tokens": 4972,
|
||||
"prompt_unit_price": "0.0",
|
||||
"prompt_price_unit": "0.0",
|
||||
"prompt_price": "0.0",
|
||||
"completion_tokens": 332,
|
||||
"completion_unit_price": "0.0",
|
||||
"completion_price_unit": "0.0",
|
||||
"completion_price": "0.0",
|
||||
"total_tokens": 5304,
|
||||
"total_price": "0.0",
|
||||
"currency": "USD",
|
||||
"latency": 4.897703120019287
|
||||
}
|
||||
}
|
||||
})
|
||||
return params
|
||||
|
||||
def to_response(self)-> dict|None:
|
||||
try:
|
||||
match self.event_type:
|
||||
match self.event_type.value:
|
||||
case "workflow_started":
|
||||
return self.get_WorkflowStart_param()
|
||||
case "workflow_finished":
|
||||
@@ -180,24 +220,18 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
]
|
||||
super().__init__(ignored_events, ignored_events)
|
||||
self._aqueue = asyncio.Queue()
|
||||
self._response:str = ''
|
||||
self._response: StreamingAgentChatResponse = None
|
||||
self._params:Dict[str,Any] = params
|
||||
self._nodeStack:deque = deque()
|
||||
self._nodeStack:List[str] = []
|
||||
|
||||
#添加工作流开始事件
|
||||
data:ChatRequestData = self._params['data']
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
args.update(
|
||||
{
|
||||
'use_id': data.user,
|
||||
'query': data.query,
|
||||
'conversation_id': data.conversation_id
|
||||
}
|
||||
)
|
||||
wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args)
|
||||
wf_event = self.makeWorkflow_startEvent()
|
||||
if wf_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(wf_event)
|
||||
|
||||
def setResponse(self,response: StreamingAgentChatResponse):
|
||||
self._response = response
|
||||
|
||||
def on_event_start(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
@@ -208,7 +242,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload))
|
||||
|
||||
self._nodeStack.append(event_id)
|
||||
nindex = self._nodeStack.count() - 1
|
||||
nindex = len(self._nodeStack) - 1
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
args.update(
|
||||
{
|
||||
@@ -222,7 +256,6 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
if nd_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(nd_event)
|
||||
|
||||
|
||||
def on_event_end(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
@@ -236,7 +269,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
nodeID = self._nodeStack[-1]
|
||||
if nodeID == event_id:
|
||||
nindex = self._nodeStack.count() - 1
|
||||
nindex = len(self._nodeStack) - 1
|
||||
args.update(
|
||||
{
|
||||
'nodeid':event_id,
|
||||
@@ -250,7 +283,6 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
self._aqueue.put_nowait(nd_event)
|
||||
self._nodeStack.pop()
|
||||
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
"""No-op."""
|
||||
logger.info("trace_start:{}\n".format(trace_id))
|
||||
@@ -262,23 +294,14 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
) -> None:
|
||||
"""No-op."""
|
||||
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
|
||||
data:ChatRequestData = self._params['data']
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
args.update(
|
||||
{
|
||||
'response':self._response,
|
||||
'conversation_id': data.conversation_id
|
||||
}
|
||||
)
|
||||
wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args)
|
||||
|
||||
wf_event = self.makeWorkflow_finishedEvent()
|
||||
if wf_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(wf_event)
|
||||
|
||||
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args)
|
||||
msgEnt_event = self.makeMessage_EndEvent()
|
||||
if msgEnt_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(msgEnt_event)
|
||||
self._aqueue.put_nowait(msgEnt_event)
|
||||
|
||||
async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]:
|
||||
while not self._aqueue.empty() or not self.is_done:
|
||||
@@ -287,6 +310,38 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
def makeWorkflow_startEvent(self)->ChatCallbackEvent:
|
||||
data:ChatRequestData = self._params['data']
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
args.update(
|
||||
{
|
||||
'use_id': data.user,
|
||||
'query': data.query,
|
||||
'conversation_id': data.conversation_id
|
||||
}
|
||||
)
|
||||
return ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args)
|
||||
|
||||
def makeWorkflow_finishedEvent(self)->ChatCallbackEvent:
|
||||
data:ChatRequestData = self._params['data']
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
args.update(
|
||||
{
|
||||
'response': '',
|
||||
'conversation_id': data.conversation_id
|
||||
}
|
||||
)
|
||||
return ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args)
|
||||
|
||||
def makeMessage_EndEvent(self)->ChatCallbackEvent:
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
if self._response is not None:
|
||||
args.update({
|
||||
'source_node': self._response.source_nodes
|
||||
})
|
||||
msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args)
|
||||
return msgEnt_event
|
||||
|
||||
class IDManager:
|
||||
def createID(self):
|
||||
return {
|
||||
@@ -390,7 +445,7 @@ async def post_conversations(request: Request, data: ChatRequestData):
|
||||
params = data.inputs or {}
|
||||
|
||||
# 获取聊天引擎对象
|
||||
chat_engine = get_chat_engine(filters=filters, params=params,prjFlag = data.prjFlag)
|
||||
chat_engine = get_chat_engine(filters=filters, params=params)
|
||||
|
||||
# 启动聊天事件监听
|
||||
ids = IDManager().createID()
|
||||
@@ -399,7 +454,7 @@ async def post_conversations(request: Request, data: ChatRequestData):
|
||||
|
||||
# 执行异步聊天
|
||||
response = await chat_engine.astream_chat(data.query)
|
||||
|
||||
event_handler.setResponse(response)
|
||||
# 返回异步消息回应
|
||||
return ChatStreamResponse(request, event_handler, response, data,ids)
|
||||
|
||||
@@ -468,29 +523,51 @@ async def query_conversations(user:str, first_id:str = None, limit:str = None, p
|
||||
|
||||
@v.get("/parameters")
|
||||
async def query_parameters(user:str):
|
||||
params = parameter().get(user)
|
||||
if len(params) == 0:
|
||||
params = BaseConfig().ParamterCfg()
|
||||
return params
|
||||
prjObj = ProjectInfo()
|
||||
return BaseConfig().ParamterCfg(projectInfo = prjObj.projectNames())
|
||||
|
||||
@v.post("/messages/{message_id}/feedbacks")
|
||||
async def post_feedbacks(request: Request,message_id:str,params:Dict[str,Any]):
|
||||
if params['rating'] =='null':
|
||||
if params['rating'] is None:
|
||||
feedback().delete(message_id)
|
||||
else:
|
||||
condition = {'id':message_id}
|
||||
results = message().query(**condition)
|
||||
results = message().query(message_id)
|
||||
if len(results) > 0:
|
||||
result = results[0]
|
||||
feedback().add(message_id=message_id,query=result['query'],
|
||||
answer=result['answer'],rating=params['rating'])
|
||||
|
||||
@v.post("")
|
||||
def upload_file(request: ChatFileUploadRequest) -> List[str]:
|
||||
@v.post("/files/upload")
|
||||
def upload_file(request: ChatFileUploadRequest):
|
||||
try:
|
||||
logger.info("Processing file")
|
||||
return FileLoadService.process_file(request.base64)
|
||||
resluts = ChatFileService.process_file(request.base64)
|
||||
return {
|
||||
'id':resluts.get('id'),
|
||||
'name': resluts.get('name'),
|
||||
'size': resluts.get('size'),
|
||||
'extension':resluts.get('extension'),
|
||||
'mime_type':resluts.get('mime_type'),
|
||||
'created_by':str(uuid.uuid4()),
|
||||
'created_at':int(time.time())
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Error processing file")
|
||||
|
||||
@v.post("/applications")
|
||||
def upload_file(request: ChatFileUploadRequest):
|
||||
try:
|
||||
logger.info("Processing file")
|
||||
return PrjFileLoadService.process_file(request.base64)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Error processing file")
|
||||
|
||||
@v.post("/messages/{message_id}/suggested")
|
||||
async def post_suggested(request: Request,message_id:str,user:str):
|
||||
questions = await NextQuestionSuggestion.suggest_next_questions(message_id)
|
||||
return {
|
||||
"result": "success",
|
||||
"data":questions
|
||||
}
|
||||
Reference in New Issue
Block a user