优化了提示词
This commit is contained in:
@@ -0,0 +1,113 @@
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from uuid import uuid4
|
||||
|
||||
from app.engine.index import get_index
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.ingestion import IngestionPipeline
|
||||
from llama_index.core.readers.file.base import (
|
||||
_try_loading_included_file_formats as get_file_loaders_map,
|
||||
)
|
||||
from llama_index.core.readers.file.base import (
|
||||
default_file_metadata_func,
|
||||
)
|
||||
from llama_index.core.schema import Document
|
||||
from llama_index.indices.managed.llama_cloud.base import LlamaCloudIndex
|
||||
from llama_index.readers.file import FlatReader
|
||||
|
||||
|
||||
def get_llamaparse_parser():
|
||||
from app.engine.loaders import load_configs
|
||||
from app.engine.loaders.file import FileLoaderConfig, llama_parse_parser
|
||||
|
||||
config = load_configs()
|
||||
file_loader_config = FileLoaderConfig(**config["file"])
|
||||
if file_loader_config.use_llama_parse:
|
||||
return llama_parse_parser()
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def default_file_loaders_map():
|
||||
default_loaders = get_file_loaders_map()
|
||||
default_loaders[".txt"] = FlatReader
|
||||
return default_loaders
|
||||
|
||||
|
||||
class PrivateFileService:
|
||||
PRIVATE_STORE_PATH = "output/uploaded"
|
||||
|
||||
@staticmethod
|
||||
def preprocess_base64_file(base64_content: str) -> tuple:
|
||||
header, data = base64_content.split(",", 1)
|
||||
mime_type = header.split(";")[0].split(":", 1)[1]
|
||||
extension = mimetypes.guess_extension(mime_type)
|
||||
# File data as bytes
|
||||
return base64.b64decode(data), extension
|
||||
|
||||
@staticmethod
|
||||
def store_and_parse_file(file_data, extension) -> List[Document]:
|
||||
# Store file to the private directory
|
||||
os.makedirs(PrivateFileService.PRIVATE_STORE_PATH, exist_ok=True)
|
||||
|
||||
# random file name
|
||||
file_name = f"{uuid4().hex}{extension}"
|
||||
file_path = Path(os.path.join(PrivateFileService.PRIVATE_STORE_PATH, file_name))
|
||||
|
||||
# write file
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_data)
|
||||
|
||||
# Load file to documents
|
||||
# If LlamaParse is enabled, use it to parse the file
|
||||
# Otherwise, use the default file loaders
|
||||
reader = get_llamaparse_parser()
|
||||
if reader is None:
|
||||
reader_cls = default_file_loaders_map().get(extension)
|
||||
if reader_cls is None:
|
||||
raise ValueError(f"File extension {extension} is not supported")
|
||||
reader = reader_cls()
|
||||
documents = reader.load_data(file_path)
|
||||
# Add custom metadata
|
||||
for doc in documents:
|
||||
doc.metadata["file_name"] = file_name
|
||||
doc.metadata["private"] = "true"
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def process_file(base64_content: str) -> List[str]:
|
||||
file_data, extension = PrivateFileService.preprocess_base64_file(base64_content)
|
||||
documents = PrivateFileService.store_and_parse_file(file_data, extension)
|
||||
|
||||
# Only process nodes, no store the index
|
||||
pipeline = IngestionPipeline()
|
||||
nodes = pipeline.run(documents=documents)
|
||||
|
||||
# Add the nodes to the index and persist it
|
||||
current_index = get_index()
|
||||
|
||||
# Insert the documents into the index
|
||||
if isinstance(current_index, LlamaCloudIndex):
|
||||
# LlamaCloudIndex is a managed index so we don't need to process the nodes
|
||||
# just insert the documents
|
||||
for doc in documents:
|
||||
current_index.insert(doc)
|
||||
else:
|
||||
# Only process nodes, no store the index
|
||||
pipeline = IngestionPipeline()
|
||||
nodes = pipeline.run(documents=documents)
|
||||
|
||||
# Add the nodes to the index and persist it
|
||||
if current_index is None:
|
||||
current_index = VectorStoreIndex(nodes=nodes)
|
||||
else:
|
||||
current_index.insert_nodes(nodes=nodes)
|
||||
current_index.storage_context.persist(
|
||||
persist_dir=os.environ.get("STORAGE_DIR", "storage")
|
||||
)
|
||||
|
||||
# Return the document ids
|
||||
return [doc.doc_id for doc in documents]
|
||||
@@ -0,0 +1,114 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from app.api.routers.models import LlamaCloudFile
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
class LLamaCloudFileService:
|
||||
LLAMA_CLOUD_URL = "https://cloud.llamaindex.ai/api/v1"
|
||||
LOCAL_STORE_PATH = "output/llamacloud"
|
||||
|
||||
DOWNLOAD_FILE_NAME_TPL = "{pipeline_id}${filename}"
|
||||
|
||||
@classmethod
|
||||
def get_all_projects(cls) -> List[Dict[str, Any]]:
|
||||
url = f"{cls.LLAMA_CLOUD_URL}/projects"
|
||||
return cls._make_request(url)
|
||||
|
||||
@classmethod
|
||||
def get_all_pipelines(cls) -> List[Dict[str, Any]]:
|
||||
url = f"{cls.LLAMA_CLOUD_URL}/pipelines"
|
||||
return cls._make_request(url)
|
||||
|
||||
@classmethod
|
||||
def get_all_projects_with_pipelines(cls) -> List[Dict[str, Any]]:
|
||||
try:
|
||||
projects = cls.get_all_projects()
|
||||
pipelines = cls.get_all_pipelines()
|
||||
return [
|
||||
{
|
||||
**project,
|
||||
"pipelines": [p for p in pipelines if p["project_id"] == project["id"]],
|
||||
}
|
||||
for project in projects
|
||||
]
|
||||
except Exception as error:
|
||||
logger.error(f"Error listing projects and pipelines: {error}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _get_files(cls, pipeline_id: str) -> List[Dict[str, Any]]:
|
||||
url = f"{cls.LLAMA_CLOUD_URL}/pipelines/{pipeline_id}/files"
|
||||
return cls._make_request(url)
|
||||
|
||||
@classmethod
|
||||
def _get_file_detail(cls, project_id: str, file_id: str) -> Dict[str, Any]:
|
||||
url = f"{cls.LLAMA_CLOUD_URL}/files/{file_id}/content?project_id={project_id}"
|
||||
return cls._make_request(url)
|
||||
|
||||
@classmethod
|
||||
def _download_file(cls, url: str, local_file_path: str):
|
||||
logger.info(f"Downloading file to {local_file_path}")
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(cls.LOCAL_STORE_PATH, exist_ok=True)
|
||||
# Download the file
|
||||
with requests.get(url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(local_file_path, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
logger.info("File downloaded successfully")
|
||||
|
||||
@classmethod
|
||||
def download_llamacloud_pipeline_file(
|
||||
cls,
|
||||
file: LlamaCloudFile,
|
||||
force_download: bool = False,
|
||||
):
|
||||
file_name = file.file_name
|
||||
pipeline_id = file.pipeline_id
|
||||
|
||||
# Check is the file already exists
|
||||
downloaded_file_path = cls.get_file_path(file_name, pipeline_id)
|
||||
if os.path.exists(downloaded_file_path) and not force_download:
|
||||
logger.debug(f"File {file_name} already exists in local storage")
|
||||
return
|
||||
try:
|
||||
logger.info(f"Downloading file {file_name} for pipeline {pipeline_id}")
|
||||
files = cls._get_files(pipeline_id)
|
||||
if not files or not isinstance(files, list):
|
||||
raise Exception("No files found in LlamaCloud")
|
||||
for file_entry in files:
|
||||
if file_entry["name"] == file_name:
|
||||
file_id = file_entry["file_id"]
|
||||
project_id = file_entry["project_id"]
|
||||
file_detail = cls._get_file_detail(project_id, file_id)
|
||||
cls._download_file(file_detail["url"], downloaded_file_path)
|
||||
break
|
||||
except Exception as error:
|
||||
logger.info(f"Error fetching file from LlamaCloud: {error}")
|
||||
|
||||
@classmethod
|
||||
def get_file_name(cls, name: str, pipeline_id: str) -> str:
|
||||
return cls.DOWNLOAD_FILE_NAME_TPL.format(pipeline_id=pipeline_id, filename=name)
|
||||
|
||||
@classmethod
|
||||
def get_file_path(cls, name: str, pipeline_id: str) -> str:
|
||||
return os.path.join(cls.LOCAL_STORE_PATH, cls.get_file_name(name, pipeline_id))
|
||||
|
||||
@staticmethod
|
||||
def _make_request(
|
||||
url: str, data=None, headers: Optional[Dict] = None, method: str = "get"
|
||||
):
|
||||
if headers is None:
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Authorization": f'Bearer {os.getenv("LLAMA_CLOUD_API_KEY")}',
|
||||
}
|
||||
response = requests.request(method, url, headers=headers, data=data)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@@ -0,0 +1,48 @@
|
||||
from typing import List
|
||||
|
||||
from app.api.routers.models import Message
|
||||
from llama_index.core.prompts import PromptTemplate
|
||||
from llama_index.core.settings import Settings
|
||||
from pydantic import BaseModel
|
||||
|
||||
NEXT_QUESTIONS_SUGGESTION_PROMPT = PromptTemplate(
|
||||
"你是一个乐于助人的助手!你的任务是对用户可能会问的下一个问题给出建议。 "
|
||||
"\n这是对话历史记录"
|
||||
"\n---------------------\n{conversation}\n---------------------"
|
||||
"考虑到对话历史记录,仅限于现在知识库已有内容, 请给我 $number_of_questions 个你接下来可能会问题的问题!"
|
||||
)
|
||||
N_QUESTION_TO_GENERATE = 3
|
||||
|
||||
|
||||
class NextQuestions(BaseModel):
|
||||
"""A list of questions that user might ask next"""
|
||||
|
||||
questions: List[str]
|
||||
|
||||
|
||||
class NextQuestionSuggestion:
|
||||
@staticmethod
|
||||
async def suggest_next_questions(
|
||||
messages: List[Message],
|
||||
number_of_questions: int = N_QUESTION_TO_GENERATE,
|
||||
) -> List[str]:
|
||||
# Reduce the cost by only using the last two messages
|
||||
last_user_message = None
|
||||
last_assistant_message = None
|
||||
for message in reversed(messages):
|
||||
if message.role == "user":
|
||||
last_user_message = f"User: {message.content}"
|
||||
elif message.role == "assistant":
|
||||
last_assistant_message = f"Assistant: {message.content}"
|
||||
if last_user_message and last_assistant_message:
|
||||
break
|
||||
conversation: str = f"{last_user_message}\n{last_assistant_message}"
|
||||
|
||||
output: NextQuestions = await Settings.llm.astructured_predict(
|
||||
NextQuestions,
|
||||
prompt=NEXT_QUESTIONS_SUGGESTION_PROMPT,
|
||||
conversation=conversation,
|
||||
nun_questions=number_of_questions,
|
||||
)
|
||||
|
||||
return output.questions
|
||||
Reference in New Issue
Block a user