删除误上传的文件
This commit is contained in:
@@ -1,150 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, status
|
||||
from llama_index.core.chat_engine.types import BaseChatEngine, NodeWithScore
|
||||
from llama_index.core.llms import MessageRole
|
||||
from llama_index.core.vector_stores.types import MetadataFilter, MetadataFilters
|
||||
|
||||
from app.api.routers.events import EventCallbackHandler
|
||||
from app.api.routers.models import (
|
||||
ChatConfig,
|
||||
ChatData,
|
||||
Message,
|
||||
Result,
|
||||
SourceNodes,
|
||||
)
|
||||
from app.api.routers.vercel_response import VercelStreamResponse
|
||||
from app.api.services.llama_cloud import LLamaCloudFileService
|
||||
from app.engine import get_chat_engine
|
||||
|
||||
chat_router = r = APIRouter()
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
def process_response_nodes(
|
||||
nodes: List[NodeWithScore],
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
"""
|
||||
Start background tasks on the source nodes if needed.
|
||||
"""
|
||||
files_to_download = SourceNodes.get_download_files(nodes)
|
||||
for file in files_to_download:
|
||||
background_tasks.add_task(
|
||||
LLamaCloudFileService.download_llamacloud_pipeline_file, file
|
||||
)
|
||||
|
||||
|
||||
# streaming endpoint - delete if not needed
|
||||
@r.post("")
|
||||
async def chat(
|
||||
request: Request,
|
||||
data: ChatData,
|
||||
background_tasks: BackgroundTasks,
|
||||
chat_engine: BaseChatEngine = Depends(get_chat_engine),
|
||||
):
|
||||
try:
|
||||
last_message_content = data.get_last_message_content()
|
||||
# 由于基于历史消息的提示词没有调整好,所以暂时屏蔽历史消息
|
||||
data.messages.clear()
|
||||
messages = data.get_history_messages()
|
||||
|
||||
doc_ids = data.get_chat_document_ids()
|
||||
filters = generate_filters(doc_ids)
|
||||
params = data.data or {}
|
||||
logger.info("Creating chat engine with filters", filters.dict())
|
||||
chat_engine = get_chat_engine(filters=filters, params=params)
|
||||
|
||||
event_handler = EventCallbackHandler()
|
||||
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
|
||||
|
||||
response = await chat_engine.astream_chat(last_message_content, messages)
|
||||
process_response_nodes(response.source_nodes, background_tasks)
|
||||
|
||||
return VercelStreamResponse(request, event_handler, response, data)
|
||||
except Exception as e:
|
||||
logger.exception("Error in chat engine", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error in chat engine: {e}",
|
||||
) from e
|
||||
|
||||
|
||||
def generate_filters(doc_ids):
|
||||
if len(doc_ids) > 0:
|
||||
filters = MetadataFilters(
|
||||
filters=[
|
||||
MetadataFilter(
|
||||
key="private",
|
||||
value=["true"],
|
||||
operator="nin", # type: ignore
|
||||
),
|
||||
MetadataFilter(
|
||||
key="doc_id",
|
||||
value=doc_ids,
|
||||
operator="in", # type: ignore
|
||||
),
|
||||
],
|
||||
condition="or", # type: ignore
|
||||
)
|
||||
else:
|
||||
filters = MetadataFilters(
|
||||
# Use the "NIN" - "not in" operator to include all public documents (don't have the private key set)
|
||||
filters=[
|
||||
MetadataFilter(
|
||||
key="private",
|
||||
value=["true"],
|
||||
operator="nin", # type: ignore
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
return filters
|
||||
|
||||
|
||||
# non-streaming endpoint - delete if not needed
|
||||
@r.post("/request")
|
||||
async def chat_request(
|
||||
data: ChatData,
|
||||
chat_engine: BaseChatEngine = Depends(get_chat_engine),
|
||||
) -> Result:
|
||||
last_message_content = data.get_last_message_content()
|
||||
messages = data.get_history_messages()
|
||||
|
||||
response = await chat_engine.achat(last_message_content, messages)
|
||||
return Result(
|
||||
result=Message(role=MessageRole.ASSISTANT, content=response.response),
|
||||
nodes=SourceNodes.from_source_nodes(response.source_nodes),
|
||||
)
|
||||
|
||||
|
||||
@r.get("/config")
|
||||
async def chat_config() -> ChatConfig:
|
||||
starter_questions = None
|
||||
conversation_starters = os.getenv("CONVERSATION_STARTERS")
|
||||
if conversation_starters and conversation_starters.strip():
|
||||
starter_questions = conversation_starters.strip().split("\\n")
|
||||
return ChatConfig(starter_questions=starter_questions)
|
||||
|
||||
|
||||
@r.get("/config/llamacloud")
|
||||
async def chat_llama_cloud_config():
|
||||
projects = LLamaCloudFileService.get_all_projects_with_pipelines()
|
||||
pipeline = os.getenv("LLAMA_CLOUD_INDEX_NAME")
|
||||
project = os.getenv("LLAMA_CLOUD_PROJECT_NAME")
|
||||
pipeline_config = (
|
||||
pipeline
|
||||
and project
|
||||
and {
|
||||
"pipeline": pipeline,
|
||||
"project": project,
|
||||
}
|
||||
or None
|
||||
)
|
||||
return {
|
||||
"projects": projects,
|
||||
"pipeline": pipeline_config,
|
||||
}
|
||||
@@ -1,149 +0,0 @@
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import AsyncGenerator, Dict, Any, List, Optional
|
||||
from llama_index.core.callbacks.base import BaseCallbackHandler
|
||||
from llama_index.core.callbacks.schema import CBEventType
|
||||
from llama_index.core.tools.types import ToolOutput
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CallbackEvent(BaseModel):
|
||||
event_type: CBEventType
|
||||
payload: Optional[Dict[str, Any]] = None
|
||||
event_id: str = ""
|
||||
|
||||
def get_retrieval_message(self) -> dict | None:
|
||||
if self.payload:
|
||||
nodes = self.payload.get("nodes")
|
||||
if nodes:
|
||||
msg = f"根据查询检索到 {len(nodes)} 源文件"
|
||||
else:
|
||||
msg = f"查询检索中: '{self.payload.get('query_str')}'"
|
||||
return {
|
||||
"type": "events",
|
||||
"data": {"title": msg},
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_tool_message(self) -> dict | None:
|
||||
func_call_args = self.payload.get("function_call")
|
||||
if func_call_args is not None and "tool" in self.payload:
|
||||
tool = self.payload.get("tool")
|
||||
return {
|
||||
"type": "events",
|
||||
"data": {
|
||||
"title": f"调用工具 {tool.name} ,参数: {func_call_args}",
|
||||
},
|
||||
}
|
||||
|
||||
def _is_output_serializable(self, output: Any) -> bool:
|
||||
try:
|
||||
json.dumps(output)
|
||||
return True
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
def get_agent_tool_response(self) -> dict | None:
|
||||
response = self.payload.get("response")
|
||||
if response is not None:
|
||||
sources = response.sources
|
||||
for source in sources:
|
||||
# Return the tool response here to include the toolCall information
|
||||
if isinstance(source, ToolOutput):
|
||||
if self._is_output_serializable(source.raw_output):
|
||||
output = source.raw_output
|
||||
else:
|
||||
output = source.content
|
||||
|
||||
return {
|
||||
"type": "tools",
|
||||
"data": {
|
||||
"toolOutput": {
|
||||
"output": output,
|
||||
"isError": source.is_error,
|
||||
},
|
||||
"toolCall": {
|
||||
"id": None, # There is no tool id in the ToolOutput
|
||||
"name": source.tool_name,
|
||||
"input": source.raw_input,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def to_response(self):
|
||||
try:
|
||||
match self.event_type:
|
||||
case "retrieve":
|
||||
return self.get_retrieval_message()
|
||||
case "function_call":
|
||||
return self.get_tool_message()
|
||||
case "agent_step":
|
||||
return self.get_agent_tool_response()
|
||||
case _:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"转换回应时间时发生错误,原因: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class EventCallbackHandler(BaseCallbackHandler):
|
||||
_aqueue: asyncio.Queue
|
||||
is_done: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
"""Initialize the base callback handler."""
|
||||
ignored_events = [
|
||||
CBEventType.CHUNKING,
|
||||
CBEventType.NODE_PARSING,
|
||||
CBEventType.EMBEDDING,
|
||||
CBEventType.LLM,
|
||||
CBEventType.TEMPLATING,
|
||||
]
|
||||
super().__init__(ignored_events, ignored_events)
|
||||
self._aqueue = asyncio.Queue()
|
||||
|
||||
def on_event_start(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
||||
if event.to_response() is not None:
|
||||
self._aqueue.put_nowait(event)
|
||||
|
||||
def on_event_end(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
||||
if event.to_response() is not None:
|
||||
self._aqueue.put_nowait(event)
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
"""No-op."""
|
||||
|
||||
def end_trace(
|
||||
self,
|
||||
trace_id: Optional[str] = None,
|
||||
trace_map: Optional[Dict[str, List[str]]] = None,
|
||||
) -> None:
|
||||
"""No-op."""
|
||||
|
||||
async def async_event_gen(self) -> AsyncGenerator[CallbackEvent, None]:
|
||||
while not self._aqueue.empty() or not self.is_done:
|
||||
try:
|
||||
yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
@@ -1,253 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Literal, Optional, Set
|
||||
|
||||
from llama_index.core.llms import ChatMessage, MessageRole
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
from pydantic import BaseModel, Field, validator, field_validator
|
||||
from pydantic.alias_generators import to_camel
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
class FileContent(BaseModel):
|
||||
type: Literal["text", "ref"]
|
||||
# If the file is pure text then the value is be a string
|
||||
# otherwise, it's a list of document IDs
|
||||
value: str | List[str]
|
||||
|
||||
|
||||
class File(BaseModel):
|
||||
id: str
|
||||
content: FileContent
|
||||
filename: str
|
||||
filesize: int
|
||||
filetype: str
|
||||
|
||||
|
||||
class AnnotationFileData(BaseModel):
|
||||
files: List[File] = Field(
|
||||
default=[],
|
||||
description="List of files",
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"csvFiles": [
|
||||
{
|
||||
"content": "Name, Age\nAlice, 25\nBob, 30",
|
||||
"filename": "example.csv",
|
||||
"filesize": 123,
|
||||
"id": "123",
|
||||
"type": "text/csv",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
alias_generator = to_camel
|
||||
|
||||
|
||||
class Annotation(BaseModel):
|
||||
type: str
|
||||
data: AnnotationFileData | List[str]
|
||||
|
||||
def to_content(self) -> str | None:
|
||||
if self.type == "document_file":
|
||||
# We only support generating context content for CSV files for now
|
||||
csv_files = [file for file in self.data.files if file.filetype == "csv"]
|
||||
if len(csv_files) > 0:
|
||||
return "Use data from following CSV raw content\n" + "\n".join(
|
||||
[f"```csv\n{csv_file.content.value}\n```" for csv_file in csv_files]
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"The annotation {self.type} is not supported for generating context content"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: MessageRole
|
||||
content: str
|
||||
annotations: List[Annotation] | None = None
|
||||
|
||||
|
||||
class ChatData(BaseModel):
|
||||
messages: List[Message]
|
||||
data: Any = None
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What standards for letters exist?",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@field_validator("messages")
|
||||
def messages_must_not_be_empty(cls, v):
|
||||
if len(v) == 0:
|
||||
raise ValueError("Messages must not be empty")
|
||||
return v
|
||||
|
||||
def get_last_message_content(self) -> str:
|
||||
"""
|
||||
Get the content of the last message along with the data content if available.
|
||||
Fallback to use data content from previous messages
|
||||
"""
|
||||
if len(self.messages) == 0:
|
||||
raise ValueError("There is not any message in the chat")
|
||||
last_message = self.messages[-1]
|
||||
message_content = last_message.content
|
||||
for message in reversed(self.messages):
|
||||
if message.role == MessageRole.USER and message.annotations is not None:
|
||||
annotation_contents = filter(
|
||||
None,
|
||||
[annotation.to_content() for annotation in message.annotations],
|
||||
)
|
||||
if not annotation_contents:
|
||||
continue
|
||||
annotation_text = "\n".join(annotation_contents)
|
||||
message_content = f"{message_content}\n{annotation_text}"
|
||||
break
|
||||
return message_content
|
||||
|
||||
def get_history_messages(self) -> List[ChatMessage]:
|
||||
"""
|
||||
Get the history messages
|
||||
"""
|
||||
return [
|
||||
ChatMessage(role=message.role, content=message.content)
|
||||
for message in self.messages[:-1]
|
||||
]
|
||||
|
||||
def is_last_message_from_user(self) -> bool:
|
||||
return self.messages[-1].role == MessageRole.USER
|
||||
|
||||
def get_chat_document_ids(self) -> List[str]:
|
||||
"""
|
||||
Get the document IDs from the chat messages
|
||||
"""
|
||||
document_ids: List[str] = []
|
||||
for message in self.messages:
|
||||
if message.role == MessageRole.USER and message.annotations is not None:
|
||||
for annotation in message.annotations:
|
||||
if (
|
||||
annotation.type == "document_file"
|
||||
and annotation.data.files is not None
|
||||
):
|
||||
for fi in annotation.data.files:
|
||||
if fi.content.type == "ref":
|
||||
document_ids += fi.content.value
|
||||
return list(set(document_ids))
|
||||
|
||||
|
||||
class LlamaCloudFile(BaseModel):
|
||||
file_name: str
|
||||
pipeline_id: str
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, LlamaCloudFile):
|
||||
return NotImplemented
|
||||
return (
|
||||
self.file_name == other.file_name and self.pipeline_id == other.pipeline_id
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.file_name, self.pipeline_id))
|
||||
|
||||
|
||||
class SourceNodes(BaseModel):
|
||||
id: str
|
||||
metadata: Dict[str, Any]
|
||||
score: Optional[float]
|
||||
text: str
|
||||
url: Optional[str]
|
||||
|
||||
@classmethod
|
||||
def from_source_node(cls, source_node: NodeWithScore):
|
||||
metadata = source_node.node.metadata
|
||||
url = cls.get_url_from_metadata(metadata)
|
||||
#text = 'filename' in metadata and metadata['filename'] or source_node.node.node_id
|
||||
text = source_node.node.text
|
||||
return cls(
|
||||
id=source_node.node.node_id,
|
||||
metadata=metadata,
|
||||
score=source_node.score,
|
||||
text=text, # type: ignore
|
||||
url=url,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_url_from_metadata(cls, metadata: Dict[str, Any]) -> str:
|
||||
url_prefix = os.getenv("FILESERVER_URL_PREFIX")
|
||||
if not url_prefix:
|
||||
logger.warning(
|
||||
"Warning: FILESERVER_URL_PREFIX not set in environment variables. Can't use file server"
|
||||
)
|
||||
file_name = metadata.get("file_name")
|
||||
if file_name and url_prefix:
|
||||
# file_name exists and file server is configured
|
||||
pipeline_id = metadata.get("pipeline_id")
|
||||
if pipeline_id and metadata.get("private") is None:
|
||||
# file is from LlamaCloud and was not ingested locally
|
||||
file_name = f"{pipeline_id}${file_name}"
|
||||
return f"{url_prefix}/output/llamacloud/{file_name}"
|
||||
is_private = metadata.get("private", "false") == "true"
|
||||
if is_private:
|
||||
return f"{url_prefix}/output/uploaded/{file_name}"
|
||||
return f"{url_prefix}/data/{file_name}"
|
||||
else:
|
||||
# fallback to URL in metadata (e.g. for websites)
|
||||
return metadata.get("URL")
|
||||
|
||||
@classmethod
|
||||
def from_source_nodes(cls, source_nodes: List[NodeWithScore]):
|
||||
return [cls.from_source_node(node) for node in source_nodes]
|
||||
|
||||
@staticmethod
|
||||
def get_download_files(nodes: List[NodeWithScore]) -> Set[LlamaCloudFile]:
|
||||
source_nodes = SourceNodes.from_source_nodes(nodes)
|
||||
llama_cloud_files = [
|
||||
LlamaCloudFile(
|
||||
file_name=node.metadata.get("file_name"),
|
||||
pipeline_id=node.metadata.get("pipeline_id"),
|
||||
)
|
||||
for node in source_nodes
|
||||
if (
|
||||
node.metadata.get("private")
|
||||
is None # Only download files are from LlamaCloud and were not ingested locally
|
||||
and node.metadata.get("pipeline_id") is not None
|
||||
and node.metadata.get("file_name") is not None
|
||||
)
|
||||
]
|
||||
# Remove duplicates and return
|
||||
return set(llama_cloud_files)
|
||||
|
||||
|
||||
class Result(BaseModel):
|
||||
result: Message
|
||||
nodes: List[SourceNodes]
|
||||
|
||||
|
||||
class ChatConfig(BaseModel):
|
||||
starter_questions: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="List of starter questions",
|
||||
serialization_alias="starterQuestions",
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"starterQuestions": [
|
||||
"What standards for letters exist?",
|
||||
"What are the requirements for a letter to be considered a letter?",
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.api.services.file import PrivateFileService
|
||||
|
||||
file_upload_router = r = APIRouter()
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
class FileUploadRequest(BaseModel):
|
||||
base64: str
|
||||
|
||||
|
||||
@r.post("")
|
||||
def upload_file(request: FileUploadRequest) -> List[str]:
|
||||
try:
|
||||
logger.info("Processing file")
|
||||
return PrivateFileService.process_file(request.base64)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Error processing file")
|
||||
@@ -1,109 +0,0 @@
|
||||
import json
|
||||
|
||||
from aiostream import stream
|
||||
from fastapi import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
|
||||
|
||||
from app.api.routers.events import EventCallbackHandler
|
||||
from app.api.routers.models import ChatData, Message, SourceNodes
|
||||
from app.api.services.suggestion import NextQuestionSuggestion
|
||||
|
||||
|
||||
class VercelStreamResponse(StreamingResponse):
|
||||
"""
|
||||
Class to convert the response from the chat engine to the streaming format expected by Vercel
|
||||
"""
|
||||
|
||||
TEXT_PREFIX = "0:"
|
||||
DATA_PREFIX = "8:"
|
||||
|
||||
@classmethod
|
||||
def convert_text(cls, token: str):
|
||||
# Escape newlines and double quotes to avoid breaking the stream
|
||||
token = json.dumps(token)
|
||||
return f"{cls.TEXT_PREFIX}{token}\n"
|
||||
|
||||
@classmethod
|
||||
def convert_data(cls, data: dict):
|
||||
data_str = json.dumps(data)
|
||||
return f"{cls.DATA_PREFIX}[{data_str}]\n"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request: Request,
|
||||
event_handler: EventCallbackHandler,
|
||||
response: StreamingAgentChatResponse,
|
||||
chat_data: ChatData,
|
||||
):
|
||||
content = VercelStreamResponse.content_generator(
|
||||
request, event_handler, response, chat_data
|
||||
)
|
||||
super().__init__(content=content)
|
||||
|
||||
@classmethod
|
||||
async def content_generator(
|
||||
cls,
|
||||
request: Request,
|
||||
event_handler: EventCallbackHandler,
|
||||
response: StreamingAgentChatResponse,
|
||||
chat_data: ChatData,
|
||||
):
|
||||
# Yield the text response
|
||||
async def _chat_response_generator():
|
||||
final_response = ""
|
||||
async for token in response.async_response_gen():
|
||||
final_response += token
|
||||
yield VercelStreamResponse.convert_text(token)
|
||||
|
||||
# Generate questions that user might interested to
|
||||
conversation = chat_data.messages + [
|
||||
Message(role="assistant", content=final_response)
|
||||
]
|
||||
questions = await NextQuestionSuggestion.suggest_next_questions(
|
||||
conversation
|
||||
)
|
||||
if len(questions) > 0:
|
||||
yield VercelStreamResponse.convert_data(
|
||||
{
|
||||
"type": "suggested_questions",
|
||||
"data": questions,
|
||||
}
|
||||
)
|
||||
|
||||
# the text_generator is the leading stream, once it's finished, also finish the event stream
|
||||
event_handler.is_done = True
|
||||
|
||||
# Yield the source nodes
|
||||
yield cls.convert_data(
|
||||
{
|
||||
"type": "sources",
|
||||
"data": {
|
||||
"nodes": [
|
||||
SourceNodes.from_source_node(node).dict()
|
||||
for node in response.source_nodes
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Yield the events from the event handler
|
||||
async def _event_generator():
|
||||
async for event in event_handler.async_event_gen():
|
||||
event_response = event.to_response()
|
||||
if event_response is not None:
|
||||
yield VercelStreamResponse.convert_data(event_response)
|
||||
|
||||
combine = stream.merge(_chat_response_generator(), _event_generator())
|
||||
is_stream_started = False
|
||||
async with combine.stream() as streamer:
|
||||
async for output in streamer:
|
||||
if not is_stream_started:
|
||||
is_stream_started = True
|
||||
# Stream a blank message to start the stream
|
||||
yield VercelStreamResponse.convert_text("")
|
||||
|
||||
yield output
|
||||
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
@@ -1,113 +0,0 @@
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from uuid import uuid4
|
||||
|
||||
from app.engine.index import get_index
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.ingestion import IngestionPipeline
|
||||
from llama_index.core.readers.file.base import (
|
||||
_try_loading_included_file_formats as get_file_loaders_map,
|
||||
)
|
||||
from llama_index.core.readers.file.base import (
|
||||
default_file_metadata_func,
|
||||
)
|
||||
from llama_index.core.schema import Document
|
||||
from llama_index.indices.managed.llama_cloud.base import LlamaCloudIndex
|
||||
from llama_index.readers.file import FlatReader
|
||||
|
||||
|
||||
def get_llamaparse_parser():
|
||||
from app.engine.loaders import load_configs
|
||||
from app.engine.loaders.file import FileLoaderConfig, llama_parse_parser
|
||||
|
||||
config = load_configs()
|
||||
file_loader_config = FileLoaderConfig(**config["file"])
|
||||
if file_loader_config.use_llama_parse:
|
||||
return llama_parse_parser()
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def default_file_loaders_map():
|
||||
default_loaders = get_file_loaders_map()
|
||||
default_loaders[".txt"] = FlatReader
|
||||
return default_loaders
|
||||
|
||||
|
||||
class PrivateFileService:
|
||||
PRIVATE_STORE_PATH = "output/uploaded"
|
||||
|
||||
@staticmethod
|
||||
def preprocess_base64_file(base64_content: str) -> tuple:
|
||||
header, data = base64_content.split(",", 1)
|
||||
mime_type = header.split(";")[0].split(":", 1)[1]
|
||||
extension = mimetypes.guess_extension(mime_type)
|
||||
# File data as bytes
|
||||
return base64.b64decode(data), extension
|
||||
|
||||
@staticmethod
|
||||
def store_and_parse_file(file_data, extension) -> List[Document]:
|
||||
# Store file to the private directory
|
||||
os.makedirs(PrivateFileService.PRIVATE_STORE_PATH, exist_ok=True)
|
||||
|
||||
# random file name
|
||||
file_name = f"{uuid4().hex}{extension}"
|
||||
file_path = Path(os.path.join(PrivateFileService.PRIVATE_STORE_PATH, file_name))
|
||||
|
||||
# write file
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_data)
|
||||
|
||||
# Load file to documents
|
||||
# If LlamaParse is enabled, use it to parse the file
|
||||
# Otherwise, use the default file loaders
|
||||
reader = get_llamaparse_parser()
|
||||
if reader is None:
|
||||
reader_cls = default_file_loaders_map().get(extension)
|
||||
if reader_cls is None:
|
||||
raise ValueError(f"File extension {extension} is not supported")
|
||||
reader = reader_cls()
|
||||
documents = reader.load_data(file_path)
|
||||
# Add custom metadata
|
||||
for doc in documents:
|
||||
doc.metadata["file_name"] = file_name
|
||||
doc.metadata["private"] = "true"
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def process_file(base64_content: str) -> List[str]:
|
||||
file_data, extension = PrivateFileService.preprocess_base64_file(base64_content)
|
||||
documents = PrivateFileService.store_and_parse_file(file_data, extension)
|
||||
|
||||
# Only process nodes, no store the index
|
||||
pipeline = IngestionPipeline()
|
||||
nodes = pipeline.run(documents=documents)
|
||||
|
||||
# Add the nodes to the index and persist it
|
||||
current_index = get_index()
|
||||
|
||||
# Insert the documents into the index
|
||||
if isinstance(current_index, LlamaCloudIndex):
|
||||
# LlamaCloudIndex is a managed index so we don't need to process the nodes
|
||||
# just insert the documents
|
||||
for doc in documents:
|
||||
current_index.insert(doc)
|
||||
else:
|
||||
# Only process nodes, no store the index
|
||||
pipeline = IngestionPipeline()
|
||||
nodes = pipeline.run(documents=documents)
|
||||
|
||||
# Add the nodes to the index and persist it
|
||||
if current_index is None:
|
||||
current_index = VectorStoreIndex(nodes=nodes)
|
||||
else:
|
||||
current_index.insert_nodes(nodes=nodes)
|
||||
current_index.storage_context.persist(
|
||||
persist_dir=os.environ.get("STORAGE_DIR", "storage")
|
||||
)
|
||||
|
||||
# Return the document ids
|
||||
return [doc.doc_id for doc in documents]
|
||||
@@ -1,114 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from app.api.routers.models import LlamaCloudFile
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
class LLamaCloudFileService:
|
||||
LLAMA_CLOUD_URL = "https://cloud.llamaindex.ai/api/v1"
|
||||
LOCAL_STORE_PATH = "output/llamacloud"
|
||||
|
||||
DOWNLOAD_FILE_NAME_TPL = "{pipeline_id}${filename}"
|
||||
|
||||
@classmethod
|
||||
def get_all_projects(cls) -> List[Dict[str, Any]]:
|
||||
url = f"{cls.LLAMA_CLOUD_URL}/projects"
|
||||
return cls._make_request(url)
|
||||
|
||||
@classmethod
|
||||
def get_all_pipelines(cls) -> List[Dict[str, Any]]:
|
||||
url = f"{cls.LLAMA_CLOUD_URL}/pipelines"
|
||||
return cls._make_request(url)
|
||||
|
||||
@classmethod
|
||||
def get_all_projects_with_pipelines(cls) -> List[Dict[str, Any]]:
|
||||
try:
|
||||
projects = cls.get_all_projects()
|
||||
pipelines = cls.get_all_pipelines()
|
||||
return [
|
||||
{
|
||||
**project,
|
||||
"pipelines": [p for p in pipelines if p["project_id"] == project["id"]],
|
||||
}
|
||||
for project in projects
|
||||
]
|
||||
except Exception as error:
|
||||
logger.error(f"Error listing projects and pipelines: {error}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _get_files(cls, pipeline_id: str) -> List[Dict[str, Any]]:
|
||||
url = f"{cls.LLAMA_CLOUD_URL}/pipelines/{pipeline_id}/files"
|
||||
return cls._make_request(url)
|
||||
|
||||
@classmethod
|
||||
def _get_file_detail(cls, project_id: str, file_id: str) -> Dict[str, Any]:
|
||||
url = f"{cls.LLAMA_CLOUD_URL}/files/{file_id}/content?project_id={project_id}"
|
||||
return cls._make_request(url)
|
||||
|
||||
@classmethod
|
||||
def _download_file(cls, url: str, local_file_path: str):
|
||||
logger.info(f"Downloading file to {local_file_path}")
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(cls.LOCAL_STORE_PATH, exist_ok=True)
|
||||
# Download the file
|
||||
with requests.get(url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(local_file_path, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
logger.info("File downloaded successfully")
|
||||
|
||||
@classmethod
|
||||
def download_llamacloud_pipeline_file(
|
||||
cls,
|
||||
file: LlamaCloudFile,
|
||||
force_download: bool = False,
|
||||
):
|
||||
file_name = file.file_name
|
||||
pipeline_id = file.pipeline_id
|
||||
|
||||
# Check is the file already exists
|
||||
downloaded_file_path = cls.get_file_path(file_name, pipeline_id)
|
||||
if os.path.exists(downloaded_file_path) and not force_download:
|
||||
logger.debug(f"File {file_name} already exists in local storage")
|
||||
return
|
||||
try:
|
||||
logger.info(f"Downloading file {file_name} for pipeline {pipeline_id}")
|
||||
files = cls._get_files(pipeline_id)
|
||||
if not files or not isinstance(files, list):
|
||||
raise Exception("No files found in LlamaCloud")
|
||||
for file_entry in files:
|
||||
if file_entry["name"] == file_name:
|
||||
file_id = file_entry["file_id"]
|
||||
project_id = file_entry["project_id"]
|
||||
file_detail = cls._get_file_detail(project_id, file_id)
|
||||
cls._download_file(file_detail["url"], downloaded_file_path)
|
||||
break
|
||||
except Exception as error:
|
||||
logger.info(f"Error fetching file from LlamaCloud: {error}")
|
||||
|
||||
@classmethod
|
||||
def get_file_name(cls, name: str, pipeline_id: str) -> str:
|
||||
return cls.DOWNLOAD_FILE_NAME_TPL.format(pipeline_id=pipeline_id, filename=name)
|
||||
|
||||
@classmethod
|
||||
def get_file_path(cls, name: str, pipeline_id: str) -> str:
|
||||
return os.path.join(cls.LOCAL_STORE_PATH, cls.get_file_name(name, pipeline_id))
|
||||
|
||||
@staticmethod
|
||||
def _make_request(
|
||||
url: str, data=None, headers: Optional[Dict] = None, method: str = "get"
|
||||
):
|
||||
if headers is None:
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Authorization": f'Bearer {os.getenv("LLAMA_CLOUD_API_KEY")}',
|
||||
}
|
||||
response = requests.request(method, url, headers=headers, data=data)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@@ -1,48 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from app.api.routers.models import Message
|
||||
from llama_index.core.prompts import PromptTemplate
|
||||
from llama_index.core.settings import Settings
|
||||
from pydantic import BaseModel
|
||||
|
||||
NEXT_QUESTIONS_SUGGESTION_PROMPT = PromptTemplate(
|
||||
"你是一个乐于助人的助手!你的任务是对用户可能会问的下一个问题给出建议。 "
|
||||
"\n这是对话历史记录"
|
||||
"\n---------------------\n{conversation}\n---------------------"
|
||||
"考虑到对话历史记录,仅限于现在知识库已有内容, 请给我 $number_of_questions 个你接下来可能会问题的问题!"
|
||||
)
|
||||
N_QUESTION_TO_GENERATE = 3
|
||||
|
||||
|
||||
class NextQuestions(BaseModel):
|
||||
"""A list of questions that user might ask next"""
|
||||
|
||||
questions: List[str]
|
||||
|
||||
|
||||
class NextQuestionSuggestion:
|
||||
@staticmethod
|
||||
async def suggest_next_questions(
|
||||
messages: List[Message],
|
||||
number_of_questions: int = N_QUESTION_TO_GENERATE,
|
||||
) -> List[str]:
|
||||
# Reduce the cost by only using the last two messages
|
||||
last_user_message = None
|
||||
last_assistant_message = None
|
||||
for message in reversed(messages):
|
||||
if message.role == "user":
|
||||
last_user_message = f"User: {message.content}"
|
||||
elif message.role == "assistant":
|
||||
last_assistant_message = f"Assistant: {message.content}"
|
||||
if last_user_message and last_assistant_message:
|
||||
break
|
||||
conversation: str = f"{last_user_message}\n{last_assistant_message}"
|
||||
|
||||
output: NextQuestions = await Settings.llm.astructured_predict(
|
||||
NextQuestions,
|
||||
prompt=NEXT_QUESTIONS_SUGGESTION_PROMPT,
|
||||
conversation=conversation,
|
||||
nun_questions=number_of_questions,
|
||||
)
|
||||
|
||||
return output.questions
|
||||
Reference in New Issue
Block a user