删除误上传的文件

This commit is contained in:
2024-08-26 09:54:33 +08:00
parent 7462244f01
commit b052d373f1
36 changed files with 0 additions and 3048 deletions
@@ -1,150 +0,0 @@
import logging
import os
from typing import List
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, status
from llama_index.core.chat_engine.types import BaseChatEngine, NodeWithScore
from llama_index.core.llms import MessageRole
from llama_index.core.vector_stores.types import MetadataFilter, MetadataFilters
from app.api.routers.events import EventCallbackHandler
from app.api.routers.models import (
ChatConfig,
ChatData,
Message,
Result,
SourceNodes,
)
from app.api.routers.vercel_response import VercelStreamResponse
from app.api.services.llama_cloud import LLamaCloudFileService
from app.engine import get_chat_engine
chat_router = r = APIRouter()
logger = logging.getLogger("uvicorn")
def process_response_nodes(
nodes: List[NodeWithScore],
background_tasks: BackgroundTasks,
):
"""
Start background tasks on the source nodes if needed.
"""
files_to_download = SourceNodes.get_download_files(nodes)
for file in files_to_download:
background_tasks.add_task(
LLamaCloudFileService.download_llamacloud_pipeline_file, file
)
# streaming endpoint - delete if not needed
@r.post("")
async def chat(
request: Request,
data: ChatData,
background_tasks: BackgroundTasks,
chat_engine: BaseChatEngine = Depends(get_chat_engine),
):
try:
last_message_content = data.get_last_message_content()
# 由于基于历史消息的提示词没有调整好,所以暂时屏蔽历史消息
data.messages.clear()
messages = data.get_history_messages()
doc_ids = data.get_chat_document_ids()
filters = generate_filters(doc_ids)
params = data.data or {}
logger.info("Creating chat engine with filters", filters.dict())
chat_engine = get_chat_engine(filters=filters, params=params)
event_handler = EventCallbackHandler()
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
response = await chat_engine.astream_chat(last_message_content, messages)
process_response_nodes(response.source_nodes, background_tasks)
return VercelStreamResponse(request, event_handler, response, data)
except Exception as e:
logger.exception("Error in chat engine", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error in chat engine: {e}",
) from e
def generate_filters(doc_ids):
if len(doc_ids) > 0:
filters = MetadataFilters(
filters=[
MetadataFilter(
key="private",
value=["true"],
operator="nin", # type: ignore
),
MetadataFilter(
key="doc_id",
value=doc_ids,
operator="in", # type: ignore
),
],
condition="or", # type: ignore
)
else:
filters = MetadataFilters(
# Use the "NIN" - "not in" operator to include all public documents (don't have the private key set)
filters=[
MetadataFilter(
key="private",
value=["true"],
operator="nin", # type: ignore
),
]
)
return filters
# non-streaming endpoint - delete if not needed
@r.post("/request")
async def chat_request(
data: ChatData,
chat_engine: BaseChatEngine = Depends(get_chat_engine),
) -> Result:
last_message_content = data.get_last_message_content()
messages = data.get_history_messages()
response = await chat_engine.achat(last_message_content, messages)
return Result(
result=Message(role=MessageRole.ASSISTANT, content=response.response),
nodes=SourceNodes.from_source_nodes(response.source_nodes),
)
@r.get("/config")
async def chat_config() -> ChatConfig:
starter_questions = None
conversation_starters = os.getenv("CONVERSATION_STARTERS")
if conversation_starters and conversation_starters.strip():
starter_questions = conversation_starters.strip().split("\\n")
return ChatConfig(starter_questions=starter_questions)
@r.get("/config/llamacloud")
async def chat_llama_cloud_config():
projects = LLamaCloudFileService.get_all_projects_with_pipelines()
pipeline = os.getenv("LLAMA_CLOUD_INDEX_NAME")
project = os.getenv("LLAMA_CLOUD_PROJECT_NAME")
pipeline_config = (
pipeline
and project
and {
"pipeline": pipeline,
"project": project,
}
or None
)
return {
"projects": projects,
"pipeline": pipeline_config,
}
@@ -1,149 +0,0 @@
import json
import asyncio
import logging
from typing import AsyncGenerator, Dict, Any, List, Optional
from llama_index.core.callbacks.base import BaseCallbackHandler
from llama_index.core.callbacks.schema import CBEventType
from llama_index.core.tools.types import ToolOutput
from pydantic import BaseModel
logger = logging.getLogger(__name__)
class CallbackEvent(BaseModel):
event_type: CBEventType
payload: Optional[Dict[str, Any]] = None
event_id: str = ""
def get_retrieval_message(self) -> dict | None:
if self.payload:
nodes = self.payload.get("nodes")
if nodes:
msg = f"根据查询检索到 {len(nodes)} 源文件"
else:
msg = f"查询检索中: '{self.payload.get('query_str')}'"
return {
"type": "events",
"data": {"title": msg},
}
else:
return None
def get_tool_message(self) -> dict | None:
func_call_args = self.payload.get("function_call")
if func_call_args is not None and "tool" in self.payload:
tool = self.payload.get("tool")
return {
"type": "events",
"data": {
"title": f"调用工具 {tool.name} ,参数: {func_call_args}",
},
}
def _is_output_serializable(self, output: Any) -> bool:
try:
json.dumps(output)
return True
except TypeError:
return False
def get_agent_tool_response(self) -> dict | None:
response = self.payload.get("response")
if response is not None:
sources = response.sources
for source in sources:
# Return the tool response here to include the toolCall information
if isinstance(source, ToolOutput):
if self._is_output_serializable(source.raw_output):
output = source.raw_output
else:
output = source.content
return {
"type": "tools",
"data": {
"toolOutput": {
"output": output,
"isError": source.is_error,
},
"toolCall": {
"id": None, # There is no tool id in the ToolOutput
"name": source.tool_name,
"input": source.raw_input,
},
},
}
def to_response(self):
try:
match self.event_type:
case "retrieve":
return self.get_retrieval_message()
case "function_call":
return self.get_tool_message()
case "agent_step":
return self.get_agent_tool_response()
case _:
return None
except Exception as e:
logger.error(f"转换回应时间时发生错误,原因: {e}")
return None
class EventCallbackHandler(BaseCallbackHandler):
_aqueue: asyncio.Queue
is_done: bool = False
def __init__(
self,
):
"""Initialize the base callback handler."""
ignored_events = [
CBEventType.CHUNKING,
CBEventType.NODE_PARSING,
CBEventType.EMBEDDING,
CBEventType.LLM,
CBEventType.TEMPLATING,
]
super().__init__(ignored_events, ignored_events)
self._aqueue = asyncio.Queue()
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> str:
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
if event.to_response() is not None:
self._aqueue.put_nowait(event)
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
if event.to_response() is not None:
self._aqueue.put_nowait(event)
def start_trace(self, trace_id: Optional[str] = None) -> None:
"""No-op."""
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""No-op."""
async def async_event_gen(self) -> AsyncGenerator[CallbackEvent, None]:
while not self._aqueue.empty() or not self.is_done:
try:
yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1)
except asyncio.TimeoutError:
pass
@@ -1,253 +0,0 @@
import logging
import os
from typing import Any, Dict, List, Literal, Optional, Set
from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.core.schema import NodeWithScore
from pydantic import BaseModel, Field, validator, field_validator
from pydantic.alias_generators import to_camel
logger = logging.getLogger("uvicorn")
class FileContent(BaseModel):
type: Literal["text", "ref"]
# If the file is pure text then the value is be a string
# otherwise, it's a list of document IDs
value: str | List[str]
class File(BaseModel):
id: str
content: FileContent
filename: str
filesize: int
filetype: str
class AnnotationFileData(BaseModel):
files: List[File] = Field(
default=[],
description="List of files",
)
class Config:
json_schema_extra = {
"example": {
"csvFiles": [
{
"content": "Name, Age\nAlice, 25\nBob, 30",
"filename": "example.csv",
"filesize": 123,
"id": "123",
"type": "text/csv",
}
]
}
}
alias_generator = to_camel
class Annotation(BaseModel):
type: str
data: AnnotationFileData | List[str]
def to_content(self) -> str | None:
if self.type == "document_file":
# We only support generating context content for CSV files for now
csv_files = [file for file in self.data.files if file.filetype == "csv"]
if len(csv_files) > 0:
return "Use data from following CSV raw content\n" + "\n".join(
[f"```csv\n{csv_file.content.value}\n```" for csv_file in csv_files]
)
else:
logger.warning(
f"The annotation {self.type} is not supported for generating context content"
)
return None
class Message(BaseModel):
role: MessageRole
content: str
annotations: List[Annotation] | None = None
class ChatData(BaseModel):
messages: List[Message]
data: Any = None
class Config:
json_schema_extra = {
"example": {
"messages": [
{
"role": "user",
"content": "What standards for letters exist?",
}
]
}
}
@field_validator("messages")
def messages_must_not_be_empty(cls, v):
if len(v) == 0:
raise ValueError("Messages must not be empty")
return v
def get_last_message_content(self) -> str:
"""
Get the content of the last message along with the data content if available.
Fallback to use data content from previous messages
"""
if len(self.messages) == 0:
raise ValueError("There is not any message in the chat")
last_message = self.messages[-1]
message_content = last_message.content
for message in reversed(self.messages):
if message.role == MessageRole.USER and message.annotations is not None:
annotation_contents = filter(
None,
[annotation.to_content() for annotation in message.annotations],
)
if not annotation_contents:
continue
annotation_text = "\n".join(annotation_contents)
message_content = f"{message_content}\n{annotation_text}"
break
return message_content
def get_history_messages(self) -> List[ChatMessage]:
"""
Get the history messages
"""
return [
ChatMessage(role=message.role, content=message.content)
for message in self.messages[:-1]
]
def is_last_message_from_user(self) -> bool:
return self.messages[-1].role == MessageRole.USER
def get_chat_document_ids(self) -> List[str]:
"""
Get the document IDs from the chat messages
"""
document_ids: List[str] = []
for message in self.messages:
if message.role == MessageRole.USER and message.annotations is not None:
for annotation in message.annotations:
if (
annotation.type == "document_file"
and annotation.data.files is not None
):
for fi in annotation.data.files:
if fi.content.type == "ref":
document_ids += fi.content.value
return list(set(document_ids))
class LlamaCloudFile(BaseModel):
file_name: str
pipeline_id: str
def __eq__(self, other):
if not isinstance(other, LlamaCloudFile):
return NotImplemented
return (
self.file_name == other.file_name and self.pipeline_id == other.pipeline_id
)
def __hash__(self):
return hash((self.file_name, self.pipeline_id))
class SourceNodes(BaseModel):
id: str
metadata: Dict[str, Any]
score: Optional[float]
text: str
url: Optional[str]
@classmethod
def from_source_node(cls, source_node: NodeWithScore):
metadata = source_node.node.metadata
url = cls.get_url_from_metadata(metadata)
#text = 'filename' in metadata and metadata['filename'] or source_node.node.node_id
text = source_node.node.text
return cls(
id=source_node.node.node_id,
metadata=metadata,
score=source_node.score,
text=text, # type: ignore
url=url,
)
@classmethod
def get_url_from_metadata(cls, metadata: Dict[str, Any]) -> str:
url_prefix = os.getenv("FILESERVER_URL_PREFIX")
if not url_prefix:
logger.warning(
"Warning: FILESERVER_URL_PREFIX not set in environment variables. Can't use file server"
)
file_name = metadata.get("file_name")
if file_name and url_prefix:
# file_name exists and file server is configured
pipeline_id = metadata.get("pipeline_id")
if pipeline_id and metadata.get("private") is None:
# file is from LlamaCloud and was not ingested locally
file_name = f"{pipeline_id}${file_name}"
return f"{url_prefix}/output/llamacloud/{file_name}"
is_private = metadata.get("private", "false") == "true"
if is_private:
return f"{url_prefix}/output/uploaded/{file_name}"
return f"{url_prefix}/data/{file_name}"
else:
# fallback to URL in metadata (e.g. for websites)
return metadata.get("URL")
@classmethod
def from_source_nodes(cls, source_nodes: List[NodeWithScore]):
return [cls.from_source_node(node) for node in source_nodes]
@staticmethod
def get_download_files(nodes: List[NodeWithScore]) -> Set[LlamaCloudFile]:
source_nodes = SourceNodes.from_source_nodes(nodes)
llama_cloud_files = [
LlamaCloudFile(
file_name=node.metadata.get("file_name"),
pipeline_id=node.metadata.get("pipeline_id"),
)
for node in source_nodes
if (
node.metadata.get("private")
is None # Only download files are from LlamaCloud and were not ingested locally
and node.metadata.get("pipeline_id") is not None
and node.metadata.get("file_name") is not None
)
]
# Remove duplicates and return
return set(llama_cloud_files)
class Result(BaseModel):
result: Message
nodes: List[SourceNodes]
class ChatConfig(BaseModel):
starter_questions: Optional[List[str]] = Field(
default=None,
description="List of starter questions",
serialization_alias="starterQuestions",
)
class Config:
json_schema_extra = {
"example": {
"starterQuestions": [
"What standards for letters exist?",
"What are the requirements for a letter to be considered a letter?",
]
}
}
@@ -1,25 +0,0 @@
import logging
from typing import List
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from app.api.services.file import PrivateFileService
file_upload_router = r = APIRouter()
logger = logging.getLogger("uvicorn")
class FileUploadRequest(BaseModel):
base64: str
@r.post("")
def upload_file(request: FileUploadRequest) -> List[str]:
try:
logger.info("Processing file")
return PrivateFileService.process_file(request.base64)
except Exception as e:
logger.error(f"Error processing file: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Error processing file")
@@ -1,109 +0,0 @@
import json
from aiostream import stream
from fastapi import Request
from fastapi.responses import StreamingResponse
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
from app.api.routers.events import EventCallbackHandler
from app.api.routers.models import ChatData, Message, SourceNodes
from app.api.services.suggestion import NextQuestionSuggestion
class VercelStreamResponse(StreamingResponse):
"""
Class to convert the response from the chat engine to the streaming format expected by Vercel
"""
TEXT_PREFIX = "0:"
DATA_PREFIX = "8:"
@classmethod
def convert_text(cls, token: str):
# Escape newlines and double quotes to avoid breaking the stream
token = json.dumps(token)
return f"{cls.TEXT_PREFIX}{token}\n"
@classmethod
def convert_data(cls, data: dict):
data_str = json.dumps(data)
return f"{cls.DATA_PREFIX}[{data_str}]\n"
def __init__(
self,
request: Request,
event_handler: EventCallbackHandler,
response: StreamingAgentChatResponse,
chat_data: ChatData,
):
content = VercelStreamResponse.content_generator(
request, event_handler, response, chat_data
)
super().__init__(content=content)
@classmethod
async def content_generator(
cls,
request: Request,
event_handler: EventCallbackHandler,
response: StreamingAgentChatResponse,
chat_data: ChatData,
):
# Yield the text response
async def _chat_response_generator():
final_response = ""
async for token in response.async_response_gen():
final_response += token
yield VercelStreamResponse.convert_text(token)
# Generate questions that user might interested to
conversation = chat_data.messages + [
Message(role="assistant", content=final_response)
]
questions = await NextQuestionSuggestion.suggest_next_questions(
conversation
)
if len(questions) > 0:
yield VercelStreamResponse.convert_data(
{
"type": "suggested_questions",
"data": questions,
}
)
# the text_generator is the leading stream, once it's finished, also finish the event stream
event_handler.is_done = True
# Yield the source nodes
yield cls.convert_data(
{
"type": "sources",
"data": {
"nodes": [
SourceNodes.from_source_node(node).dict()
for node in response.source_nodes
]
},
}
)
# Yield the events from the event handler
async def _event_generator():
async for event in event_handler.async_event_gen():
event_response = event.to_response()
if event_response is not None:
yield VercelStreamResponse.convert_data(event_response)
combine = stream.merge(_chat_response_generator(), _event_generator())
is_stream_started = False
async with combine.stream() as streamer:
async for output in streamer:
if not is_stream_started:
is_stream_started = True
# Stream a blank message to start the stream
yield VercelStreamResponse.convert_text("")
yield output
if await request.is_disconnected():
break
@@ -1,113 +0,0 @@
import base64
import mimetypes
import os
from pathlib import Path
from typing import Dict, List
from uuid import uuid4
from app.engine.index import get_index
from llama_index.core import VectorStoreIndex
from llama_index.core.ingestion import IngestionPipeline
from llama_index.core.readers.file.base import (
_try_loading_included_file_formats as get_file_loaders_map,
)
from llama_index.core.readers.file.base import (
default_file_metadata_func,
)
from llama_index.core.schema import Document
from llama_index.indices.managed.llama_cloud.base import LlamaCloudIndex
from llama_index.readers.file import FlatReader
def get_llamaparse_parser():
from app.engine.loaders import load_configs
from app.engine.loaders.file import FileLoaderConfig, llama_parse_parser
config = load_configs()
file_loader_config = FileLoaderConfig(**config["file"])
if file_loader_config.use_llama_parse:
return llama_parse_parser()
else:
return None
def default_file_loaders_map():
default_loaders = get_file_loaders_map()
default_loaders[".txt"] = FlatReader
return default_loaders
class PrivateFileService:
PRIVATE_STORE_PATH = "output/uploaded"
@staticmethod
def preprocess_base64_file(base64_content: str) -> tuple:
header, data = base64_content.split(",", 1)
mime_type = header.split(";")[0].split(":", 1)[1]
extension = mimetypes.guess_extension(mime_type)
# File data as bytes
return base64.b64decode(data), extension
@staticmethod
def store_and_parse_file(file_data, extension) -> List[Document]:
# Store file to the private directory
os.makedirs(PrivateFileService.PRIVATE_STORE_PATH, exist_ok=True)
# random file name
file_name = f"{uuid4().hex}{extension}"
file_path = Path(os.path.join(PrivateFileService.PRIVATE_STORE_PATH, file_name))
# write file
with open(file_path, "wb") as f:
f.write(file_data)
# Load file to documents
# If LlamaParse is enabled, use it to parse the file
# Otherwise, use the default file loaders
reader = get_llamaparse_parser()
if reader is None:
reader_cls = default_file_loaders_map().get(extension)
if reader_cls is None:
raise ValueError(f"File extension {extension} is not supported")
reader = reader_cls()
documents = reader.load_data(file_path)
# Add custom metadata
for doc in documents:
doc.metadata["file_name"] = file_name
doc.metadata["private"] = "true"
return documents
@staticmethod
def process_file(base64_content: str) -> List[str]:
file_data, extension = PrivateFileService.preprocess_base64_file(base64_content)
documents = PrivateFileService.store_and_parse_file(file_data, extension)
# Only process nodes, no store the index
pipeline = IngestionPipeline()
nodes = pipeline.run(documents=documents)
# Add the nodes to the index and persist it
current_index = get_index()
# Insert the documents into the index
if isinstance(current_index, LlamaCloudIndex):
# LlamaCloudIndex is a managed index so we don't need to process the nodes
# just insert the documents
for doc in documents:
current_index.insert(doc)
else:
# Only process nodes, no store the index
pipeline = IngestionPipeline()
nodes = pipeline.run(documents=documents)
# Add the nodes to the index and persist it
if current_index is None:
current_index = VectorStoreIndex(nodes=nodes)
else:
current_index.insert_nodes(nodes=nodes)
current_index.storage_context.persist(
persist_dir=os.environ.get("STORAGE_DIR", "storage")
)
# Return the document ids
return [doc.doc_id for doc in documents]
@@ -1,114 +0,0 @@
import logging
import os
from typing import Any, Dict, List, Optional
import requests
from app.api.routers.models import LlamaCloudFile
logger = logging.getLogger("uvicorn")
class LLamaCloudFileService:
LLAMA_CLOUD_URL = "https://cloud.llamaindex.ai/api/v1"
LOCAL_STORE_PATH = "output/llamacloud"
DOWNLOAD_FILE_NAME_TPL = "{pipeline_id}${filename}"
@classmethod
def get_all_projects(cls) -> List[Dict[str, Any]]:
url = f"{cls.LLAMA_CLOUD_URL}/projects"
return cls._make_request(url)
@classmethod
def get_all_pipelines(cls) -> List[Dict[str, Any]]:
url = f"{cls.LLAMA_CLOUD_URL}/pipelines"
return cls._make_request(url)
@classmethod
def get_all_projects_with_pipelines(cls) -> List[Dict[str, Any]]:
try:
projects = cls.get_all_projects()
pipelines = cls.get_all_pipelines()
return [
{
**project,
"pipelines": [p for p in pipelines if p["project_id"] == project["id"]],
}
for project in projects
]
except Exception as error:
logger.error(f"Error listing projects and pipelines: {error}")
return []
@classmethod
def _get_files(cls, pipeline_id: str) -> List[Dict[str, Any]]:
url = f"{cls.LLAMA_CLOUD_URL}/pipelines/{pipeline_id}/files"
return cls._make_request(url)
@classmethod
def _get_file_detail(cls, project_id: str, file_id: str) -> Dict[str, Any]:
url = f"{cls.LLAMA_CLOUD_URL}/files/{file_id}/content?project_id={project_id}"
return cls._make_request(url)
@classmethod
def _download_file(cls, url: str, local_file_path: str):
logger.info(f"Downloading file to {local_file_path}")
# Create directory if it doesn't exist
os.makedirs(cls.LOCAL_STORE_PATH, exist_ok=True)
# Download the file
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(local_file_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
logger.info("File downloaded successfully")
@classmethod
def download_llamacloud_pipeline_file(
cls,
file: LlamaCloudFile,
force_download: bool = False,
):
file_name = file.file_name
pipeline_id = file.pipeline_id
# Check is the file already exists
downloaded_file_path = cls.get_file_path(file_name, pipeline_id)
if os.path.exists(downloaded_file_path) and not force_download:
logger.debug(f"File {file_name} already exists in local storage")
return
try:
logger.info(f"Downloading file {file_name} for pipeline {pipeline_id}")
files = cls._get_files(pipeline_id)
if not files or not isinstance(files, list):
raise Exception("No files found in LlamaCloud")
for file_entry in files:
if file_entry["name"] == file_name:
file_id = file_entry["file_id"]
project_id = file_entry["project_id"]
file_detail = cls._get_file_detail(project_id, file_id)
cls._download_file(file_detail["url"], downloaded_file_path)
break
except Exception as error:
logger.info(f"Error fetching file from LlamaCloud: {error}")
@classmethod
def get_file_name(cls, name: str, pipeline_id: str) -> str:
return cls.DOWNLOAD_FILE_NAME_TPL.format(pipeline_id=pipeline_id, filename=name)
@classmethod
def get_file_path(cls, name: str, pipeline_id: str) -> str:
return os.path.join(cls.LOCAL_STORE_PATH, cls.get_file_name(name, pipeline_id))
@staticmethod
def _make_request(
url: str, data=None, headers: Optional[Dict] = None, method: str = "get"
):
if headers is None:
headers = {
"Accept": "application/json",
"Authorization": f'Bearer {os.getenv("LLAMA_CLOUD_API_KEY")}',
}
response = requests.request(method, url, headers=headers, data=data)
response.raise_for_status()
return response.json()
@@ -1,48 +0,0 @@
from typing import List
from app.api.routers.models import Message
from llama_index.core.prompts import PromptTemplate
from llama_index.core.settings import Settings
from pydantic import BaseModel
NEXT_QUESTIONS_SUGGESTION_PROMPT = PromptTemplate(
"你是一个乐于助人的助手!你的任务是对用户可能会问的下一个问题给出建议。 "
"\n这是对话历史记录"
"\n---------------------\n{conversation}\n---------------------"
"考虑到对话历史记录,仅限于现在知识库已有内容, 请给我 $number_of_questions 个你接下来可能会问题的问题!"
)
N_QUESTION_TO_GENERATE = 3
class NextQuestions(BaseModel):
"""A list of questions that user might ask next"""
questions: List[str]
class NextQuestionSuggestion:
@staticmethod
async def suggest_next_questions(
messages: List[Message],
number_of_questions: int = N_QUESTION_TO_GENERATE,
) -> List[str]:
# Reduce the cost by only using the last two messages
last_user_message = None
last_assistant_message = None
for message in reversed(messages):
if message.role == "user":
last_user_message = f"User: {message.content}"
elif message.role == "assistant":
last_assistant_message = f"Assistant: {message.content}"
if last_user_message and last_assistant_message:
break
conversation: str = f"{last_user_message}\n{last_assistant_message}"
output: NextQuestions = await Settings.llm.astructured_predict(
NextQuestions,
prompt=NEXT_QUESTIONS_SUGGESTION_PROMPT,
conversation=conversation,
nun_questions=number_of_questions,
)
return output.questions