删除误上传的文件
This commit is contained in:
@@ -1,150 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, status
|
||||
from llama_index.core.chat_engine.types import BaseChatEngine, NodeWithScore
|
||||
from llama_index.core.llms import MessageRole
|
||||
from llama_index.core.vector_stores.types import MetadataFilter, MetadataFilters
|
||||
|
||||
from app.api.routers.events import EventCallbackHandler
|
||||
from app.api.routers.models import (
|
||||
ChatConfig,
|
||||
ChatData,
|
||||
Message,
|
||||
Result,
|
||||
SourceNodes,
|
||||
)
|
||||
from app.api.routers.vercel_response import VercelStreamResponse
|
||||
from app.api.services.llama_cloud import LLamaCloudFileService
|
||||
from app.engine import get_chat_engine
|
||||
|
||||
chat_router = r = APIRouter()
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
def process_response_nodes(
|
||||
nodes: List[NodeWithScore],
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
"""
|
||||
Start background tasks on the source nodes if needed.
|
||||
"""
|
||||
files_to_download = SourceNodes.get_download_files(nodes)
|
||||
for file in files_to_download:
|
||||
background_tasks.add_task(
|
||||
LLamaCloudFileService.download_llamacloud_pipeline_file, file
|
||||
)
|
||||
|
||||
|
||||
# streaming endpoint - delete if not needed
|
||||
@r.post("")
|
||||
async def chat(
|
||||
request: Request,
|
||||
data: ChatData,
|
||||
background_tasks: BackgroundTasks,
|
||||
chat_engine: BaseChatEngine = Depends(get_chat_engine),
|
||||
):
|
||||
try:
|
||||
last_message_content = data.get_last_message_content()
|
||||
# 由于基于历史消息的提示词没有调整好,所以暂时屏蔽历史消息
|
||||
data.messages.clear()
|
||||
messages = data.get_history_messages()
|
||||
|
||||
doc_ids = data.get_chat_document_ids()
|
||||
filters = generate_filters(doc_ids)
|
||||
params = data.data or {}
|
||||
logger.info("Creating chat engine with filters", filters.dict())
|
||||
chat_engine = get_chat_engine(filters=filters, params=params)
|
||||
|
||||
event_handler = EventCallbackHandler()
|
||||
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
|
||||
|
||||
response = await chat_engine.astream_chat(last_message_content, messages)
|
||||
process_response_nodes(response.source_nodes, background_tasks)
|
||||
|
||||
return VercelStreamResponse(request, event_handler, response, data)
|
||||
except Exception as e:
|
||||
logger.exception("Error in chat engine", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error in chat engine: {e}",
|
||||
) from e
|
||||
|
||||
|
||||
def generate_filters(doc_ids):
|
||||
if len(doc_ids) > 0:
|
||||
filters = MetadataFilters(
|
||||
filters=[
|
||||
MetadataFilter(
|
||||
key="private",
|
||||
value=["true"],
|
||||
operator="nin", # type: ignore
|
||||
),
|
||||
MetadataFilter(
|
||||
key="doc_id",
|
||||
value=doc_ids,
|
||||
operator="in", # type: ignore
|
||||
),
|
||||
],
|
||||
condition="or", # type: ignore
|
||||
)
|
||||
else:
|
||||
filters = MetadataFilters(
|
||||
# Use the "NIN" - "not in" operator to include all public documents (don't have the private key set)
|
||||
filters=[
|
||||
MetadataFilter(
|
||||
key="private",
|
||||
value=["true"],
|
||||
operator="nin", # type: ignore
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
return filters
|
||||
|
||||
|
||||
# non-streaming endpoint - delete if not needed
|
||||
@r.post("/request")
|
||||
async def chat_request(
|
||||
data: ChatData,
|
||||
chat_engine: BaseChatEngine = Depends(get_chat_engine),
|
||||
) -> Result:
|
||||
last_message_content = data.get_last_message_content()
|
||||
messages = data.get_history_messages()
|
||||
|
||||
response = await chat_engine.achat(last_message_content, messages)
|
||||
return Result(
|
||||
result=Message(role=MessageRole.ASSISTANT, content=response.response),
|
||||
nodes=SourceNodes.from_source_nodes(response.source_nodes),
|
||||
)
|
||||
|
||||
|
||||
@r.get("/config")
|
||||
async def chat_config() -> ChatConfig:
|
||||
starter_questions = None
|
||||
conversation_starters = os.getenv("CONVERSATION_STARTERS")
|
||||
if conversation_starters and conversation_starters.strip():
|
||||
starter_questions = conversation_starters.strip().split("\\n")
|
||||
return ChatConfig(starter_questions=starter_questions)
|
||||
|
||||
|
||||
@r.get("/config/llamacloud")
|
||||
async def chat_llama_cloud_config():
|
||||
projects = LLamaCloudFileService.get_all_projects_with_pipelines()
|
||||
pipeline = os.getenv("LLAMA_CLOUD_INDEX_NAME")
|
||||
project = os.getenv("LLAMA_CLOUD_PROJECT_NAME")
|
||||
pipeline_config = (
|
||||
pipeline
|
||||
and project
|
||||
and {
|
||||
"pipeline": pipeline,
|
||||
"project": project,
|
||||
}
|
||||
or None
|
||||
)
|
||||
return {
|
||||
"projects": projects,
|
||||
"pipeline": pipeline_config,
|
||||
}
|
||||
@@ -1,149 +0,0 @@
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import AsyncGenerator, Dict, Any, List, Optional
|
||||
from llama_index.core.callbacks.base import BaseCallbackHandler
|
||||
from llama_index.core.callbacks.schema import CBEventType
|
||||
from llama_index.core.tools.types import ToolOutput
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CallbackEvent(BaseModel):
|
||||
event_type: CBEventType
|
||||
payload: Optional[Dict[str, Any]] = None
|
||||
event_id: str = ""
|
||||
|
||||
def get_retrieval_message(self) -> dict | None:
|
||||
if self.payload:
|
||||
nodes = self.payload.get("nodes")
|
||||
if nodes:
|
||||
msg = f"根据查询检索到 {len(nodes)} 源文件"
|
||||
else:
|
||||
msg = f"查询检索中: '{self.payload.get('query_str')}'"
|
||||
return {
|
||||
"type": "events",
|
||||
"data": {"title": msg},
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_tool_message(self) -> dict | None:
|
||||
func_call_args = self.payload.get("function_call")
|
||||
if func_call_args is not None and "tool" in self.payload:
|
||||
tool = self.payload.get("tool")
|
||||
return {
|
||||
"type": "events",
|
||||
"data": {
|
||||
"title": f"调用工具 {tool.name} ,参数: {func_call_args}",
|
||||
},
|
||||
}
|
||||
|
||||
def _is_output_serializable(self, output: Any) -> bool:
|
||||
try:
|
||||
json.dumps(output)
|
||||
return True
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
def get_agent_tool_response(self) -> dict | None:
|
||||
response = self.payload.get("response")
|
||||
if response is not None:
|
||||
sources = response.sources
|
||||
for source in sources:
|
||||
# Return the tool response here to include the toolCall information
|
||||
if isinstance(source, ToolOutput):
|
||||
if self._is_output_serializable(source.raw_output):
|
||||
output = source.raw_output
|
||||
else:
|
||||
output = source.content
|
||||
|
||||
return {
|
||||
"type": "tools",
|
||||
"data": {
|
||||
"toolOutput": {
|
||||
"output": output,
|
||||
"isError": source.is_error,
|
||||
},
|
||||
"toolCall": {
|
||||
"id": None, # There is no tool id in the ToolOutput
|
||||
"name": source.tool_name,
|
||||
"input": source.raw_input,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def to_response(self):
|
||||
try:
|
||||
match self.event_type:
|
||||
case "retrieve":
|
||||
return self.get_retrieval_message()
|
||||
case "function_call":
|
||||
return self.get_tool_message()
|
||||
case "agent_step":
|
||||
return self.get_agent_tool_response()
|
||||
case _:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"转换回应时间时发生错误,原因: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class EventCallbackHandler(BaseCallbackHandler):
|
||||
_aqueue: asyncio.Queue
|
||||
is_done: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
"""Initialize the base callback handler."""
|
||||
ignored_events = [
|
||||
CBEventType.CHUNKING,
|
||||
CBEventType.NODE_PARSING,
|
||||
CBEventType.EMBEDDING,
|
||||
CBEventType.LLM,
|
||||
CBEventType.TEMPLATING,
|
||||
]
|
||||
super().__init__(ignored_events, ignored_events)
|
||||
self._aqueue = asyncio.Queue()
|
||||
|
||||
def on_event_start(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
||||
if event.to_response() is not None:
|
||||
self._aqueue.put_nowait(event)
|
||||
|
||||
def on_event_end(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
||||
if event.to_response() is not None:
|
||||
self._aqueue.put_nowait(event)
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
"""No-op."""
|
||||
|
||||
def end_trace(
|
||||
self,
|
||||
trace_id: Optional[str] = None,
|
||||
trace_map: Optional[Dict[str, List[str]]] = None,
|
||||
) -> None:
|
||||
"""No-op."""
|
||||
|
||||
async def async_event_gen(self) -> AsyncGenerator[CallbackEvent, None]:
|
||||
while not self._aqueue.empty() or not self.is_done:
|
||||
try:
|
||||
yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
@@ -1,253 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Literal, Optional, Set
|
||||
|
||||
from llama_index.core.llms import ChatMessage, MessageRole
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
from pydantic import BaseModel, Field, validator, field_validator
|
||||
from pydantic.alias_generators import to_camel
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
class FileContent(BaseModel):
|
||||
type: Literal["text", "ref"]
|
||||
# If the file is pure text then the value is be a string
|
||||
# otherwise, it's a list of document IDs
|
||||
value: str | List[str]
|
||||
|
||||
|
||||
class File(BaseModel):
|
||||
id: str
|
||||
content: FileContent
|
||||
filename: str
|
||||
filesize: int
|
||||
filetype: str
|
||||
|
||||
|
||||
class AnnotationFileData(BaseModel):
|
||||
files: List[File] = Field(
|
||||
default=[],
|
||||
description="List of files",
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"csvFiles": [
|
||||
{
|
||||
"content": "Name, Age\nAlice, 25\nBob, 30",
|
||||
"filename": "example.csv",
|
||||
"filesize": 123,
|
||||
"id": "123",
|
||||
"type": "text/csv",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
alias_generator = to_camel
|
||||
|
||||
|
||||
class Annotation(BaseModel):
|
||||
type: str
|
||||
data: AnnotationFileData | List[str]
|
||||
|
||||
def to_content(self) -> str | None:
|
||||
if self.type == "document_file":
|
||||
# We only support generating context content for CSV files for now
|
||||
csv_files = [file for file in self.data.files if file.filetype == "csv"]
|
||||
if len(csv_files) > 0:
|
||||
return "Use data from following CSV raw content\n" + "\n".join(
|
||||
[f"```csv\n{csv_file.content.value}\n```" for csv_file in csv_files]
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"The annotation {self.type} is not supported for generating context content"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: MessageRole
|
||||
content: str
|
||||
annotations: List[Annotation] | None = None
|
||||
|
||||
|
||||
class ChatData(BaseModel):
|
||||
messages: List[Message]
|
||||
data: Any = None
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What standards for letters exist?",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@field_validator("messages")
|
||||
def messages_must_not_be_empty(cls, v):
|
||||
if len(v) == 0:
|
||||
raise ValueError("Messages must not be empty")
|
||||
return v
|
||||
|
||||
def get_last_message_content(self) -> str:
|
||||
"""
|
||||
Get the content of the last message along with the data content if available.
|
||||
Fallback to use data content from previous messages
|
||||
"""
|
||||
if len(self.messages) == 0:
|
||||
raise ValueError("There is not any message in the chat")
|
||||
last_message = self.messages[-1]
|
||||
message_content = last_message.content
|
||||
for message in reversed(self.messages):
|
||||
if message.role == MessageRole.USER and message.annotations is not None:
|
||||
annotation_contents = filter(
|
||||
None,
|
||||
[annotation.to_content() for annotation in message.annotations],
|
||||
)
|
||||
if not annotation_contents:
|
||||
continue
|
||||
annotation_text = "\n".join(annotation_contents)
|
||||
message_content = f"{message_content}\n{annotation_text}"
|
||||
break
|
||||
return message_content
|
||||
|
||||
def get_history_messages(self) -> List[ChatMessage]:
|
||||
"""
|
||||
Get the history messages
|
||||
"""
|
||||
return [
|
||||
ChatMessage(role=message.role, content=message.content)
|
||||
for message in self.messages[:-1]
|
||||
]
|
||||
|
||||
def is_last_message_from_user(self) -> bool:
|
||||
return self.messages[-1].role == MessageRole.USER
|
||||
|
||||
def get_chat_document_ids(self) -> List[str]:
|
||||
"""
|
||||
Get the document IDs from the chat messages
|
||||
"""
|
||||
document_ids: List[str] = []
|
||||
for message in self.messages:
|
||||
if message.role == MessageRole.USER and message.annotations is not None:
|
||||
for annotation in message.annotations:
|
||||
if (
|
||||
annotation.type == "document_file"
|
||||
and annotation.data.files is not None
|
||||
):
|
||||
for fi in annotation.data.files:
|
||||
if fi.content.type == "ref":
|
||||
document_ids += fi.content.value
|
||||
return list(set(document_ids))
|
||||
|
||||
|
||||
class LlamaCloudFile(BaseModel):
|
||||
file_name: str
|
||||
pipeline_id: str
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, LlamaCloudFile):
|
||||
return NotImplemented
|
||||
return (
|
||||
self.file_name == other.file_name and self.pipeline_id == other.pipeline_id
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.file_name, self.pipeline_id))
|
||||
|
||||
|
||||
class SourceNodes(BaseModel):
|
||||
id: str
|
||||
metadata: Dict[str, Any]
|
||||
score: Optional[float]
|
||||
text: str
|
||||
url: Optional[str]
|
||||
|
||||
@classmethod
|
||||
def from_source_node(cls, source_node: NodeWithScore):
|
||||
metadata = source_node.node.metadata
|
||||
url = cls.get_url_from_metadata(metadata)
|
||||
#text = 'filename' in metadata and metadata['filename'] or source_node.node.node_id
|
||||
text = source_node.node.text
|
||||
return cls(
|
||||
id=source_node.node.node_id,
|
||||
metadata=metadata,
|
||||
score=source_node.score,
|
||||
text=text, # type: ignore
|
||||
url=url,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_url_from_metadata(cls, metadata: Dict[str, Any]) -> str:
|
||||
url_prefix = os.getenv("FILESERVER_URL_PREFIX")
|
||||
if not url_prefix:
|
||||
logger.warning(
|
||||
"Warning: FILESERVER_URL_PREFIX not set in environment variables. Can't use file server"
|
||||
)
|
||||
file_name = metadata.get("file_name")
|
||||
if file_name and url_prefix:
|
||||
# file_name exists and file server is configured
|
||||
pipeline_id = metadata.get("pipeline_id")
|
||||
if pipeline_id and metadata.get("private") is None:
|
||||
# file is from LlamaCloud and was not ingested locally
|
||||
file_name = f"{pipeline_id}${file_name}"
|
||||
return f"{url_prefix}/output/llamacloud/{file_name}"
|
||||
is_private = metadata.get("private", "false") == "true"
|
||||
if is_private:
|
||||
return f"{url_prefix}/output/uploaded/{file_name}"
|
||||
return f"{url_prefix}/data/{file_name}"
|
||||
else:
|
||||
# fallback to URL in metadata (e.g. for websites)
|
||||
return metadata.get("URL")
|
||||
|
||||
@classmethod
|
||||
def from_source_nodes(cls, source_nodes: List[NodeWithScore]):
|
||||
return [cls.from_source_node(node) for node in source_nodes]
|
||||
|
||||
@staticmethod
|
||||
def get_download_files(nodes: List[NodeWithScore]) -> Set[LlamaCloudFile]:
|
||||
source_nodes = SourceNodes.from_source_nodes(nodes)
|
||||
llama_cloud_files = [
|
||||
LlamaCloudFile(
|
||||
file_name=node.metadata.get("file_name"),
|
||||
pipeline_id=node.metadata.get("pipeline_id"),
|
||||
)
|
||||
for node in source_nodes
|
||||
if (
|
||||
node.metadata.get("private")
|
||||
is None # Only download files are from LlamaCloud and were not ingested locally
|
||||
and node.metadata.get("pipeline_id") is not None
|
||||
and node.metadata.get("file_name") is not None
|
||||
)
|
||||
]
|
||||
# Remove duplicates and return
|
||||
return set(llama_cloud_files)
|
||||
|
||||
|
||||
class Result(BaseModel):
|
||||
result: Message
|
||||
nodes: List[SourceNodes]
|
||||
|
||||
|
||||
class ChatConfig(BaseModel):
|
||||
starter_questions: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="List of starter questions",
|
||||
serialization_alias="starterQuestions",
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"starterQuestions": [
|
||||
"What standards for letters exist?",
|
||||
"What are the requirements for a letter to be considered a letter?",
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.api.services.file import PrivateFileService
|
||||
|
||||
file_upload_router = r = APIRouter()
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
class FileUploadRequest(BaseModel):
|
||||
base64: str
|
||||
|
||||
|
||||
@r.post("")
|
||||
def upload_file(request: FileUploadRequest) -> List[str]:
|
||||
try:
|
||||
logger.info("Processing file")
|
||||
return PrivateFileService.process_file(request.base64)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Error processing file")
|
||||
@@ -1,109 +0,0 @@
|
||||
import json
|
||||
|
||||
from aiostream import stream
|
||||
from fastapi import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
|
||||
|
||||
from app.api.routers.events import EventCallbackHandler
|
||||
from app.api.routers.models import ChatData, Message, SourceNodes
|
||||
from app.api.services.suggestion import NextQuestionSuggestion
|
||||
|
||||
|
||||
class VercelStreamResponse(StreamingResponse):
|
||||
"""
|
||||
Class to convert the response from the chat engine to the streaming format expected by Vercel
|
||||
"""
|
||||
|
||||
TEXT_PREFIX = "0:"
|
||||
DATA_PREFIX = "8:"
|
||||
|
||||
@classmethod
|
||||
def convert_text(cls, token: str):
|
||||
# Escape newlines and double quotes to avoid breaking the stream
|
||||
token = json.dumps(token)
|
||||
return f"{cls.TEXT_PREFIX}{token}\n"
|
||||
|
||||
@classmethod
|
||||
def convert_data(cls, data: dict):
|
||||
data_str = json.dumps(data)
|
||||
return f"{cls.DATA_PREFIX}[{data_str}]\n"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request: Request,
|
||||
event_handler: EventCallbackHandler,
|
||||
response: StreamingAgentChatResponse,
|
||||
chat_data: ChatData,
|
||||
):
|
||||
content = VercelStreamResponse.content_generator(
|
||||
request, event_handler, response, chat_data
|
||||
)
|
||||
super().__init__(content=content)
|
||||
|
||||
@classmethod
|
||||
async def content_generator(
|
||||
cls,
|
||||
request: Request,
|
||||
event_handler: EventCallbackHandler,
|
||||
response: StreamingAgentChatResponse,
|
||||
chat_data: ChatData,
|
||||
):
|
||||
# Yield the text response
|
||||
async def _chat_response_generator():
|
||||
final_response = ""
|
||||
async for token in response.async_response_gen():
|
||||
final_response += token
|
||||
yield VercelStreamResponse.convert_text(token)
|
||||
|
||||
# Generate questions that user might interested to
|
||||
conversation = chat_data.messages + [
|
||||
Message(role="assistant", content=final_response)
|
||||
]
|
||||
questions = await NextQuestionSuggestion.suggest_next_questions(
|
||||
conversation
|
||||
)
|
||||
if len(questions) > 0:
|
||||
yield VercelStreamResponse.convert_data(
|
||||
{
|
||||
"type": "suggested_questions",
|
||||
"data": questions,
|
||||
}
|
||||
)
|
||||
|
||||
# the text_generator is the leading stream, once it's finished, also finish the event stream
|
||||
event_handler.is_done = True
|
||||
|
||||
# Yield the source nodes
|
||||
yield cls.convert_data(
|
||||
{
|
||||
"type": "sources",
|
||||
"data": {
|
||||
"nodes": [
|
||||
SourceNodes.from_source_node(node).dict()
|
||||
for node in response.source_nodes
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Yield the events from the event handler
|
||||
async def _event_generator():
|
||||
async for event in event_handler.async_event_gen():
|
||||
event_response = event.to_response()
|
||||
if event_response is not None:
|
||||
yield VercelStreamResponse.convert_data(event_response)
|
||||
|
||||
combine = stream.merge(_chat_response_generator(), _event_generator())
|
||||
is_stream_started = False
|
||||
async with combine.stream() as streamer:
|
||||
async for output in streamer:
|
||||
if not is_stream_started:
|
||||
is_stream_started = True
|
||||
# Stream a blank message to start the stream
|
||||
yield VercelStreamResponse.convert_text("")
|
||||
|
||||
yield output
|
||||
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
Reference in New Issue
Block a user