优化了提示词

This commit is contained in:
chentianrui
2024-08-23 18:35:19 +08:00
parent 7691b22274
commit a200e8adfc
39 changed files with 3083 additions and 21 deletions
@@ -0,0 +1,150 @@
import logging
import os
from typing import List
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, status
from llama_index.core.chat_engine.types import BaseChatEngine, NodeWithScore
from llama_index.core.llms import MessageRole
from llama_index.core.vector_stores.types import MetadataFilter, MetadataFilters
from app.api.routers.events import EventCallbackHandler
from app.api.routers.models import (
ChatConfig,
ChatData,
Message,
Result,
SourceNodes,
)
from app.api.routers.vercel_response import VercelStreamResponse
from app.api.services.llama_cloud import LLamaCloudFileService
from app.engine import get_chat_engine
chat_router = r = APIRouter()
logger = logging.getLogger("uvicorn")
def process_response_nodes(
nodes: List[NodeWithScore],
background_tasks: BackgroundTasks,
):
"""
Start background tasks on the source nodes if needed.
"""
files_to_download = SourceNodes.get_download_files(nodes)
for file in files_to_download:
background_tasks.add_task(
LLamaCloudFileService.download_llamacloud_pipeline_file, file
)
# streaming endpoint - delete if not needed
@r.post("")
async def chat(
request: Request,
data: ChatData,
background_tasks: BackgroundTasks,
chat_engine: BaseChatEngine = Depends(get_chat_engine),
):
try:
last_message_content = data.get_last_message_content()
# 由于基于历史消息的提示词没有调整好,所以暂时屏蔽历史消息
data.messages.clear()
messages = data.get_history_messages()
doc_ids = data.get_chat_document_ids()
filters = generate_filters(doc_ids)
params = data.data or {}
logger.info("Creating chat engine with filters", filters.dict())
chat_engine = get_chat_engine(filters=filters, params=params)
event_handler = EventCallbackHandler()
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
response = await chat_engine.astream_chat(last_message_content, messages)
process_response_nodes(response.source_nodes, background_tasks)
return VercelStreamResponse(request, event_handler, response, data)
except Exception as e:
logger.exception("Error in chat engine", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error in chat engine: {e}",
) from e
def generate_filters(doc_ids):
if len(doc_ids) > 0:
filters = MetadataFilters(
filters=[
MetadataFilter(
key="private",
value=["true"],
operator="nin", # type: ignore
),
MetadataFilter(
key="doc_id",
value=doc_ids,
operator="in", # type: ignore
),
],
condition="or", # type: ignore
)
else:
filters = MetadataFilters(
# Use the "NIN" - "not in" operator to include all public documents (don't have the private key set)
filters=[
MetadataFilter(
key="private",
value=["true"],
operator="nin", # type: ignore
),
]
)
return filters
# non-streaming endpoint - delete if not needed
@r.post("/request")
async def chat_request(
data: ChatData,
chat_engine: BaseChatEngine = Depends(get_chat_engine),
) -> Result:
last_message_content = data.get_last_message_content()
messages = data.get_history_messages()
response = await chat_engine.achat(last_message_content, messages)
return Result(
result=Message(role=MessageRole.ASSISTANT, content=response.response),
nodes=SourceNodes.from_source_nodes(response.source_nodes),
)
@r.get("/config")
async def chat_config() -> ChatConfig:
starter_questions = None
conversation_starters = os.getenv("CONVERSATION_STARTERS")
if conversation_starters and conversation_starters.strip():
starter_questions = conversation_starters.strip().split("\\n")
return ChatConfig(starter_questions=starter_questions)
@r.get("/config/llamacloud")
async def chat_llama_cloud_config():
projects = LLamaCloudFileService.get_all_projects_with_pipelines()
pipeline = os.getenv("LLAMA_CLOUD_INDEX_NAME")
project = os.getenv("LLAMA_CLOUD_PROJECT_NAME")
pipeline_config = (
pipeline
and project
and {
"pipeline": pipeline,
"project": project,
}
or None
)
return {
"projects": projects,
"pipeline": pipeline_config,
}