149 lines
4.6 KiB
Python
149 lines
4.6 KiB
Python
import logging
|
|
import os
|
|
from typing import List
|
|
|
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, status
|
|
from llama_index.core.chat_engine.types import BaseChatEngine, NodeWithScore
|
|
from llama_index.core.llms import MessageRole
|
|
from llama_index.core.vector_stores.types import MetadataFilter, MetadataFilters
|
|
|
|
from app.api.routers.events import EventCallbackHandler
|
|
from app.api.routers.models import (
|
|
ChatConfig,
|
|
ChatData,
|
|
Message,
|
|
Result,
|
|
SourceNodes,
|
|
)
|
|
from app.api.routers.vercel_response import VercelStreamResponse
|
|
from app.api.services.llama_cloud import LLamaCloudFileService
|
|
from app.engine import get_chat_engine
|
|
|
|
chat_router = r = APIRouter()
|
|
|
|
logger = logging.getLogger("uvicorn")
|
|
|
|
|
|
def process_response_nodes(
|
|
nodes: List[NodeWithScore],
|
|
background_tasks: BackgroundTasks,
|
|
):
|
|
"""
|
|
Start background tasks on the source nodes if needed.
|
|
"""
|
|
files_to_download = SourceNodes.get_download_files(nodes)
|
|
for file in files_to_download:
|
|
background_tasks.add_task(
|
|
LLamaCloudFileService.download_llamacloud_pipeline_file, file
|
|
)
|
|
|
|
|
|
# streaming endpoint - delete if not needed
|
|
@r.post("")
|
|
async def chat(
|
|
request: Request,
|
|
data: ChatData,
|
|
background_tasks: BackgroundTasks,
|
|
chat_engine: BaseChatEngine = Depends(get_chat_engine),
|
|
):
|
|
try:
|
|
last_message_content = data.get_last_message_content()
|
|
messages = data.get_history_messages()
|
|
|
|
doc_ids = data.get_chat_document_ids()
|
|
filters = generate_filters(doc_ids)
|
|
params = data.data or {}
|
|
logger.info("Creating chat engine with filters", filters.dict())
|
|
chat_engine = get_chat_engine(filters=filters, params=params)
|
|
|
|
event_handler = EventCallbackHandler()
|
|
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
|
|
|
|
response = await chat_engine.astream_chat(last_message_content, messages)
|
|
process_response_nodes(response.source_nodes, background_tasks)
|
|
|
|
return VercelStreamResponse(request, event_handler, response, data)
|
|
except Exception as e:
|
|
logger.exception("Error in chat engine", exc_info=True)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Error in chat engine: {e}",
|
|
) from e
|
|
|
|
|
|
def generate_filters(doc_ids):
|
|
if len(doc_ids) > 0:
|
|
filters = MetadataFilters(
|
|
filters=[
|
|
MetadataFilter(
|
|
key="private",
|
|
value=["true"],
|
|
operator="nin", # type: ignore
|
|
),
|
|
MetadataFilter(
|
|
key="doc_id",
|
|
value=doc_ids,
|
|
operator="in", # type: ignore
|
|
),
|
|
],
|
|
condition="or", # type: ignore
|
|
)
|
|
else:
|
|
filters = MetadataFilters(
|
|
# Use the "NIN" - "not in" operator to include all public documents (don't have the private key set)
|
|
filters=[
|
|
MetadataFilter(
|
|
key="private",
|
|
value=["true"],
|
|
operator="nin", # type: ignore
|
|
),
|
|
]
|
|
)
|
|
|
|
return filters
|
|
|
|
|
|
# non-streaming endpoint - delete if not needed
|
|
@r.post("/request")
|
|
async def chat_request(
|
|
data: ChatData,
|
|
chat_engine: BaseChatEngine = Depends(get_chat_engine),
|
|
) -> Result:
|
|
last_message_content = data.get_last_message_content()
|
|
messages = data.get_history_messages()
|
|
|
|
response = await chat_engine.achat(last_message_content, messages)
|
|
return Result(
|
|
result=Message(role=MessageRole.ASSISTANT, content=response.response),
|
|
nodes=SourceNodes.from_source_nodes(response.source_nodes),
|
|
)
|
|
|
|
|
|
@r.get("/config")
|
|
async def chat_config() -> ChatConfig:
|
|
starter_questions = None
|
|
conversation_starters = os.getenv("CONVERSATION_STARTERS")
|
|
if conversation_starters and conversation_starters.strip():
|
|
starter_questions = conversation_starters.strip().split("\n")
|
|
return ChatConfig(starter_questions=starter_questions)
|
|
|
|
|
|
@r.get("/config/llamacloud")
|
|
async def chat_llama_cloud_config():
|
|
projects = LLamaCloudFileService.get_all_projects_with_pipelines()
|
|
pipeline = os.getenv("LLAMA_CLOUD_INDEX_NAME")
|
|
project = os.getenv("LLAMA_CLOUD_PROJECT_NAME")
|
|
pipeline_config = (
|
|
pipeline
|
|
and project
|
|
and {
|
|
"pipeline": pipeline,
|
|
"project": project,
|
|
}
|
|
or None
|
|
)
|
|
return {
|
|
"projects": projects,
|
|
"pipeline": pipeline_config,
|
|
}
|