支持节点元数据输出及代码优化,减少事件重复添加
This commit is contained in:
@@ -23,11 +23,18 @@ import uuid
|
||||
from app.api.routers.services.fileServices import PrjFileLoadService,ChatFileService
|
||||
from app.api.routers.services.suggestion import NextQuestionSuggestion
|
||||
import time
|
||||
from llama_index.core.settings import Settings
|
||||
from llama_index.core.callbacks import CallbackManager
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
v1_router = v = APIRouter()
|
||||
|
||||
Settings.llm.callback_manager = CallbackManager()
|
||||
|
||||
gEvent_handler = None
|
||||
|
||||
|
||||
class ChatCallbackEvent(BaseModel):
|
||||
event_type: ChatEventType
|
||||
payload: Optional[Dict[str, Any]] = None
|
||||
@@ -210,7 +217,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
_aqueue: asyncio.Queue
|
||||
is_done: bool = False
|
||||
|
||||
def __init__(self,**params):
|
||||
def __init__(self):
|
||||
"""Initialize the base callback handler."""
|
||||
ignored_events = [
|
||||
# CBEventType.CHUNKING,
|
||||
@@ -222,13 +229,15 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
super().__init__(ignored_events, ignored_events)
|
||||
self._aqueue = asyncio.Queue()
|
||||
self._response: StreamingAgentChatResponse = None
|
||||
self._params:Dict[str,Any] = params
|
||||
self._ids:Dict[str,Any] = {}
|
||||
self._chatData:ChatRequestData = None
|
||||
self._nodeStack:List[str] = []
|
||||
|
||||
#添加工作流开始事件
|
||||
wf_event = self.makeWorkflow_startEvent()
|
||||
if wf_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(wf_event)
|
||||
self._firstEventID:str = None
|
||||
|
||||
def setInitParams(self,ids:dict,data:ChatRequestData):
|
||||
self._ids = ids
|
||||
self._chatData = data
|
||||
self._firstEventID = None
|
||||
|
||||
def setResponse(self,response: StreamingAgentChatResponse):
|
||||
self._response = response
|
||||
@@ -240,11 +249,15 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if self._firstEventID is None:
|
||||
self._firstEventID = event_id
|
||||
self.start()
|
||||
|
||||
logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload))
|
||||
|
||||
self._nodeStack.append(event_id)
|
||||
nindex = len(self._nodeStack) - 1
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
args:Dict[str,Any] = self._ids
|
||||
args.update(
|
||||
{
|
||||
'nodeid':event_id,
|
||||
@@ -267,7 +280,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
logger.info("event_end:{} type:{} payload:{}\n".format(event_id, event_type, payload))
|
||||
|
||||
#self.response = payload.get("response","")
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
args:Dict[str,Any] = self._ids
|
||||
nodeID = self._nodeStack[-1]
|
||||
if nodeID == event_id:
|
||||
nindex = len(self._nodeStack) - 1
|
||||
@@ -284,6 +297,9 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
self._aqueue.put_nowait(nd_event)
|
||||
self._nodeStack.pop()
|
||||
|
||||
if self._firstEventID is not None and self._firstEventID == event_id:
|
||||
self.finished()
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
"""No-op."""
|
||||
logger.info("trace_start:{}\n".format(trace_id))
|
||||
@@ -294,15 +310,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
trace_map: Optional[Dict[str, List[str]]] = None,
|
||||
) -> None:
|
||||
"""No-op."""
|
||||
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
|
||||
|
||||
wf_event = self.makeWorkflow_finishedEvent()
|
||||
if wf_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(wf_event)
|
||||
|
||||
msgEnt_event = self.makeMessage_EndEvent()
|
||||
if msgEnt_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(msgEnt_event)
|
||||
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
|
||||
|
||||
async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]:
|
||||
while not self._aqueue.empty() or not self.is_done:
|
||||
@@ -312,30 +320,28 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
pass
|
||||
|
||||
def makeWorkflow_startEvent(self)->ChatCallbackEvent:
|
||||
data:ChatRequestData = self._params['data']
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
args:Dict[str,Any] = self._ids
|
||||
args.update(
|
||||
{
|
||||
'use_id': data.user,
|
||||
'query': data.query,
|
||||
'conversation_id': data.conversation_id
|
||||
'use_id': self._chatData.user,
|
||||
'query': self._chatData.query,
|
||||
'conversation_id': self._chatData.conversation_id
|
||||
}
|
||||
)
|
||||
return ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args)
|
||||
|
||||
def makeWorkflow_finishedEvent(self)->ChatCallbackEvent:
|
||||
data:ChatRequestData = self._params['data']
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
def makeWorkflow_finishedEvent(self)->ChatCallbackEvent:
|
||||
args:Dict[str,Any] = self._ids
|
||||
args.update(
|
||||
{
|
||||
'response': '',
|
||||
'conversation_id': data.conversation_id
|
||||
'conversation_id': self._chatData.conversation_id
|
||||
}
|
||||
)
|
||||
return ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args)
|
||||
|
||||
def makeMessage_EndEvent(self)->ChatCallbackEvent:
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
args:Dict[str,Any] = self._ids
|
||||
if self._response is not None:
|
||||
args.update({
|
||||
'source_node': self._response.source_nodes
|
||||
@@ -343,6 +349,21 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args)
|
||||
return msgEnt_event
|
||||
|
||||
def start(self):
|
||||
#添加工作流开始事件
|
||||
wf_event = self.makeWorkflow_startEvent()
|
||||
if wf_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(wf_event)
|
||||
|
||||
def finished(self):
|
||||
wf_event = self.makeWorkflow_finishedEvent()
|
||||
if wf_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(wf_event)
|
||||
|
||||
msgEnt_event = self.makeMessage_EndEvent()
|
||||
if msgEnt_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(msgEnt_event)
|
||||
|
||||
class IDManager:
|
||||
def createID(self):
|
||||
return {
|
||||
@@ -410,6 +431,7 @@ class ChatStreamResponse(StreamingResponse):
|
||||
|
||||
# the text_generator is the leading stream, once it's finished, also finish the event stream
|
||||
event_handler.is_done = True
|
||||
event_handler.setResponse(response)
|
||||
|
||||
# Yield the events from the event handler
|
||||
async def _event_generator():
|
||||
@@ -431,7 +453,8 @@ class ChatStreamResponse(StreamingResponse):
|
||||
break
|
||||
|
||||
@v.post("/chat-messages")
|
||||
async def post_conversations(request: Request, data: ChatRequestData):
|
||||
async def post_chatmessages(request: Request, data: ChatRequestData):
|
||||
global gEvent_handler
|
||||
userMng.findNoExistCreate(data.user)
|
||||
data.conversation_id = data.conversation_id if data.conversation_id else str(uuid.uuid4())
|
||||
|
||||
@@ -445,19 +468,21 @@ async def post_conversations(request: Request, data: ChatRequestData):
|
||||
filters = None
|
||||
params = data.inputs or {}
|
||||
|
||||
# 获取聊天引擎对象
|
||||
chat_engine = get_chat_engine(filters=filters, params=params)
|
||||
|
||||
# 启动聊天事件监听
|
||||
ids = IDManager().createID()
|
||||
event_handler = ChatEventCallbackHandler(ids = ids,data = data)
|
||||
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
|
||||
if gEvent_handler is None:
|
||||
gEvent_handler = ChatEventCallbackHandler()
|
||||
Settings.llm.callback_manager.handlers.append(gEvent_handler)
|
||||
|
||||
if gEvent_handler is not None:
|
||||
gEvent_handler.setInitParams(ids = ids,data = data)
|
||||
|
||||
# 获取聊天引擎对象
|
||||
chat_engine = get_chat_engine(filters=filters, params=params)
|
||||
# 执行异步聊天
|
||||
response = await chat_engine.astream_chat(data.query)
|
||||
event_handler.setResponse(response)
|
||||
# 返回异步消息回应
|
||||
return ChatStreamResponse(request, event_handler, response, data,ids)
|
||||
return ChatStreamResponse(request, gEvent_handler, response, data,ids)
|
||||
|
||||
@v.get("/messages")
|
||||
async def query_messages(user:str, conversation_id:str):
|
||||
|
||||
@@ -102,7 +102,7 @@ def create_query_engine(index, top_k=3, use_reranker=False, filters=None, respon
|
||||
simple_template = simple_template,
|
||||
node_postprocessors=postprocess,
|
||||
use_async=True,
|
||||
streaming=True,
|
||||
streaming=False,
|
||||
ResponseMode = response_mode
|
||||
)
|
||||
|
||||
|
||||
@@ -24,6 +24,8 @@ logger = logging.getLogger("uvicorn")
|
||||
usPrj = using_project(os.getenv("PHOENIX_PROJECT_NAME"))
|
||||
usPrj.__enter__()
|
||||
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
|
||||
init_settings()
|
||||
init_observability()
|
||||
|
||||
Reference in New Issue
Block a user