支持节点元数据输出及代码优化,减少事件重复添加
This commit is contained in:
@@ -23,11 +23,18 @@ import uuid
|
|||||||
from app.api.routers.services.fileServices import PrjFileLoadService,ChatFileService
|
from app.api.routers.services.fileServices import PrjFileLoadService,ChatFileService
|
||||||
from app.api.routers.services.suggestion import NextQuestionSuggestion
|
from app.api.routers.services.suggestion import NextQuestionSuggestion
|
||||||
import time
|
import time
|
||||||
|
from llama_index.core.settings import Settings
|
||||||
|
from llama_index.core.callbacks import CallbackManager
|
||||||
|
|
||||||
logger = logging.getLogger("uvicorn")
|
logger = logging.getLogger("uvicorn")
|
||||||
|
|
||||||
v1_router = v = APIRouter()
|
v1_router = v = APIRouter()
|
||||||
|
|
||||||
|
Settings.llm.callback_manager = CallbackManager()
|
||||||
|
|
||||||
|
gEvent_handler = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCallbackEvent(BaseModel):
|
class ChatCallbackEvent(BaseModel):
|
||||||
event_type: ChatEventType
|
event_type: ChatEventType
|
||||||
payload: Optional[Dict[str, Any]] = None
|
payload: Optional[Dict[str, Any]] = None
|
||||||
@@ -210,7 +217,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
|||||||
_aqueue: asyncio.Queue
|
_aqueue: asyncio.Queue
|
||||||
is_done: bool = False
|
is_done: bool = False
|
||||||
|
|
||||||
def __init__(self,**params):
|
def __init__(self):
|
||||||
"""Initialize the base callback handler."""
|
"""Initialize the base callback handler."""
|
||||||
ignored_events = [
|
ignored_events = [
|
||||||
# CBEventType.CHUNKING,
|
# CBEventType.CHUNKING,
|
||||||
@@ -222,13 +229,15 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
|||||||
super().__init__(ignored_events, ignored_events)
|
super().__init__(ignored_events, ignored_events)
|
||||||
self._aqueue = asyncio.Queue()
|
self._aqueue = asyncio.Queue()
|
||||||
self._response: StreamingAgentChatResponse = None
|
self._response: StreamingAgentChatResponse = None
|
||||||
self._params:Dict[str,Any] = params
|
self._ids:Dict[str,Any] = {}
|
||||||
|
self._chatData:ChatRequestData = None
|
||||||
self._nodeStack:List[str] = []
|
self._nodeStack:List[str] = []
|
||||||
|
self._firstEventID:str = None
|
||||||
|
|
||||||
#添加工作流开始事件
|
def setInitParams(self,ids:dict,data:ChatRequestData):
|
||||||
wf_event = self.makeWorkflow_startEvent()
|
self._ids = ids
|
||||||
if wf_event.to_response() is not None:
|
self._chatData = data
|
||||||
self._aqueue.put_nowait(wf_event)
|
self._firstEventID = None
|
||||||
|
|
||||||
def setResponse(self,response: StreamingAgentChatResponse):
|
def setResponse(self,response: StreamingAgentChatResponse):
|
||||||
self._response = response
|
self._response = response
|
||||||
@@ -240,11 +249,15 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
|||||||
event_id: str = "",
|
event_id: str = "",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
if self._firstEventID is None:
|
||||||
|
self._firstEventID = event_id
|
||||||
|
self.start()
|
||||||
|
|
||||||
logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload))
|
logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload))
|
||||||
|
|
||||||
self._nodeStack.append(event_id)
|
self._nodeStack.append(event_id)
|
||||||
nindex = len(self._nodeStack) - 1
|
nindex = len(self._nodeStack) - 1
|
||||||
args:Dict[str,Any] = self._params['ids']
|
args:Dict[str,Any] = self._ids
|
||||||
args.update(
|
args.update(
|
||||||
{
|
{
|
||||||
'nodeid':event_id,
|
'nodeid':event_id,
|
||||||
@@ -267,7 +280,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
|||||||
logger.info("event_end:{} type:{} payload:{}\n".format(event_id, event_type, payload))
|
logger.info("event_end:{} type:{} payload:{}\n".format(event_id, event_type, payload))
|
||||||
|
|
||||||
#self.response = payload.get("response","")
|
#self.response = payload.get("response","")
|
||||||
args:Dict[str,Any] = self._params['ids']
|
args:Dict[str,Any] = self._ids
|
||||||
nodeID = self._nodeStack[-1]
|
nodeID = self._nodeStack[-1]
|
||||||
if nodeID == event_id:
|
if nodeID == event_id:
|
||||||
nindex = len(self._nodeStack) - 1
|
nindex = len(self._nodeStack) - 1
|
||||||
@@ -284,6 +297,9 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
|||||||
self._aqueue.put_nowait(nd_event)
|
self._aqueue.put_nowait(nd_event)
|
||||||
self._nodeStack.pop()
|
self._nodeStack.pop()
|
||||||
|
|
||||||
|
if self._firstEventID is not None and self._firstEventID == event_id:
|
||||||
|
self.finished()
|
||||||
|
|
||||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||||
"""No-op."""
|
"""No-op."""
|
||||||
logger.info("trace_start:{}\n".format(trace_id))
|
logger.info("trace_start:{}\n".format(trace_id))
|
||||||
@@ -296,14 +312,6 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
|||||||
"""No-op."""
|
"""No-op."""
|
||||||
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
|
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
|
||||||
|
|
||||||
wf_event = self.makeWorkflow_finishedEvent()
|
|
||||||
if wf_event.to_response() is not None:
|
|
||||||
self._aqueue.put_nowait(wf_event)
|
|
||||||
|
|
||||||
msgEnt_event = self.makeMessage_EndEvent()
|
|
||||||
if msgEnt_event.to_response() is not None:
|
|
||||||
self._aqueue.put_nowait(msgEnt_event)
|
|
||||||
|
|
||||||
async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]:
|
async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]:
|
||||||
while not self._aqueue.empty() or not self.is_done:
|
while not self._aqueue.empty() or not self.is_done:
|
||||||
try:
|
try:
|
||||||
@@ -312,30 +320,28 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def makeWorkflow_startEvent(self)->ChatCallbackEvent:
|
def makeWorkflow_startEvent(self)->ChatCallbackEvent:
|
||||||
data:ChatRequestData = self._params['data']
|
args:Dict[str,Any] = self._ids
|
||||||
args:Dict[str,Any] = self._params['ids']
|
|
||||||
args.update(
|
args.update(
|
||||||
{
|
{
|
||||||
'use_id': data.user,
|
'use_id': self._chatData.user,
|
||||||
'query': data.query,
|
'query': self._chatData.query,
|
||||||
'conversation_id': data.conversation_id
|
'conversation_id': self._chatData.conversation_id
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args)
|
return ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args)
|
||||||
|
|
||||||
def makeWorkflow_finishedEvent(self)->ChatCallbackEvent:
|
def makeWorkflow_finishedEvent(self)->ChatCallbackEvent:
|
||||||
data:ChatRequestData = self._params['data']
|
args:Dict[str,Any] = self._ids
|
||||||
args:Dict[str,Any] = self._params['ids']
|
|
||||||
args.update(
|
args.update(
|
||||||
{
|
{
|
||||||
'response': '',
|
'response': '',
|
||||||
'conversation_id': data.conversation_id
|
'conversation_id': self._chatData.conversation_id
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args)
|
return ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args)
|
||||||
|
|
||||||
def makeMessage_EndEvent(self)->ChatCallbackEvent:
|
def makeMessage_EndEvent(self)->ChatCallbackEvent:
|
||||||
args:Dict[str,Any] = self._params['ids']
|
args:Dict[str,Any] = self._ids
|
||||||
if self._response is not None:
|
if self._response is not None:
|
||||||
args.update({
|
args.update({
|
||||||
'source_node': self._response.source_nodes
|
'source_node': self._response.source_nodes
|
||||||
@@ -343,6 +349,21 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
|||||||
msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args)
|
msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args)
|
||||||
return msgEnt_event
|
return msgEnt_event
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
#添加工作流开始事件
|
||||||
|
wf_event = self.makeWorkflow_startEvent()
|
||||||
|
if wf_event.to_response() is not None:
|
||||||
|
self._aqueue.put_nowait(wf_event)
|
||||||
|
|
||||||
|
def finished(self):
|
||||||
|
wf_event = self.makeWorkflow_finishedEvent()
|
||||||
|
if wf_event.to_response() is not None:
|
||||||
|
self._aqueue.put_nowait(wf_event)
|
||||||
|
|
||||||
|
msgEnt_event = self.makeMessage_EndEvent()
|
||||||
|
if msgEnt_event.to_response() is not None:
|
||||||
|
self._aqueue.put_nowait(msgEnt_event)
|
||||||
|
|
||||||
class IDManager:
|
class IDManager:
|
||||||
def createID(self):
|
def createID(self):
|
||||||
return {
|
return {
|
||||||
@@ -410,6 +431,7 @@ class ChatStreamResponse(StreamingResponse):
|
|||||||
|
|
||||||
# the text_generator is the leading stream, once it's finished, also finish the event stream
|
# the text_generator is the leading stream, once it's finished, also finish the event stream
|
||||||
event_handler.is_done = True
|
event_handler.is_done = True
|
||||||
|
event_handler.setResponse(response)
|
||||||
|
|
||||||
# Yield the events from the event handler
|
# Yield the events from the event handler
|
||||||
async def _event_generator():
|
async def _event_generator():
|
||||||
@@ -431,7 +453,8 @@ class ChatStreamResponse(StreamingResponse):
|
|||||||
break
|
break
|
||||||
|
|
||||||
@v.post("/chat-messages")
|
@v.post("/chat-messages")
|
||||||
async def post_conversations(request: Request, data: ChatRequestData):
|
async def post_chatmessages(request: Request, data: ChatRequestData):
|
||||||
|
global gEvent_handler
|
||||||
userMng.findNoExistCreate(data.user)
|
userMng.findNoExistCreate(data.user)
|
||||||
data.conversation_id = data.conversation_id if data.conversation_id else str(uuid.uuid4())
|
data.conversation_id = data.conversation_id if data.conversation_id else str(uuid.uuid4())
|
||||||
|
|
||||||
@@ -445,19 +468,21 @@ async def post_conversations(request: Request, data: ChatRequestData):
|
|||||||
filters = None
|
filters = None
|
||||||
params = data.inputs or {}
|
params = data.inputs or {}
|
||||||
|
|
||||||
# 获取聊天引擎对象
|
|
||||||
chat_engine = get_chat_engine(filters=filters, params=params)
|
|
||||||
|
|
||||||
# 启动聊天事件监听
|
# 启动聊天事件监听
|
||||||
ids = IDManager().createID()
|
ids = IDManager().createID()
|
||||||
event_handler = ChatEventCallbackHandler(ids = ids,data = data)
|
if gEvent_handler is None:
|
||||||
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
|
gEvent_handler = ChatEventCallbackHandler()
|
||||||
|
Settings.llm.callback_manager.handlers.append(gEvent_handler)
|
||||||
|
|
||||||
|
if gEvent_handler is not None:
|
||||||
|
gEvent_handler.setInitParams(ids = ids,data = data)
|
||||||
|
|
||||||
|
# 获取聊天引擎对象
|
||||||
|
chat_engine = get_chat_engine(filters=filters, params=params)
|
||||||
# 执行异步聊天
|
# 执行异步聊天
|
||||||
response = await chat_engine.astream_chat(data.query)
|
response = await chat_engine.astream_chat(data.query)
|
||||||
event_handler.setResponse(response)
|
|
||||||
# 返回异步消息回应
|
# 返回异步消息回应
|
||||||
return ChatStreamResponse(request, event_handler, response, data,ids)
|
return ChatStreamResponse(request, gEvent_handler, response, data,ids)
|
||||||
|
|
||||||
@v.get("/messages")
|
@v.get("/messages")
|
||||||
async def query_messages(user:str, conversation_id:str):
|
async def query_messages(user:str, conversation_id:str):
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ def create_query_engine(index, top_k=3, use_reranker=False, filters=None, respon
|
|||||||
simple_template = simple_template,
|
simple_template = simple_template,
|
||||||
node_postprocessors=postprocess,
|
node_postprocessors=postprocess,
|
||||||
use_async=True,
|
use_async=True,
|
||||||
streaming=True,
|
streaming=False,
|
||||||
ResponseMode = response_mode
|
ResponseMode = response_mode
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ logger = logging.getLogger("uvicorn")
|
|||||||
usPrj = using_project(os.getenv("PHOENIX_PROJECT_NAME"))
|
usPrj = using_project(os.getenv("PHOENIX_PROJECT_NAME"))
|
||||||
usPrj.__enter__()
|
usPrj.__enter__()
|
||||||
|
|
||||||
|
import nest_asyncio
|
||||||
|
nest_asyncio.apply()
|
||||||
|
|
||||||
init_settings()
|
init_settings()
|
||||||
init_observability()
|
init_observability()
|
||||||
|
|||||||
Reference in New Issue
Block a user