dev #5

Closed
ly wants to merge 93 commits from dev into dev-db
5 changed files with 266 additions and 53 deletions
Showing only changes of commit b4c571cddb - Show all commits
+251 -40
View File
@@ -1,23 +1,264 @@
import os
from typing import Dict, List, Any, Optional, cast
from fastapi import APIRouter,Request
from app.api.routers.request.base import userMng,conversations
from app.api.routers.request.models import ChatRequestData
import asyncio
import json
import logging
from typing import Dict, List, Any, Optional, AsyncGenerator
from aiostream import stream
from fastapi import APIRouter, Request
from fastapi.responses import StreamingResponse
from llama_index.core import BaseCallbackHandler
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.callbacks import CBEventType
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
from llama_index.core.tools import ToolOutput
from pydantic import BaseModel
from app.api.routers.events import EventCallbackHandler
from app.api.routers.request.base import userMng, conversations
from app.api.routers.request.models import ChatRequestData
from app.engine import get_chat_engine
logger = logging.getLogger("uvicorn")
api_router = r = APIRouter()
v1_router = v = APIRouter()
class ChatCallbackEvent(BaseModel):
event_type: CBEventType
payload: Optional[Dict[str, Any]] = None
event_id: str = ""
def get_retrieval_message(self) -> dict | None:
if self.payload:
nodes = self.payload.get("nodes")
if nodes:
msg = f"根据查询检索到 {len(nodes)} 源文件"
else:
msg = f"查询检索中: '{self.payload.get('query_str')}'"
return {
"type": "events",
"data": {"title": msg},
}
else:
return None
def get_tool_message(self) -> dict | None:
func_call_args = self.payload.get("function_call")
if func_call_args is not None and "tool" in self.payload:
tool = self.payload.get("tool")
return {
"type": "events",
"data": {
"title": f"调用工具 {tool.name} ,参数: {func_call_args}",
},
}
def _is_output_serializable(self, output: Any) -> bool:
try:
json.dumps(output)
return True
except TypeError:
return False
def get_agent_tool_response(self) -> dict | None:
response = self.payload.get("response")
if response is not None:
sources = response.sources
for source in sources:
# Return the tool response here to include the toolCall information
if isinstance(source, ToolOutput):
if self._is_output_serializable(source.raw_output):
output = source.raw_output
else:
output = source.content
return {
"type": "tools",
"data": {
"toolOutput": {
"output": output,
"isError": source.is_error,
},
"toolCall": {
"id": None, # There is no tool id in the ToolOutput
"name": source.tool_name,
"input": source.raw_input,
},
},
}
def to_response(self):
try:
match self.event_type:
case "retrieve":
return self.get_retrieval_message()
case "function_call":
return self.get_tool_message()
case "agent_step":
return self.get_agent_tool_response()
case _:
return None
except Exception as e:
logger.error(f"转换回应时间时发生错误,原因: {e}")
return None
class ChatEventCallbackHandler(BaseCallbackHandler):
_aqueue: asyncio.Queue
is_done: bool = False
def __init__(
self,
):
"""Initialize the base callback handler."""
ignored_events = [
CBEventType.CHUNKING,
CBEventType.NODE_PARSING,
CBEventType.EMBEDDING,
CBEventType.LLM,
CBEventType.TEMPLATING,
]
super().__init__(ignored_events, ignored_events)
self._aqueue = asyncio.Queue()
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> str:
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
if event.to_response() is not None:
self._aqueue.put_nowait(event)
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
if event.to_response() is not None:
self._aqueue.put_nowait(event)
def start_trace(self, trace_id: Optional[str] = None) -> None:
"""No-op."""
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""No-op."""
async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]:
while not self._aqueue.empty() or not self.is_done:
try:
yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1)
except asyncio.TimeoutError:
pass
class ChatStreamResponse(StreamingResponse):
TEXT_PREFIX = "data:"
DATA_PREFIX = "data:"
@classmethod
def convert_text(cls, token: str):
# Escape newlines and double quotes to avoid breaking the stream
token = json.dumps(token)
#return f"data: {{"event": "message", "conversation_id": "80d85523-de92-4b9d-aca0-c48a5eacb068", "message_id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "created_at": 1724406492, "task_id": "802f3064-030d-42ac-a882-0e1293712d04", "id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "answer": "{token}"}}"
return ""
@classmethod
def convert_data(cls, data: dict):
data_str = json.dumps(data)
return f"{cls.DATA_PREFIX}{data_str}\n"
def __init__(
self,
request: Request,
event_handler: EventCallbackHandler,
response: StreamingAgentChatResponse,
data: ChatRequestData
):
content = ChatStreamResponse.content_generator(
request, event_handler, response, data
)
super().__init__(content=content)
@classmethod
async def content_generator(
cls,
request: Request,
event_handler: EventCallbackHandler,
response: StreamingAgentChatResponse,
data: ChatRequestData
):
# Yield the text response
async def _chat_response_generator():
final_response = ""
async for token in response.async_response_gen():
final_response += token
yield ChatStreamResponse.convert_text(token)
# 存储消息历史
#message = Message(data.conversation_id, data.query, answer=final_response)
#messageManager.addmessage(message)
# the text_generator is the leading stream, once it's finished, also finish the event stream
event_handler.is_done = True
# Yield the events from the event handler
async def _event_generator():
async for event in event_handler.async_event_gen():
event_response = event.to_response()
if event_response is not None:
yield ChatStreamResponse.convert_data(event_response)
combine = stream.merge(_chat_response_generator(), _event_generator())
is_stream_started = False
async with combine.stream() as streamer:
async for output in streamer:
if not is_stream_started:
is_stream_started = True
# Stream a blank message to start the stream
yield ChatStreamResponse.convert_text("")
yield output
if await request.is_disconnected():
break
@v.post("/chat-messages")
async def post_conversations(request: Request,data: ChatRequestData):
async def post_conversations(request: Request, data: ChatRequestData):
userMng.findNoExistCreate(data.user)
conversaObj = conversations()
conversationinfo = conversaObj.get(data.user)
conversationinfo = conversaObj.get(data.user, data.conversation_id)
if conversationinfo is None:
conversationinfo = conversaObj.add(data.user, "新建会话")
conversationinfo = conversaObj.add(data.user, "新建会话", data.conversation_id)
return None
# 生成聊天参数
last_message_content = ChatMessage.from_str(data.query)
filters = None
params = data.inputs or {}
# 获取聊天引擎对象
chat_engine = get_chat_engine(filters=filters, params=params)
# 启动聊天事件监听
event_handler = EventCallbackHandler()
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
# 执行异步聊天
response = await chat_engine.astream_chat(data.query)
# 返回异步消息回应
return ChatStreamResponse(request, event_handler, response, data)
@v.get("/messages")
async def query_messages(user:str, conversation_id:str):
@@ -37,33 +278,3 @@ async def query_conversations(user:str):
"has_more": False,
"data": conversations().gets(user_id)
}
@r.get("/conversations")
async def query_conversations(first_id:int = None, limit:int = None, pinned:bool = None):
pass
#meta查询
@r.get("/meta")
async def query_meta():
pass
#name查询
@r.get("/name查询")
def query_name():
with sessionlocal() as session:
name = session.query(NameOrm).first()
return Name.from_orm(name)
#parameters查询
@r.get("/parameters")
async def query_parameters():
pass
#msite查询
@r.get("/site")
async def query_site():
pass
+4 -5
View File
@@ -1,10 +1,8 @@
import os
from typing import Dict, List, Any, Optional, cast
import json
from app.api.routers.request.dbOrm import DBManager
from app.api.routers.request.baseConfig import BaseConfig
from datetime import datetime
from app.api.routers.request.baseConfig import BaseConfig
from app.api.routers.request.dbOrm import DBManager
dbManage = DBManager()
class conversations:
@@ -31,6 +29,7 @@ class conversations:
if id == '':
id= str(uuid.uuid4())
template = BaseConfig.ConversationCfg
template['id'] = id
template['user_id'] = user_id
template['name'] = name
+4 -5
View File
@@ -1,11 +1,10 @@
import os
from typing import Dict, List, Any, Optional, cast
from typing import Dict, List, Any
from fastapi import APIRouter
from pydantic import BaseModel, Field
from sqlalchemy import create_engine, Column, String, Integer, Boolean, JSON,ForeignKey
from sqlalchemy.orm import sessionmaker, declarative_base,relationship
from pydantic import BaseModel
from sqlalchemy import create_engine, Column, String, Integer, JSON
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.orm import sessionmaker, declarative_base
Base = declarative_base()
+5 -1
View File
@@ -1,5 +1,8 @@
from typing import Dict, Any
from pydantic import BaseModel
from typing import Dict, List, Any, Optional, cast
class ChatRequestData(BaseModel):
inputs: Dict[str,Any]
@@ -7,4 +10,5 @@ class ChatRequestData(BaseModel):
user: str
response_mode: str
files: Any
conversation_id: str = None
+2 -2
View File
@@ -12,7 +12,7 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from app.api.routers.chat import chat_router
from app.api.routers.upload import file_upload_router
from app.api.routers.app import api_router,v1_router
from app.api.routers.app import v1_router
from app.settings import init_settings
from app.observability import init_observability
from fastapi.staticfiles import StaticFiles
@@ -56,7 +56,7 @@ mount_static_files("data", "/api/files/data")
mount_static_files("data_output", "/api/files/output")
app.include_router(chat_router, prefix="/api/chat")
app.include_router(file_upload_router, prefix="/api/chat/upload")
app.include_router(api_router, prefix="/api")
app.include_router(v1_router, prefix="/v1")
@app.get("/")