优化了提示词
This commit is contained in:
@@ -0,0 +1,61 @@
|
|||||||
|
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||||
|
from llama_index.core.settings import Settings
|
||||||
|
from typing import Dict
|
||||||
|
import os
|
||||||
|
|
||||||
|
DEFAULT_MODEL = "gpt-3.5-turbo"
|
||||||
|
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large"
|
||||||
|
|
||||||
|
class TSIEmbedding(OpenAIEmbedding):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._query_engine = self._text_engine = self.model_name
|
||||||
|
|
||||||
|
def llm_config_from_env() -> Dict:
|
||||||
|
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||||
|
|
||||||
|
model = os.getenv("MODEL", DEFAULT_MODEL)
|
||||||
|
temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)
|
||||||
|
max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||||
|
api_key = os.getenv("T_SYSTEMS_LLMHUB_API_KEY")
|
||||||
|
api_base = os.getenv("T_SYSTEMS_LLMHUB_BASE_URL")
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"model": model,
|
||||||
|
"api_key": api_key,
|
||||||
|
"api_base": api_base,
|
||||||
|
"temperature": float(temperature),
|
||||||
|
"max_tokens": int(max_tokens) if max_tokens is not None else None,
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def embedding_config_from_env() -> Dict:
|
||||||
|
from llama_index.core.constants import DEFAULT_EMBEDDING_DIM
|
||||||
|
|
||||||
|
model = os.getenv("EMBEDDING_MODEL", DEFAULT_EMBEDDING_MODEL)
|
||||||
|
dimension = os.getenv("EMBEDDING_DIM", DEFAULT_EMBEDDING_DIM)
|
||||||
|
api_key = os.getenv("T_SYSTEMS_LLMHUB_API_KEY")
|
||||||
|
api_base = os.getenv("T_SYSTEMS_LLMHUB_BASE_URL")
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"model_name": model,
|
||||||
|
"dimension": int(dimension) if dimension is not None else None,
|
||||||
|
"api_key": api_key,
|
||||||
|
"api_base": api_base,
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
|
||||||
|
def init_llmhub():
|
||||||
|
from llama_index.llms.openai_like import OpenAILike
|
||||||
|
|
||||||
|
llm_configs = llm_config_from_env()
|
||||||
|
embedding_configs = embedding_config_from_env()
|
||||||
|
|
||||||
|
Settings.embed_model = TSIEmbedding(**embedding_configs)
|
||||||
|
Settings.llm = OpenAILike(
|
||||||
|
**llm_configs,
|
||||||
|
is_chat_model=True,
|
||||||
|
is_function_calling_model=False,
|
||||||
|
context_window=4096,
|
||||||
|
)
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import llama_index.core
|
||||||
|
|
||||||
|
def init_observability():
|
||||||
|
|
||||||
|
PHOENIX_API_KEY = os.getenv("PHOENIX_API_KEY")
|
||||||
|
if not PHOENIX_API_KEY:
|
||||||
|
raise ValueError("PHOENIX_API_KEY environment variable is not set")
|
||||||
|
os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = f"api_key={PHOENIX_API_KEY}"
|
||||||
|
PHOENIX_URL = os.getenv("PHOENIX_URL")
|
||||||
|
llama_index.core.set_global_handler(
|
||||||
|
"arize_phoenix", endpoint=PHOENIX_URL, eval_params={}
|
||||||
|
)
|
||||||
|
|
||||||
|
#debugHandle=[]
|
||||||
|
# llama_debug = LlamaDebugHandler(print_trace_on_end=True)
|
||||||
|
# debugHandle.append(llama_debug)
|
||||||
|
# callback_manager = CallbackManager(debugHandle)
|
||||||
|
# settings.Settings.callback_manager = callback_manager
|
||||||
@@ -0,0 +1,235 @@
|
|||||||
|
import os
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||||
|
from llama_index.core.settings import Settings
|
||||||
|
from llama_index.llms.xinference import Xinference
|
||||||
|
from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP
|
||||||
|
|
||||||
|
from app.xinference.base import XinferenceEmbedding, XinferenceRerank
|
||||||
|
|
||||||
|
|
||||||
|
def get_node_postprocessors():
|
||||||
|
rerank_enabled = os.getenv("RERANK_ENABLED").title()
|
||||||
|
if rerank_enabled is None or rerank_enabled == 'False':
|
||||||
|
return []
|
||||||
|
|
||||||
|
rerank_model = os.getenv("RERANK_MODEL")
|
||||||
|
rerank_url = os.getenv("RERANK_BASE_URL")
|
||||||
|
rerank_top_n = os.getenv("RERANK_TOP_N")
|
||||||
|
rerank_threshold = os.getenv("RERANK_THRESHOLD")
|
||||||
|
postprocess = None
|
||||||
|
if rerank_model is not None:
|
||||||
|
postprocess = [XinferenceRerank(rerank_model, rerank_url, top_n=rerank_top_n, threshold=rerank_threshold)]
|
||||||
|
return postprocess
|
||||||
|
|
||||||
|
def init_settings():
|
||||||
|
model_provider = os.getenv("MODEL_PROVIDER")
|
||||||
|
match model_provider:
|
||||||
|
case "openai":
|
||||||
|
init_openai()
|
||||||
|
case "dashscope":
|
||||||
|
init_dashscope()
|
||||||
|
case "groq":
|
||||||
|
init_groq()
|
||||||
|
case "ollama":
|
||||||
|
init_ollama()
|
||||||
|
case "anthropic":
|
||||||
|
init_anthropic()
|
||||||
|
case "gemini":
|
||||||
|
init_gemini()
|
||||||
|
case "mistral":
|
||||||
|
init_mistral()
|
||||||
|
case "azure-openai":
|
||||||
|
init_azure_openai()
|
||||||
|
case "t-systems":
|
||||||
|
from .llmhub import init_llmhub
|
||||||
|
init_llmhub()
|
||||||
|
case "xinference":
|
||||||
|
init_xinference()
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Invalid model provider: {model_provider}")
|
||||||
|
|
||||||
|
Settings.chunk_size = int(os.getenv("CHUNK_SIZE", "1024"))
|
||||||
|
Settings.chunk_overlap = int(os.getenv("CHUNK_OVERLAP", "20"))
|
||||||
|
|
||||||
|
|
||||||
|
def init_ollama():
|
||||||
|
# from llama_index.embeddings.ollama import OllamaEmbedding
|
||||||
|
# from llama_index.llms.ollama.base import DEFAULT_REQUEST_TIMEOUT, Ollama
|
||||||
|
#
|
||||||
|
# base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434"
|
||||||
|
# request_timeout = float(
|
||||||
|
# os.getenv("OLLAMA_REQUEST_TIMEOUT", DEFAULT_REQUEST_TIMEOUT)
|
||||||
|
# )
|
||||||
|
# Settings.embed_model = OllamaEmbedding(
|
||||||
|
# base_url=base_url,
|
||||||
|
# model_name=os.getenv("EMBEDDING_MODEL"),
|
||||||
|
# )
|
||||||
|
# Settings.llm = Ollama(
|
||||||
|
# base_url=base_url, model=os.getenv("MODEL"), request_timeout=request_timeout
|
||||||
|
# )
|
||||||
|
pass
|
||||||
|
|
||||||
|
def init_xinference():
|
||||||
|
base_url = os.getenv("BASE_URL")
|
||||||
|
model = os.getenv("MODEL")
|
||||||
|
max_tokens = int(os.getenv("LLM_MAX_TOKENS")) if os.getenv("LLM_MAX_TOKENS") is not None else None
|
||||||
|
temperature = float(os.getenv("LLM_TEMPERATURE", DEFAULT_XINFERENCE_TEMP))
|
||||||
|
|
||||||
|
Settings.llm = Xinference(model, base_url, temperature, max_tokens)
|
||||||
|
|
||||||
|
embedding_base_url = os.getenv("EMBEDDING_BASE_URL")
|
||||||
|
embedding_base_url = embedding_base_url if embedding_base_url != None and embedding_base_url != "" else base_url
|
||||||
|
|
||||||
|
embed_model_name = os.getenv("EMBEDDING_MODEL")
|
||||||
|
dimensions = os.getenv("EMBEDDING_DIM")
|
||||||
|
dimensions = int(dimensions) if dimensions is not None else None
|
||||||
|
Settings.embed_model = XinferenceEmbedding(embed_model_name, embedding_base_url, dimensions=dimensions)
|
||||||
|
|
||||||
|
def init_openai():
|
||||||
|
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||||
|
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||||
|
from llama_index.llms.openai import OpenAI
|
||||||
|
|
||||||
|
max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||||
|
config = {
|
||||||
|
"model": os.getenv("MODEL"),
|
||||||
|
"temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
|
||||||
|
"max_tokens": int(max_tokens) if max_tokens is not None else None,
|
||||||
|
}
|
||||||
|
Settings.llm = OpenAI(**config)
|
||||||
|
|
||||||
|
dimensions = os.getenv("EMBEDDING_DIM")
|
||||||
|
config = {
|
||||||
|
"model": os.getenv("EMBEDDING_MODEL"),
|
||||||
|
"dimensions": int(dimensions) if dimensions is not None else None,
|
||||||
|
}
|
||||||
|
Settings.embed_model = OpenAIEmbedding(**config)
|
||||||
|
|
||||||
|
def init_dashscope():
|
||||||
|
from llama_index.llms.dashscope import DashScope,DashScopeGenerationModels
|
||||||
|
from llama_index.embeddings.dashscope import DashScopeEmbedding,DashScopeBatchTextEmbeddingModels,DashScopeTextEmbeddingType,DashScopeTextEmbeddingModels
|
||||||
|
|
||||||
|
max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||||
|
config = {
|
||||||
|
"model": os.getenv("MODEL"),
|
||||||
|
"temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
|
||||||
|
"max_tokens": int(max_tokens) if max_tokens is not None else None,
|
||||||
|
}
|
||||||
|
Settings.llm = llm = DashScope(model_name=DashScopeGenerationModels.QWEN_MAX)
|
||||||
|
|
||||||
|
dimensions = os.getenv("EMBEDDING_DIM")
|
||||||
|
config = {
|
||||||
|
"model": os.getenv("EMBEDDING_MODEL"),
|
||||||
|
"dimensions": int(dimensions) if dimensions is not None else None,
|
||||||
|
}
|
||||||
|
Settings.embed_model = DashScopeEmbedding(model_name=DashScopeTextEmbeddingModels.TEXT_EMBEDDING_V2,
|
||||||
|
text_type=DashScopeTextEmbeddingType.TEXT_TYPE_QUERY)
|
||||||
|
|
||||||
|
|
||||||
|
def init_azure_openai():
|
||||||
|
# from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||||
|
# from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||||
|
# from llama_index.llms.azure_openai import AzureOpenAI
|
||||||
|
#
|
||||||
|
# llm_deployment = os.environ["AZURE_OPENAI_LLM_DEPLOYMENT"]
|
||||||
|
# embedding_deployment = os.environ["AZURE_OPENAI_EMBEDDING_DEPLOYMENT"]
|
||||||
|
# max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||||
|
# temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)
|
||||||
|
# dimensions = os.getenv("EMBEDDING_DIM")
|
||||||
|
#
|
||||||
|
# azure_config = {
|
||||||
|
# "api_key": os.environ["AZURE_OPENAI_KEY"],
|
||||||
|
# "azure_endpoint": os.environ["AZURE_OPENAI_ENDPOINT"],
|
||||||
|
# "api_version": os.getenv("AZURE_OPENAI_API_VERSION")
|
||||||
|
# or os.getenv("OPENAI_API_VERSION"),
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# Settings.llm = AzureOpenAI(
|
||||||
|
# model=os.getenv("MODEL"),
|
||||||
|
# max_tokens=int(max_tokens) if max_tokens is not None else None,
|
||||||
|
# temperature=float(temperature),
|
||||||
|
# deployment_name=llm_deployment,
|
||||||
|
# **azure_config,
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# Settings.embed_model = AzureOpenAIEmbedding(
|
||||||
|
# model=os.getenv("EMBEDDING_MODEL"),
|
||||||
|
# dimensions=int(dimensions) if dimensions is not None else None,
|
||||||
|
# deployment_name=embedding_deployment,
|
||||||
|
# **azure_config,
|
||||||
|
# )
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def init_fastembed():
|
||||||
|
"""
|
||||||
|
Use Qdrant Fastembed as the local embedding provider.
|
||||||
|
"""
|
||||||
|
# from llama_index.embeddings.fastembed import FastEmbedEmbedding
|
||||||
|
#
|
||||||
|
# embed_model_map: Dict[str, str] = {
|
||||||
|
# # Small and multilingual
|
||||||
|
# "all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
# # Large and multilingual
|
||||||
|
# "paraphrase-multilingual-mpnet-base-v2": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", # noqa: E501
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# # This will download the model automatically if it is not already downloaded
|
||||||
|
# Settings.embed_model = FastEmbedEmbedding(
|
||||||
|
# model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
|
||||||
|
# )
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def init_groq():
|
||||||
|
# from llama_index.llms.groq import Groq
|
||||||
|
#
|
||||||
|
# model_map: Dict[str, str] = {
|
||||||
|
# "llama3-8b": "llama3-8b-8192",
|
||||||
|
# "llama3-70b": "llama3-70b-8192",
|
||||||
|
# "mixtral-8x7b": "mixtral-8x7b-32768",
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# Settings.llm = Groq(model=model_map[os.getenv("MODEL")])
|
||||||
|
# # Groq does not provide embeddings, so we use FastEmbed instead
|
||||||
|
# init_fastembed()
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def init_anthropic():
|
||||||
|
# from llama_index.llms.anthropic import Anthropic
|
||||||
|
#
|
||||||
|
# model_map: Dict[str, str] = {
|
||||||
|
# "claude-3-opus": "claude-3-opus-20240229",
|
||||||
|
# "claude-3-sonnet": "claude-3-sonnet-20240229",
|
||||||
|
# "claude-3-haiku": "claude-3-haiku-20240307",
|
||||||
|
# "claude-2.1": "claude-2.1",
|
||||||
|
# "claude-instant-1.2": "claude-instant-1.2",
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# Settings.llm = Anthropic(model=model_map[os.getenv("MODEL")])
|
||||||
|
# # Anthropic does not provide embeddings, so we use FastEmbed instead
|
||||||
|
# init_fastembed()
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def init_gemini():
|
||||||
|
# from llama_index.embeddings.gemini import GeminiEmbedding
|
||||||
|
# from llama_index.llms.gemini import Gemini
|
||||||
|
#
|
||||||
|
# model_name = f"models/{os.getenv('MODEL')}"
|
||||||
|
# embed_model_name = f"models/{os.getenv('EMBEDDING_MODEL')}"
|
||||||
|
#
|
||||||
|
# Settings.llm = Gemini(model=model_name)
|
||||||
|
# Settings.embed_model = GeminiEmbedding(model_name=embed_model_name)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def init_mistral():
|
||||||
|
# from llama_index.embeddings.mistralai import MistralAIEmbedding
|
||||||
|
# from llama_index.llms.mistralai import MistralAI
|
||||||
|
#
|
||||||
|
# Settings.llm = MistralAI(model=os.getenv("MODEL"))
|
||||||
|
# Settings.embed_model = MistralAIEmbedding(model_name=os.getenv("EMBEDDING_MODEL"))
|
||||||
|
pass
|
||||||
@@ -0,0 +1,150 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, status
|
||||||
|
from llama_index.core.chat_engine.types import BaseChatEngine, NodeWithScore
|
||||||
|
from llama_index.core.llms import MessageRole
|
||||||
|
from llama_index.core.vector_stores.types import MetadataFilter, MetadataFilters
|
||||||
|
|
||||||
|
from app.api.routers.events import EventCallbackHandler
|
||||||
|
from app.api.routers.models import (
|
||||||
|
ChatConfig,
|
||||||
|
ChatData,
|
||||||
|
Message,
|
||||||
|
Result,
|
||||||
|
SourceNodes,
|
||||||
|
)
|
||||||
|
from app.api.routers.vercel_response import VercelStreamResponse
|
||||||
|
from app.api.services.llama_cloud import LLamaCloudFileService
|
||||||
|
from app.engine import get_chat_engine
|
||||||
|
|
||||||
|
chat_router = r = APIRouter()
|
||||||
|
|
||||||
|
logger = logging.getLogger("uvicorn")
|
||||||
|
|
||||||
|
|
||||||
|
def process_response_nodes(
|
||||||
|
nodes: List[NodeWithScore],
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Start background tasks on the source nodes if needed.
|
||||||
|
"""
|
||||||
|
files_to_download = SourceNodes.get_download_files(nodes)
|
||||||
|
for file in files_to_download:
|
||||||
|
background_tasks.add_task(
|
||||||
|
LLamaCloudFileService.download_llamacloud_pipeline_file, file
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# streaming endpoint - delete if not needed
|
||||||
|
@r.post("")
|
||||||
|
async def chat(
|
||||||
|
request: Request,
|
||||||
|
data: ChatData,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
chat_engine: BaseChatEngine = Depends(get_chat_engine),
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
last_message_content = data.get_last_message_content()
|
||||||
|
# 由于基于历史消息的提示词没有调整好,所以暂时屏蔽历史消息
|
||||||
|
data.messages.clear()
|
||||||
|
messages = data.get_history_messages()
|
||||||
|
|
||||||
|
doc_ids = data.get_chat_document_ids()
|
||||||
|
filters = generate_filters(doc_ids)
|
||||||
|
params = data.data or {}
|
||||||
|
logger.info("Creating chat engine with filters", filters.dict())
|
||||||
|
chat_engine = get_chat_engine(filters=filters, params=params)
|
||||||
|
|
||||||
|
event_handler = EventCallbackHandler()
|
||||||
|
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
|
||||||
|
|
||||||
|
response = await chat_engine.astream_chat(last_message_content, messages)
|
||||||
|
process_response_nodes(response.source_nodes, background_tasks)
|
||||||
|
|
||||||
|
return VercelStreamResponse(request, event_handler, response, data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error in chat engine", exc_info=True)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Error in chat engine: {e}",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
def generate_filters(doc_ids):
|
||||||
|
if len(doc_ids) > 0:
|
||||||
|
filters = MetadataFilters(
|
||||||
|
filters=[
|
||||||
|
MetadataFilter(
|
||||||
|
key="private",
|
||||||
|
value=["true"],
|
||||||
|
operator="nin", # type: ignore
|
||||||
|
),
|
||||||
|
MetadataFilter(
|
||||||
|
key="doc_id",
|
||||||
|
value=doc_ids,
|
||||||
|
operator="in", # type: ignore
|
||||||
|
),
|
||||||
|
],
|
||||||
|
condition="or", # type: ignore
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
filters = MetadataFilters(
|
||||||
|
# Use the "NIN" - "not in" operator to include all public documents (don't have the private key set)
|
||||||
|
filters=[
|
||||||
|
MetadataFilter(
|
||||||
|
key="private",
|
||||||
|
value=["true"],
|
||||||
|
operator="nin", # type: ignore
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return filters
|
||||||
|
|
||||||
|
|
||||||
|
# non-streaming endpoint - delete if not needed
|
||||||
|
@r.post("/request")
|
||||||
|
async def chat_request(
|
||||||
|
data: ChatData,
|
||||||
|
chat_engine: BaseChatEngine = Depends(get_chat_engine),
|
||||||
|
) -> Result:
|
||||||
|
last_message_content = data.get_last_message_content()
|
||||||
|
messages = data.get_history_messages()
|
||||||
|
|
||||||
|
response = await chat_engine.achat(last_message_content, messages)
|
||||||
|
return Result(
|
||||||
|
result=Message(role=MessageRole.ASSISTANT, content=response.response),
|
||||||
|
nodes=SourceNodes.from_source_nodes(response.source_nodes),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@r.get("/config")
|
||||||
|
async def chat_config() -> ChatConfig:
|
||||||
|
starter_questions = None
|
||||||
|
conversation_starters = os.getenv("CONVERSATION_STARTERS")
|
||||||
|
if conversation_starters and conversation_starters.strip():
|
||||||
|
starter_questions = conversation_starters.strip().split("\\n")
|
||||||
|
return ChatConfig(starter_questions=starter_questions)
|
||||||
|
|
||||||
|
|
||||||
|
@r.get("/config/llamacloud")
|
||||||
|
async def chat_llama_cloud_config():
|
||||||
|
projects = LLamaCloudFileService.get_all_projects_with_pipelines()
|
||||||
|
pipeline = os.getenv("LLAMA_CLOUD_INDEX_NAME")
|
||||||
|
project = os.getenv("LLAMA_CLOUD_PROJECT_NAME")
|
||||||
|
pipeline_config = (
|
||||||
|
pipeline
|
||||||
|
and project
|
||||||
|
and {
|
||||||
|
"pipeline": pipeline,
|
||||||
|
"project": project,
|
||||||
|
}
|
||||||
|
or None
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"projects": projects,
|
||||||
|
"pipeline": pipeline_config,
|
||||||
|
}
|
||||||
@@ -0,0 +1,149 @@
|
|||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import AsyncGenerator, Dict, Any, List, Optional
|
||||||
|
from llama_index.core.callbacks.base import BaseCallbackHandler
|
||||||
|
from llama_index.core.callbacks.schema import CBEventType
|
||||||
|
from llama_index.core.tools.types import ToolOutput
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CallbackEvent(BaseModel):
|
||||||
|
event_type: CBEventType
|
||||||
|
payload: Optional[Dict[str, Any]] = None
|
||||||
|
event_id: str = ""
|
||||||
|
|
||||||
|
def get_retrieval_message(self) -> dict | None:
|
||||||
|
if self.payload:
|
||||||
|
nodes = self.payload.get("nodes")
|
||||||
|
if nodes:
|
||||||
|
msg = f"根据查询检索到 {len(nodes)} 源文件"
|
||||||
|
else:
|
||||||
|
msg = f"查询检索中: '{self.payload.get('query_str')}'"
|
||||||
|
return {
|
||||||
|
"type": "events",
|
||||||
|
"data": {"title": msg},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_tool_message(self) -> dict | None:
|
||||||
|
func_call_args = self.payload.get("function_call")
|
||||||
|
if func_call_args is not None and "tool" in self.payload:
|
||||||
|
tool = self.payload.get("tool")
|
||||||
|
return {
|
||||||
|
"type": "events",
|
||||||
|
"data": {
|
||||||
|
"title": f"调用工具 {tool.name} ,参数: {func_call_args}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _is_output_serializable(self, output: Any) -> bool:
|
||||||
|
try:
|
||||||
|
json.dumps(output)
|
||||||
|
return True
|
||||||
|
except TypeError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_agent_tool_response(self) -> dict | None:
|
||||||
|
response = self.payload.get("response")
|
||||||
|
if response is not None:
|
||||||
|
sources = response.sources
|
||||||
|
for source in sources:
|
||||||
|
# Return the tool response here to include the toolCall information
|
||||||
|
if isinstance(source, ToolOutput):
|
||||||
|
if self._is_output_serializable(source.raw_output):
|
||||||
|
output = source.raw_output
|
||||||
|
else:
|
||||||
|
output = source.content
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "tools",
|
||||||
|
"data": {
|
||||||
|
"toolOutput": {
|
||||||
|
"output": output,
|
||||||
|
"isError": source.is_error,
|
||||||
|
},
|
||||||
|
"toolCall": {
|
||||||
|
"id": None, # There is no tool id in the ToolOutput
|
||||||
|
"name": source.tool_name,
|
||||||
|
"input": source.raw_input,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_response(self):
|
||||||
|
try:
|
||||||
|
match self.event_type:
|
||||||
|
case "retrieve":
|
||||||
|
return self.get_retrieval_message()
|
||||||
|
case "function_call":
|
||||||
|
return self.get_tool_message()
|
||||||
|
case "agent_step":
|
||||||
|
return self.get_agent_tool_response()
|
||||||
|
case _:
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"转换回应时间时发生错误,原因: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class EventCallbackHandler(BaseCallbackHandler):
|
||||||
|
_aqueue: asyncio.Queue
|
||||||
|
is_done: bool = False
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
"""Initialize the base callback handler."""
|
||||||
|
ignored_events = [
|
||||||
|
CBEventType.CHUNKING,
|
||||||
|
CBEventType.NODE_PARSING,
|
||||||
|
CBEventType.EMBEDDING,
|
||||||
|
CBEventType.LLM,
|
||||||
|
CBEventType.TEMPLATING,
|
||||||
|
]
|
||||||
|
super().__init__(ignored_events, ignored_events)
|
||||||
|
self._aqueue = asyncio.Queue()
|
||||||
|
|
||||||
|
def on_event_start(
|
||||||
|
self,
|
||||||
|
event_type: CBEventType,
|
||||||
|
payload: Optional[Dict[str, Any]] = None,
|
||||||
|
event_id: str = "",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
||||||
|
if event.to_response() is not None:
|
||||||
|
self._aqueue.put_nowait(event)
|
||||||
|
|
||||||
|
def on_event_end(
|
||||||
|
self,
|
||||||
|
event_type: CBEventType,
|
||||||
|
payload: Optional[Dict[str, Any]] = None,
|
||||||
|
event_id: str = "",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
||||||
|
if event.to_response() is not None:
|
||||||
|
self._aqueue.put_nowait(event)
|
||||||
|
|
||||||
|
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||||
|
"""No-op."""
|
||||||
|
|
||||||
|
def end_trace(
|
||||||
|
self,
|
||||||
|
trace_id: Optional[str] = None,
|
||||||
|
trace_map: Optional[Dict[str, List[str]]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""No-op."""
|
||||||
|
|
||||||
|
async def async_event_gen(self) -> AsyncGenerator[CallbackEvent, None]:
|
||||||
|
while not self._aqueue.empty() or not self.is_done:
|
||||||
|
try:
|
||||||
|
yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
@@ -0,0 +1,253 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Set
|
||||||
|
|
||||||
|
from llama_index.core.llms import ChatMessage, MessageRole
|
||||||
|
from llama_index.core.schema import NodeWithScore
|
||||||
|
from pydantic import BaseModel, Field, validator, field_validator
|
||||||
|
from pydantic.alias_generators import to_camel
|
||||||
|
|
||||||
|
logger = logging.getLogger("uvicorn")
|
||||||
|
|
||||||
|
|
||||||
|
class FileContent(BaseModel):
|
||||||
|
type: Literal["text", "ref"]
|
||||||
|
# If the file is pure text then the value is be a string
|
||||||
|
# otherwise, it's a list of document IDs
|
||||||
|
value: str | List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class File(BaseModel):
|
||||||
|
id: str
|
||||||
|
content: FileContent
|
||||||
|
filename: str
|
||||||
|
filesize: int
|
||||||
|
filetype: str
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationFileData(BaseModel):
|
||||||
|
files: List[File] = Field(
|
||||||
|
default=[],
|
||||||
|
description="List of files",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"csvFiles": [
|
||||||
|
{
|
||||||
|
"content": "Name, Age\nAlice, 25\nBob, 30",
|
||||||
|
"filename": "example.csv",
|
||||||
|
"filesize": 123,
|
||||||
|
"id": "123",
|
||||||
|
"type": "text/csv",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
alias_generator = to_camel
|
||||||
|
|
||||||
|
|
||||||
|
class Annotation(BaseModel):
|
||||||
|
type: str
|
||||||
|
data: AnnotationFileData | List[str]
|
||||||
|
|
||||||
|
def to_content(self) -> str | None:
|
||||||
|
if self.type == "document_file":
|
||||||
|
# We only support generating context content for CSV files for now
|
||||||
|
csv_files = [file for file in self.data.files if file.filetype == "csv"]
|
||||||
|
if len(csv_files) > 0:
|
||||||
|
return "Use data from following CSV raw content\n" + "\n".join(
|
||||||
|
[f"```csv\n{csv_file.content.value}\n```" for csv_file in csv_files]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"The annotation {self.type} is not supported for generating context content"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
role: MessageRole
|
||||||
|
content: str
|
||||||
|
annotations: List[Annotation] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatData(BaseModel):
|
||||||
|
messages: List[Message]
|
||||||
|
data: Any = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What standards for letters exist?",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@field_validator("messages")
|
||||||
|
def messages_must_not_be_empty(cls, v):
|
||||||
|
if len(v) == 0:
|
||||||
|
raise ValueError("Messages must not be empty")
|
||||||
|
return v
|
||||||
|
|
||||||
|
def get_last_message_content(self) -> str:
|
||||||
|
"""
|
||||||
|
Get the content of the last message along with the data content if available.
|
||||||
|
Fallback to use data content from previous messages
|
||||||
|
"""
|
||||||
|
if len(self.messages) == 0:
|
||||||
|
raise ValueError("There is not any message in the chat")
|
||||||
|
last_message = self.messages[-1]
|
||||||
|
message_content = last_message.content
|
||||||
|
for message in reversed(self.messages):
|
||||||
|
if message.role == MessageRole.USER and message.annotations is not None:
|
||||||
|
annotation_contents = filter(
|
||||||
|
None,
|
||||||
|
[annotation.to_content() for annotation in message.annotations],
|
||||||
|
)
|
||||||
|
if not annotation_contents:
|
||||||
|
continue
|
||||||
|
annotation_text = "\n".join(annotation_contents)
|
||||||
|
message_content = f"{message_content}\n{annotation_text}"
|
||||||
|
break
|
||||||
|
return message_content
|
||||||
|
|
||||||
|
def get_history_messages(self) -> List[ChatMessage]:
|
||||||
|
"""
|
||||||
|
Get the history messages
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
ChatMessage(role=message.role, content=message.content)
|
||||||
|
for message in self.messages[:-1]
|
||||||
|
]
|
||||||
|
|
||||||
|
def is_last_message_from_user(self) -> bool:
|
||||||
|
return self.messages[-1].role == MessageRole.USER
|
||||||
|
|
||||||
|
def get_chat_document_ids(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get the document IDs from the chat messages
|
||||||
|
"""
|
||||||
|
document_ids: List[str] = []
|
||||||
|
for message in self.messages:
|
||||||
|
if message.role == MessageRole.USER and message.annotations is not None:
|
||||||
|
for annotation in message.annotations:
|
||||||
|
if (
|
||||||
|
annotation.type == "document_file"
|
||||||
|
and annotation.data.files is not None
|
||||||
|
):
|
||||||
|
for fi in annotation.data.files:
|
||||||
|
if fi.content.type == "ref":
|
||||||
|
document_ids += fi.content.value
|
||||||
|
return list(set(document_ids))
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaCloudFile(BaseModel):
|
||||||
|
file_name: str
|
||||||
|
pipeline_id: str
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if not isinstance(other, LlamaCloudFile):
|
||||||
|
return NotImplemented
|
||||||
|
return (
|
||||||
|
self.file_name == other.file_name and self.pipeline_id == other.pipeline_id
|
||||||
|
)
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash((self.file_name, self.pipeline_id))
|
||||||
|
|
||||||
|
|
||||||
|
class SourceNodes(BaseModel):
|
||||||
|
id: str
|
||||||
|
metadata: Dict[str, Any]
|
||||||
|
score: Optional[float]
|
||||||
|
text: str
|
||||||
|
url: Optional[str]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_source_node(cls, source_node: NodeWithScore):
|
||||||
|
metadata = source_node.node.metadata
|
||||||
|
url = cls.get_url_from_metadata(metadata)
|
||||||
|
#text = 'filename' in metadata and metadata['filename'] or source_node.node.node_id
|
||||||
|
text = source_node.node.text
|
||||||
|
return cls(
|
||||||
|
id=source_node.node.node_id,
|
||||||
|
metadata=metadata,
|
||||||
|
score=source_node.score,
|
||||||
|
text=text, # type: ignore
|
||||||
|
url=url,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_url_from_metadata(cls, metadata: Dict[str, Any]) -> str:
|
||||||
|
url_prefix = os.getenv("FILESERVER_URL_PREFIX")
|
||||||
|
if not url_prefix:
|
||||||
|
logger.warning(
|
||||||
|
"Warning: FILESERVER_URL_PREFIX not set in environment variables. Can't use file server"
|
||||||
|
)
|
||||||
|
file_name = metadata.get("file_name")
|
||||||
|
if file_name and url_prefix:
|
||||||
|
# file_name exists and file server is configured
|
||||||
|
pipeline_id = metadata.get("pipeline_id")
|
||||||
|
if pipeline_id and metadata.get("private") is None:
|
||||||
|
# file is from LlamaCloud and was not ingested locally
|
||||||
|
file_name = f"{pipeline_id}${file_name}"
|
||||||
|
return f"{url_prefix}/output/llamacloud/{file_name}"
|
||||||
|
is_private = metadata.get("private", "false") == "true"
|
||||||
|
if is_private:
|
||||||
|
return f"{url_prefix}/output/uploaded/{file_name}"
|
||||||
|
return f"{url_prefix}/data/{file_name}"
|
||||||
|
else:
|
||||||
|
# fallback to URL in metadata (e.g. for websites)
|
||||||
|
return metadata.get("URL")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_source_nodes(cls, source_nodes: List[NodeWithScore]):
|
||||||
|
return [cls.from_source_node(node) for node in source_nodes]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_download_files(nodes: List[NodeWithScore]) -> Set[LlamaCloudFile]:
|
||||||
|
source_nodes = SourceNodes.from_source_nodes(nodes)
|
||||||
|
llama_cloud_files = [
|
||||||
|
LlamaCloudFile(
|
||||||
|
file_name=node.metadata.get("file_name"),
|
||||||
|
pipeline_id=node.metadata.get("pipeline_id"),
|
||||||
|
)
|
||||||
|
for node in source_nodes
|
||||||
|
if (
|
||||||
|
node.metadata.get("private")
|
||||||
|
is None # Only download files are from LlamaCloud and were not ingested locally
|
||||||
|
and node.metadata.get("pipeline_id") is not None
|
||||||
|
and node.metadata.get("file_name") is not None
|
||||||
|
)
|
||||||
|
]
|
||||||
|
# Remove duplicates and return
|
||||||
|
return set(llama_cloud_files)
|
||||||
|
|
||||||
|
|
||||||
|
class Result(BaseModel):
|
||||||
|
result: Message
|
||||||
|
nodes: List[SourceNodes]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatConfig(BaseModel):
|
||||||
|
starter_questions: Optional[List[str]] = Field(
|
||||||
|
default=None,
|
||||||
|
description="List of starter questions",
|
||||||
|
serialization_alias="starterQuestions",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"starterQuestions": [
|
||||||
|
"What standards for letters exist?",
|
||||||
|
"What are the requirements for a letter to be considered a letter?",
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app.api.services.file import PrivateFileService
|
||||||
|
|
||||||
|
file_upload_router = r = APIRouter()
|
||||||
|
|
||||||
|
logger = logging.getLogger("uvicorn")
|
||||||
|
|
||||||
|
|
||||||
|
class FileUploadRequest(BaseModel):
|
||||||
|
base64: str
|
||||||
|
|
||||||
|
|
||||||
|
@r.post("")
|
||||||
|
def upload_file(request: FileUploadRequest) -> List[str]:
|
||||||
|
try:
|
||||||
|
logger.info("Processing file")
|
||||||
|
return PrivateFileService.process_file(request.base64)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing file: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail="Error processing file")
|
||||||
@@ -0,0 +1,109 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from aiostream import stream
|
||||||
|
from fastapi import Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
|
||||||
|
|
||||||
|
from app.api.routers.events import EventCallbackHandler
|
||||||
|
from app.api.routers.models import ChatData, Message, SourceNodes
|
||||||
|
from app.api.services.suggestion import NextQuestionSuggestion
|
||||||
|
|
||||||
|
|
||||||
|
class VercelStreamResponse(StreamingResponse):
|
||||||
|
"""
|
||||||
|
Class to convert the response from the chat engine to the streaming format expected by Vercel
|
||||||
|
"""
|
||||||
|
|
||||||
|
TEXT_PREFIX = "0:"
|
||||||
|
DATA_PREFIX = "8:"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_text(cls, token: str):
|
||||||
|
# Escape newlines and double quotes to avoid breaking the stream
|
||||||
|
token = json.dumps(token)
|
||||||
|
return f"{cls.TEXT_PREFIX}{token}\n"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_data(cls, data: dict):
|
||||||
|
data_str = json.dumps(data)
|
||||||
|
return f"{cls.DATA_PREFIX}[{data_str}]\n"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
request: Request,
|
||||||
|
event_handler: EventCallbackHandler,
|
||||||
|
response: StreamingAgentChatResponse,
|
||||||
|
chat_data: ChatData,
|
||||||
|
):
|
||||||
|
content = VercelStreamResponse.content_generator(
|
||||||
|
request, event_handler, response, chat_data
|
||||||
|
)
|
||||||
|
super().__init__(content=content)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def content_generator(
|
||||||
|
cls,
|
||||||
|
request: Request,
|
||||||
|
event_handler: EventCallbackHandler,
|
||||||
|
response: StreamingAgentChatResponse,
|
||||||
|
chat_data: ChatData,
|
||||||
|
):
|
||||||
|
# Yield the text response
|
||||||
|
async def _chat_response_generator():
|
||||||
|
final_response = ""
|
||||||
|
async for token in response.async_response_gen():
|
||||||
|
final_response += token
|
||||||
|
yield VercelStreamResponse.convert_text(token)
|
||||||
|
|
||||||
|
# Generate questions that user might interested to
|
||||||
|
conversation = chat_data.messages + [
|
||||||
|
Message(role="assistant", content=final_response)
|
||||||
|
]
|
||||||
|
questions = await NextQuestionSuggestion.suggest_next_questions(
|
||||||
|
conversation
|
||||||
|
)
|
||||||
|
if len(questions) > 0:
|
||||||
|
yield VercelStreamResponse.convert_data(
|
||||||
|
{
|
||||||
|
"type": "suggested_questions",
|
||||||
|
"data": questions,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# the text_generator is the leading stream, once it's finished, also finish the event stream
|
||||||
|
event_handler.is_done = True
|
||||||
|
|
||||||
|
# Yield the source nodes
|
||||||
|
yield cls.convert_data(
|
||||||
|
{
|
||||||
|
"type": "sources",
|
||||||
|
"data": {
|
||||||
|
"nodes": [
|
||||||
|
SourceNodes.from_source_node(node).dict()
|
||||||
|
for node in response.source_nodes
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Yield the events from the event handler
|
||||||
|
async def _event_generator():
|
||||||
|
async for event in event_handler.async_event_gen():
|
||||||
|
event_response = event.to_response()
|
||||||
|
if event_response is not None:
|
||||||
|
yield VercelStreamResponse.convert_data(event_response)
|
||||||
|
|
||||||
|
combine = stream.merge(_chat_response_generator(), _event_generator())
|
||||||
|
is_stream_started = False
|
||||||
|
async with combine.stream() as streamer:
|
||||||
|
async for output in streamer:
|
||||||
|
if not is_stream_started:
|
||||||
|
is_stream_started = True
|
||||||
|
# Stream a blank message to start the stream
|
||||||
|
yield VercelStreamResponse.convert_text("")
|
||||||
|
|
||||||
|
yield output
|
||||||
|
|
||||||
|
if await request.is_disconnected():
|
||||||
|
break
|
||||||
@@ -0,0 +1,113 @@
|
|||||||
|
import base64
|
||||||
|
import mimetypes
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from app.engine.index import get_index
|
||||||
|
from llama_index.core import VectorStoreIndex
|
||||||
|
from llama_index.core.ingestion import IngestionPipeline
|
||||||
|
from llama_index.core.readers.file.base import (
|
||||||
|
_try_loading_included_file_formats as get_file_loaders_map,
|
||||||
|
)
|
||||||
|
from llama_index.core.readers.file.base import (
|
||||||
|
default_file_metadata_func,
|
||||||
|
)
|
||||||
|
from llama_index.core.schema import Document
|
||||||
|
from llama_index.indices.managed.llama_cloud.base import LlamaCloudIndex
|
||||||
|
from llama_index.readers.file import FlatReader
|
||||||
|
|
||||||
|
|
||||||
|
def get_llamaparse_parser():
|
||||||
|
from app.engine.loaders import load_configs
|
||||||
|
from app.engine.loaders.file import FileLoaderConfig, llama_parse_parser
|
||||||
|
|
||||||
|
config = load_configs()
|
||||||
|
file_loader_config = FileLoaderConfig(**config["file"])
|
||||||
|
if file_loader_config.use_llama_parse:
|
||||||
|
return llama_parse_parser()
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def default_file_loaders_map():
|
||||||
|
default_loaders = get_file_loaders_map()
|
||||||
|
default_loaders[".txt"] = FlatReader
|
||||||
|
return default_loaders
|
||||||
|
|
||||||
|
|
||||||
|
class PrivateFileService:
|
||||||
|
PRIVATE_STORE_PATH = "output/uploaded"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def preprocess_base64_file(base64_content: str) -> tuple:
|
||||||
|
header, data = base64_content.split(",", 1)
|
||||||
|
mime_type = header.split(";")[0].split(":", 1)[1]
|
||||||
|
extension = mimetypes.guess_extension(mime_type)
|
||||||
|
# File data as bytes
|
||||||
|
return base64.b64decode(data), extension
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def store_and_parse_file(file_data, extension) -> List[Document]:
|
||||||
|
# Store file to the private directory
|
||||||
|
os.makedirs(PrivateFileService.PRIVATE_STORE_PATH, exist_ok=True)
|
||||||
|
|
||||||
|
# random file name
|
||||||
|
file_name = f"{uuid4().hex}{extension}"
|
||||||
|
file_path = Path(os.path.join(PrivateFileService.PRIVATE_STORE_PATH, file_name))
|
||||||
|
|
||||||
|
# write file
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(file_data)
|
||||||
|
|
||||||
|
# Load file to documents
|
||||||
|
# If LlamaParse is enabled, use it to parse the file
|
||||||
|
# Otherwise, use the default file loaders
|
||||||
|
reader = get_llamaparse_parser()
|
||||||
|
if reader is None:
|
||||||
|
reader_cls = default_file_loaders_map().get(extension)
|
||||||
|
if reader_cls is None:
|
||||||
|
raise ValueError(f"File extension {extension} is not supported")
|
||||||
|
reader = reader_cls()
|
||||||
|
documents = reader.load_data(file_path)
|
||||||
|
# Add custom metadata
|
||||||
|
for doc in documents:
|
||||||
|
doc.metadata["file_name"] = file_name
|
||||||
|
doc.metadata["private"] = "true"
|
||||||
|
return documents
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def process_file(base64_content: str) -> List[str]:
|
||||||
|
file_data, extension = PrivateFileService.preprocess_base64_file(base64_content)
|
||||||
|
documents = PrivateFileService.store_and_parse_file(file_data, extension)
|
||||||
|
|
||||||
|
# Only process nodes, no store the index
|
||||||
|
pipeline = IngestionPipeline()
|
||||||
|
nodes = pipeline.run(documents=documents)
|
||||||
|
|
||||||
|
# Add the nodes to the index and persist it
|
||||||
|
current_index = get_index()
|
||||||
|
|
||||||
|
# Insert the documents into the index
|
||||||
|
if isinstance(current_index, LlamaCloudIndex):
|
||||||
|
# LlamaCloudIndex is a managed index so we don't need to process the nodes
|
||||||
|
# just insert the documents
|
||||||
|
for doc in documents:
|
||||||
|
current_index.insert(doc)
|
||||||
|
else:
|
||||||
|
# Only process nodes, no store the index
|
||||||
|
pipeline = IngestionPipeline()
|
||||||
|
nodes = pipeline.run(documents=documents)
|
||||||
|
|
||||||
|
# Add the nodes to the index and persist it
|
||||||
|
if current_index is None:
|
||||||
|
current_index = VectorStoreIndex(nodes=nodes)
|
||||||
|
else:
|
||||||
|
current_index.insert_nodes(nodes=nodes)
|
||||||
|
current_index.storage_context.persist(
|
||||||
|
persist_dir=os.environ.get("STORAGE_DIR", "storage")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return the document ids
|
||||||
|
return [doc.doc_id for doc in documents]
|
||||||
@@ -0,0 +1,114 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from app.api.routers.models import LlamaCloudFile
|
||||||
|
|
||||||
|
logger = logging.getLogger("uvicorn")
|
||||||
|
|
||||||
|
|
||||||
|
class LLamaCloudFileService:
|
||||||
|
LLAMA_CLOUD_URL = "https://cloud.llamaindex.ai/api/v1"
|
||||||
|
LOCAL_STORE_PATH = "output/llamacloud"
|
||||||
|
|
||||||
|
DOWNLOAD_FILE_NAME_TPL = "{pipeline_id}${filename}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_all_projects(cls) -> List[Dict[str, Any]]:
|
||||||
|
url = f"{cls.LLAMA_CLOUD_URL}/projects"
|
||||||
|
return cls._make_request(url)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_all_pipelines(cls) -> List[Dict[str, Any]]:
|
||||||
|
url = f"{cls.LLAMA_CLOUD_URL}/pipelines"
|
||||||
|
return cls._make_request(url)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_all_projects_with_pipelines(cls) -> List[Dict[str, Any]]:
|
||||||
|
try:
|
||||||
|
projects = cls.get_all_projects()
|
||||||
|
pipelines = cls.get_all_pipelines()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
**project,
|
||||||
|
"pipelines": [p for p in pipelines if p["project_id"] == project["id"]],
|
||||||
|
}
|
||||||
|
for project in projects
|
||||||
|
]
|
||||||
|
except Exception as error:
|
||||||
|
logger.error(f"Error listing projects and pipelines: {error}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_files(cls, pipeline_id: str) -> List[Dict[str, Any]]:
|
||||||
|
url = f"{cls.LLAMA_CLOUD_URL}/pipelines/{pipeline_id}/files"
|
||||||
|
return cls._make_request(url)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_file_detail(cls, project_id: str, file_id: str) -> Dict[str, Any]:
|
||||||
|
url = f"{cls.LLAMA_CLOUD_URL}/files/{file_id}/content?project_id={project_id}"
|
||||||
|
return cls._make_request(url)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _download_file(cls, url: str, local_file_path: str):
|
||||||
|
logger.info(f"Downloading file to {local_file_path}")
|
||||||
|
# Create directory if it doesn't exist
|
||||||
|
os.makedirs(cls.LOCAL_STORE_PATH, exist_ok=True)
|
||||||
|
# Download the file
|
||||||
|
with requests.get(url, stream=True) as r:
|
||||||
|
r.raise_for_status()
|
||||||
|
with open(local_file_path, "wb") as f:
|
||||||
|
for chunk in r.iter_content(chunk_size=8192):
|
||||||
|
f.write(chunk)
|
||||||
|
logger.info("File downloaded successfully")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def download_llamacloud_pipeline_file(
|
||||||
|
cls,
|
||||||
|
file: LlamaCloudFile,
|
||||||
|
force_download: bool = False,
|
||||||
|
):
|
||||||
|
file_name = file.file_name
|
||||||
|
pipeline_id = file.pipeline_id
|
||||||
|
|
||||||
|
# Check is the file already exists
|
||||||
|
downloaded_file_path = cls.get_file_path(file_name, pipeline_id)
|
||||||
|
if os.path.exists(downloaded_file_path) and not force_download:
|
||||||
|
logger.debug(f"File {file_name} already exists in local storage")
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
logger.info(f"Downloading file {file_name} for pipeline {pipeline_id}")
|
||||||
|
files = cls._get_files(pipeline_id)
|
||||||
|
if not files or not isinstance(files, list):
|
||||||
|
raise Exception("No files found in LlamaCloud")
|
||||||
|
for file_entry in files:
|
||||||
|
if file_entry["name"] == file_name:
|
||||||
|
file_id = file_entry["file_id"]
|
||||||
|
project_id = file_entry["project_id"]
|
||||||
|
file_detail = cls._get_file_detail(project_id, file_id)
|
||||||
|
cls._download_file(file_detail["url"], downloaded_file_path)
|
||||||
|
break
|
||||||
|
except Exception as error:
|
||||||
|
logger.info(f"Error fetching file from LlamaCloud: {error}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_file_name(cls, name: str, pipeline_id: str) -> str:
|
||||||
|
return cls.DOWNLOAD_FILE_NAME_TPL.format(pipeline_id=pipeline_id, filename=name)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_file_path(cls, name: str, pipeline_id: str) -> str:
|
||||||
|
return os.path.join(cls.LOCAL_STORE_PATH, cls.get_file_name(name, pipeline_id))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_request(
|
||||||
|
url: str, data=None, headers: Optional[Dict] = None, method: str = "get"
|
||||||
|
):
|
||||||
|
if headers is None:
|
||||||
|
headers = {
|
||||||
|
"Accept": "application/json",
|
||||||
|
"Authorization": f'Bearer {os.getenv("LLAMA_CLOUD_API_KEY")}',
|
||||||
|
}
|
||||||
|
response = requests.request(method, url, headers=headers, data=data)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from app.api.routers.models import Message
|
||||||
|
from llama_index.core.prompts import PromptTemplate
|
||||||
|
from llama_index.core.settings import Settings
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
NEXT_QUESTIONS_SUGGESTION_PROMPT = PromptTemplate(
|
||||||
|
"你是一个乐于助人的助手!你的任务是对用户可能会问的下一个问题给出建议。 "
|
||||||
|
"\n这是对话历史记录"
|
||||||
|
"\n---------------------\n{conversation}\n---------------------"
|
||||||
|
"考虑到对话历史记录,仅限于现在知识库已有内容, 请给我 $number_of_questions 个你接下来可能会问题的问题!"
|
||||||
|
)
|
||||||
|
N_QUESTION_TO_GENERATE = 3
|
||||||
|
|
||||||
|
|
||||||
|
class NextQuestions(BaseModel):
|
||||||
|
"""A list of questions that user might ask next"""
|
||||||
|
|
||||||
|
questions: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class NextQuestionSuggestion:
|
||||||
|
@staticmethod
|
||||||
|
async def suggest_next_questions(
|
||||||
|
messages: List[Message],
|
||||||
|
number_of_questions: int = N_QUESTION_TO_GENERATE,
|
||||||
|
) -> List[str]:
|
||||||
|
# Reduce the cost by only using the last two messages
|
||||||
|
last_user_message = None
|
||||||
|
last_assistant_message = None
|
||||||
|
for message in reversed(messages):
|
||||||
|
if message.role == "user":
|
||||||
|
last_user_message = f"User: {message.content}"
|
||||||
|
elif message.role == "assistant":
|
||||||
|
last_assistant_message = f"Assistant: {message.content}"
|
||||||
|
if last_user_message and last_assistant_message:
|
||||||
|
break
|
||||||
|
conversation: str = f"{last_user_message}\n{last_assistant_message}"
|
||||||
|
|
||||||
|
output: NextQuestions = await Settings.llm.astructured_predict(
|
||||||
|
NextQuestions,
|
||||||
|
prompt=NEXT_QUESTIONS_SUGGESTION_PROMPT,
|
||||||
|
conversation=conversation,
|
||||||
|
nun_questions=number_of_questions,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output.questions
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
import logging
|
||||||
|
from llama_index.core.indices import VectorStoreIndex
|
||||||
|
from app.engine.vectordb import get_vector_store
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger("uvicorn")
|
||||||
|
|
||||||
|
index = None
|
||||||
|
|
||||||
|
def get_index(params=None):
|
||||||
|
global index
|
||||||
|
if index is None:
|
||||||
|
logger.info("Connecting vector store...")
|
||||||
|
|
||||||
|
store = get_vector_store()
|
||||||
|
# Load the index from the vector store
|
||||||
|
# If you are using a vector store that doesn't store text,
|
||||||
|
# you must load the index from both the vector store and the document store
|
||||||
|
index = VectorStoreIndex.from_vector_store(store)
|
||||||
|
logger.info("Finished load index from vector store.")
|
||||||
|
|
||||||
|
return index
|
||||||
@@ -0,0 +1,61 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from llama_index.core.agent import AgentRunner, ReActChatFormatter
|
||||||
|
from llama_index.core.settings import Settings
|
||||||
|
from llama_index.core.tools.query_engine import QueryEngineTool
|
||||||
|
|
||||||
|
from app.engine.engine import create_query_engine, create_summary_query_engine
|
||||||
|
from app.engine.index import get_index
|
||||||
|
#from app.engine.loaders.db import makeDescriptionByEngine
|
||||||
|
from app.engine.tools import ToolFactory
|
||||||
|
|
||||||
|
|
||||||
|
def get_chat_engine(filters=None, params=None):
|
||||||
|
system_prompt = os.getenv("SYSTEM_PROMPT")
|
||||||
|
top_k = int(os.getenv("TOP_K", "3"))
|
||||||
|
use_reranker = os.getenv("RERANK_ENABLED")
|
||||||
|
tools = []
|
||||||
|
|
||||||
|
# 创建SQL查询工具
|
||||||
|
# sql_query_engine = create_summary_query_engine(index)
|
||||||
|
# sql_query_tool = QueryEngineTool.from_defaults(query_engine=sql_query_engine,
|
||||||
|
# name="zjdata_query_tool",
|
||||||
|
# description="来源于一个由博微公司电力造价软件编制的造价工程文件。该文件以多张表格的形式存储存储了整个工程的全部数据内容。适用于以详细的自然语言查询表格数据方式查询造价工程各项具体属性、费用的数值。请先使用“zj_query_tool”无法解决才使用本工具"
|
||||||
|
# )
|
||||||
|
#tools.append(sql_query_tool)
|
||||||
|
|
||||||
|
# Add query tool if index exists
|
||||||
|
index = get_index()
|
||||||
|
if index is not None:
|
||||||
|
summary_query_engine = create_summary_query_engine(index,top_k,use_reranker,filters)
|
||||||
|
summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool",
|
||||||
|
description="适用于任何需要进行全面总结、概括的要求。",
|
||||||
|
)
|
||||||
|
query_engine = create_query_engine(index,top_k,use_reranker,filters)
|
||||||
|
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool",
|
||||||
|
description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。",
|
||||||
|
)
|
||||||
|
|
||||||
|
tools.append(summary_query_tool)
|
||||||
|
tools.append(query_engine_tool)
|
||||||
|
|
||||||
|
# Add additional tools
|
||||||
|
tools += ToolFactory.from_env()
|
||||||
|
|
||||||
|
prefix_messages = ("""您的设计旨在帮助完成各种任务,从回答问题到提供其他类型分析的摘要。\n\n##工具\n\n你可以访问各种工具。你有责任按照你认为合适的顺序使用这些工具来完成当前的任务。\n这可能需要将任务分解为子任务,并使用不同的工具来完成每个子任务。\n\n你可以访问以下工具:\n{tool_desc}\n\n\n##输出格式\n\n请用与问题相同的语言回答,并使用以下格式:\n\n \nThought: 用户当前的语言是:(user's language)。我需要使用工具来帮助我回答问题。\nAction: 如果使用工具,则为工具名称(one of {tool_names})。\nAction Input: 输入给工具的内容,使用JSON格式表示kwargs(例如{{\"input\": \"hello world\", \"num_beams\": 5}})\n \n\n请始终以Thought开始。\n\n请始终以Thought开始。\n\n请始终以Thought开始。\n\n请始终以Thought开始。\n\n切勿用Markdown代码标记包围你的响应。如果需要,可以在响应中使用代码标记。\n\n请为Action Input使用有效的JSON格式。不要这样做{{\'input\': \'hello world\', \'num_beams\': 5}}。\n\n如果使用此格式,用户将以下面的格式进行回应:\n\n \nObservation: 工具响应\n \n\n你应该继续重复上述格式,直到你有足够的信息来回答问题而无需使用更多工具。此时,你必须使用以下两种格式之一进行回答:\n\n \nThought: 我可以不用任何工具来回答。我将使用用户的语言来回答。\nAnswer: [你的答案(与用户问题相同的语言)]\n \n\n \nThought: 我无法使用提供的工具回答问题。\nAnswer: [你的答案(与用户问题相同的语言)]\n \n\n##如果从工具中得到的回应是Empty Response,那么只需要回答“我不知道”,不需要额外回答别的内容。## 当前对话\n\n以下是当前对话,由人类和助手的消息交替组成。\n""")
|
||||||
|
react_chat_formatter = ReActChatFormatter.from_defaults(prefix_messages)
|
||||||
|
agentrunner = AgentRunner.from_llm(
|
||||||
|
llm=Settings.llm,
|
||||||
|
tools=tools,
|
||||||
|
react_chat_formatter=react_chat_formatter,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
return agentrunner
|
||||||
|
# create the function calling worker for reasoning
|
||||||
|
# worker = FunctionCallingAgentWorker.from_tools(
|
||||||
|
# tools, verbose=True
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# # wrap the worker in the top-level planner
|
||||||
|
# return StructuredPlannerAgent(worker, tools)
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
STORAGE_DIR = "storage" # directory to cache the generated index
|
||||||
@@ -0,0 +1,108 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from llama_index.core import SummaryIndex, SQLDatabase, VectorStoreIndex
|
||||||
|
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
|
||||||
|
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema
|
||||||
|
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||||
|
from llama_index.core.response_synthesizers import ResponseMode
|
||||||
|
from llama_index.readers.database import DatabaseReader
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
|
||||||
|
from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template
|
||||||
|
from app.engine.retriever.HybridRetriever import HybridRetriever
|
||||||
|
from app.settings import get_node_postprocessors
|
||||||
|
|
||||||
|
def makeDescriptionByEngine(sql_database:SQLDatabase):
|
||||||
|
reader = DatabaseReader(sql_database)
|
||||||
|
|
||||||
|
table_names = sql_database.get_usable_table_names()
|
||||||
|
table_schema_objs = []
|
||||||
|
for table_name in table_names:
|
||||||
|
columns = sql_database.get_table_columns(table_name)
|
||||||
|
if len(columns) > 150:
|
||||||
|
continue
|
||||||
|
stats_txt = ""
|
||||||
|
|
||||||
|
if table_name == 'gongchengshuxing':
|
||||||
|
stats_txt = '该表中有以下属性:'
|
||||||
|
documents = reader.load_data(query='select name from gongchengshuxing')
|
||||||
|
for index in range(len(documents) if len(documents) < 30 else 30):
|
||||||
|
if index == 0:
|
||||||
|
continue
|
||||||
|
elif index > 1:
|
||||||
|
stats_txt += ','
|
||||||
|
stats_txt += documents[index].text.split(':')[1]
|
||||||
|
|
||||||
|
tbSchema = (SQLTableSchema(table_name=table_name, context_str=stats_txt))
|
||||||
|
table_schema_objs.append(tbSchema)
|
||||||
|
|
||||||
|
return table_schema_objs
|
||||||
|
|
||||||
|
def get_Retriever(index,**kwargs):
|
||||||
|
strEnableHybrid = os.getenv("HYBRID_ENABLED",'False')
|
||||||
|
bEnableHybrid = True if strEnableHybrid is not None and strEnableHybrid.title() == 'True' else False
|
||||||
|
if bEnableHybrid:
|
||||||
|
alpha = float(os.getenv("HYBRID_ALPHA", "0.5"))
|
||||||
|
retriever = HybridRetriever(index,alpha = alpha,**kwargs)
|
||||||
|
else:
|
||||||
|
retriever = index.as_retriever(**kwargs)
|
||||||
|
return retriever
|
||||||
|
|
||||||
|
|
||||||
|
sql_database = None
|
||||||
|
sql_obj_index = None
|
||||||
|
|
||||||
|
# Create a summary query engine
|
||||||
|
def create_summary_query_engine(top_k=3, use_reranker=False, filters=None):
|
||||||
|
global sql_obj_index
|
||||||
|
global sql_database
|
||||||
|
if sql_obj_index is None or sql_database is None:
|
||||||
|
sqlengine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
|
||||||
|
sql_database = SQLDatabase(sqlengine)
|
||||||
|
table_schema_objs = makeDescriptionByEngine(sql_database)
|
||||||
|
table_node_mapping = SQLTableNodeMapping(sql_database)
|
||||||
|
|
||||||
|
sql_obj_index = ObjectIndex.from_objects(
|
||||||
|
table_schema_objs,
|
||||||
|
table_node_mapping,
|
||||||
|
index_cls=VectorStoreIndex,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建SQL查询工具
|
||||||
|
sql_query_engine = SQLTableRetrieverQueryEngine(sql_database,
|
||||||
|
sql_obj_index.as_retriever(similarity_top_k=top_k),
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
return sql_query_engine
|
||||||
|
|
||||||
|
# Create a summary query engine
|
||||||
|
def create_summary_query_engine(index, top_k=3, use_reranker=False, filters=None):
|
||||||
|
summary_index = SummaryIndex(index.vector_store.get_nodes(node_ids=None))
|
||||||
|
summary_query_engine = summary_index.as_query_engine(
|
||||||
|
response_mode=ResponseMode.TREE_SUMMARIZE,
|
||||||
|
use_async=True,
|
||||||
|
streaming=True,
|
||||||
|
)
|
||||||
|
return summary_query_engine
|
||||||
|
|
||||||
|
# Create a query engine
|
||||||
|
def create_query_engine(index, top_k=3, use_reranker=False, filters=None):
|
||||||
|
# 创建向量检索查询工具
|
||||||
|
postprocess = None
|
||||||
|
if use_reranker:
|
||||||
|
postprocess = get_node_postprocessors()
|
||||||
|
|
||||||
|
query_engine = RetrieverQueryEngine.from_args(
|
||||||
|
get_Retriever(index,
|
||||||
|
similarity_top_k=top_k,
|
||||||
|
filters=filters),
|
||||||
|
text_qa_template=text_qa_template,
|
||||||
|
refine_template=refine_template,
|
||||||
|
summary_template = summary_template,
|
||||||
|
simple_template = simple_template,
|
||||||
|
node_postprocessors=postprocess,
|
||||||
|
use_async=True,
|
||||||
|
streaming=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return query_engine
|
||||||
@@ -0,0 +1,94 @@
|
|||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from app.engine.loaders import get_documents
|
||||||
|
from app.engine.vectordb import get_vector_store
|
||||||
|
from app.settings import init_settings
|
||||||
|
from app.engine.retriever.CHBM25Retriever import CHBM25Retriever
|
||||||
|
from llama_index.core.ingestion import IngestionPipeline
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
from llama_index.core.settings import Settings
|
||||||
|
from llama_index.core.storage import StorageContext
|
||||||
|
from llama_index.core.storage.docstore import SimpleDocumentStore
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
STORAGE_DIR = os.getenv("STORAGE_DIR", "storage")
|
||||||
|
|
||||||
|
|
||||||
|
def get_doc_store():
|
||||||
|
|
||||||
|
# If the storage directory is there, load the document store from it.
|
||||||
|
# If not, set up an in-memory document store since we can't load from a directory that doesn't exist.
|
||||||
|
if os.path.exists(STORAGE_DIR):
|
||||||
|
return SimpleDocumentStore.from_persist_dir(STORAGE_DIR)
|
||||||
|
else:
|
||||||
|
return SimpleDocumentStore()
|
||||||
|
|
||||||
|
|
||||||
|
def run_pipeline(docstore, vector_store, documents):
|
||||||
|
pipeline = IngestionPipeline(
|
||||||
|
transformations=[
|
||||||
|
SentenceSplitter(
|
||||||
|
chunk_size=Settings.chunk_size,
|
||||||
|
chunk_overlap=Settings.chunk_overlap,
|
||||||
|
),
|
||||||
|
Settings.embed_model,
|
||||||
|
],
|
||||||
|
docstore=docstore,
|
||||||
|
docstore_strategy="upserts_and_delete",
|
||||||
|
vector_store=vector_store,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the ingestion pipeline and store the results
|
||||||
|
nodes = pipeline.run(show_progress=True, documents=documents)
|
||||||
|
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
|
||||||
|
def persist_storage(docstore, vector_store):
|
||||||
|
storage_context = StorageContext.from_defaults(
|
||||||
|
docstore=docstore,
|
||||||
|
vector_store=vector_store,
|
||||||
|
)
|
||||||
|
storage_context.persist(STORAGE_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
def persist_BMRetriever(vector_store):
|
||||||
|
STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm")
|
||||||
|
top_k = int(os.getenv("TOP_K", "3"))
|
||||||
|
bmRetriver = CHBM25Retriever.from_defaults(similarity_top_k=top_k,nodes=vector_store.get_nodes([]))
|
||||||
|
bmRetriver.persist(STORAGE_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_datasource():
|
||||||
|
init_settings()
|
||||||
|
logger.info("Generate index for the provided data")
|
||||||
|
|
||||||
|
# Get the stores and documents or create new ones
|
||||||
|
documents = get_documents()
|
||||||
|
# Set private=false to mark the document as public (required for filtering)
|
||||||
|
for doc in documents:
|
||||||
|
doc.metadata["private"] = "false"
|
||||||
|
docstore = get_doc_store()
|
||||||
|
vector_store = get_vector_store()
|
||||||
|
|
||||||
|
# Run the ingestion pipeline
|
||||||
|
_ = run_pipeline(docstore, vector_store, documents)
|
||||||
|
|
||||||
|
# Build the index and persist storage
|
||||||
|
persist_storage(docstore, vector_store)
|
||||||
|
persist_BMRetriever(vector_store)
|
||||||
|
|
||||||
|
logger.info("Finished generating the index")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from phoenix.trace import using_project
|
||||||
|
with using_project(os.getenv("PHOENIX_PROJECT_NAME") + "_generate") as obj:
|
||||||
|
generate_datasource()
|
||||||
@@ -0,0 +1,93 @@
|
|||||||
|
from llama_index.core import PromptTemplate
|
||||||
|
|
||||||
|
text_qa_template_str = (
|
||||||
|
"# 角色\n"
|
||||||
|
"你是一名博微造价工程数据查询助手,专精于电力工程文件中的信息。"
|
||||||
|
"你的职责是提供有关电力造价、造价编制软件、文件结构及相关数据的精准、客观的回答,"
|
||||||
|
"如同直接从文件中提取的内容。\n"
|
||||||
|
"知识库中已经导入一个工程的全部数据,请你站在当前工程的角度回答用户关于工程文件的问题。\n"
|
||||||
|
"例如:询问“此工程”指当前导入的工程。询问“此工程名称”指当前导入的工程的工程名称。\n"
|
||||||
|
|
||||||
|
"## 技能\n"
|
||||||
|
"### 技能 1: 数据查询与提供\n"
|
||||||
|
"- 准确回答所有关于电力工程造价的相关问题。\n"
|
||||||
|
"- 提供具体数据,如成本估算、材料清单、劳动力需求等。\n"
|
||||||
|
"- 确保提供的信息严格基于工程文档中的记录。\n"
|
||||||
|
|
||||||
|
"### 技能 2: 技术性解释\n"
|
||||||
|
"- 解释造价工程中的技术术语和概念。\n"
|
||||||
|
"- 为复杂的工程细节提供清晰易懂的说明。\n"
|
||||||
|
|
||||||
|
"## 约束\n"
|
||||||
|
"- 仅回答与电力工程造价文件相关的具体问题。\n"
|
||||||
|
"- 不进行任何超出文件内容的猜测或假设。\n"
|
||||||
|
"- 所有回答均基于文件内容,采用客观和技术性的语言。\n"
|
||||||
|
"- 请基于这些信息回答问题。如果无法找到相关信息,请不要额外发散回答,不要回答多余的信息,只需要回答“我不知道这个问题的答案”。\n"
|
||||||
|
"以下为上下文信息\n"
|
||||||
|
"---------------------\n"
|
||||||
|
"{context_str}\n"
|
||||||
|
"---------------------\n"
|
||||||
|
"请根据上下文信息而非先前知识回答我的问题或回复我的指令。前面的上下文信息可能有用,也可能没用,你需要从我给出的上下文信息中选出与我的问题最相关的那些,来为你的回答提供依据。回答一定要忠于原文,简洁但不丢信息,不要胡乱编造。如果无法找到相关信息,请不要额外发散回答,不要回答多余的信息,只需要回答“我不知道这个问题的答案”。我的问题或指令是什么语种,你就用什么语种回复。\n"
|
||||||
|
"如果是表结构或者是数据库的相关内容,只用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n"
|
||||||
|
|
||||||
|
"问题:{query_str}\n"
|
||||||
|
"你的回复: "
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
text_qa_template = PromptTemplate(text_qa_template_str)
|
||||||
|
|
||||||
|
refine_template_str = (
|
||||||
|
"这是原本的问题: {query_str}\n"
|
||||||
|
"我们已经提供了回答: {existing_answer}\n"
|
||||||
|
"现在我们有机会改进这个回答 "
|
||||||
|
"使用以下更多上下文(仅当有助于改进回答时使用)\n"
|
||||||
|
"如果新的上下文对回答没有影响,或者原来的回答已经正确,不要在上次回答的后边再加上多余的补充信息,直接返回原本的回答。\n"
|
||||||
|
"如果新的上下文对回答没有影响,或者原来的回答已经正确,不要在上次回答的后边再加上多余的补充信息,直接返回原本的回答。\n"
|
||||||
|
"------------\n"
|
||||||
|
"{context_msg}\n"
|
||||||
|
"------------\n"
|
||||||
|
"如果回答中已经包含有正确答案,不要返回多余的解释等信息,只返回正确答案\n"
|
||||||
|
"如果是表结构或者是数据库的相关内容,仅用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n"
|
||||||
|
"改进的回答: "
|
||||||
|
)
|
||||||
|
|
||||||
|
refine_template = PromptTemplate(refine_template_str)
|
||||||
|
|
||||||
|
summary_template_str = (
|
||||||
|
"# 角色\n"
|
||||||
|
"你是一名博微造价工程数据查询助手,专精于电力工程文件中的信息。"
|
||||||
|
"你的职责是提供有关电力造价、造价编制软件、文件结构及相关数据的精准、客观的回答,"
|
||||||
|
"如同直接从文件中提取的内容。\n"
|
||||||
|
|
||||||
|
"## 技能\n"
|
||||||
|
"### 技能 1: 数据查询与提供\n"
|
||||||
|
"- 准确回答所有关于电力工程造价的相关问题。\n"
|
||||||
|
"- 提供具体数据,如成本估算、材料清单、劳动力需求等。\n"
|
||||||
|
"- 确保提供的信息严格基于工程文档中的记录。\n"
|
||||||
|
|
||||||
|
"### 技能 2: 技术性解释\n"
|
||||||
|
"- 解释造价工程中的技术术语和概念。\n"
|
||||||
|
"- 为复杂的工程细节提供清晰易懂的说明。\n"
|
||||||
|
|
||||||
|
"## 约束\n"
|
||||||
|
"- 仅回答与电力工程造价文件相关的具体问题。\n"
|
||||||
|
"- 不进行任何超出文件内容的猜测或假设。\n"
|
||||||
|
"- 所有回答均基于文件内容,采用客观和技术性的语言。\n"
|
||||||
|
"- 请基于这些信息回答问题。如果无法找到相关信息,请不要额外发散回答,不要回答多余的信息,只需要回答“我不知道这个问题的答案”。\n"
|
||||||
|
"来自多个来源的上下文信息如下。\n"
|
||||||
|
"---------------------\n"
|
||||||
|
"{context_str}\n"
|
||||||
|
"---------------------\n"
|
||||||
|
"鉴于来自多个来源的信息而非先验知识, "
|
||||||
|
"回答查询。\n"
|
||||||
|
"如果是表结构或者是数据库的相关内容,只用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n"
|
||||||
|
"Query: {query_str}\n"
|
||||||
|
"Answer: "
|
||||||
|
)
|
||||||
|
summary_template = PromptTemplate(summary_template_str)
|
||||||
|
|
||||||
|
simple_template_str = (
|
||||||
|
"{query_str}"
|
||||||
|
)
|
||||||
|
simple_template = PromptTemplate(simple_template_str)
|
||||||
@@ -0,0 +1,71 @@
|
|||||||
|
import os
|
||||||
|
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||||
|
from llama_index.vector_stores.qdrant import QdrantVectorStore
|
||||||
|
from qdrant_client import qdrant_client
|
||||||
|
|
||||||
|
qclient = None
|
||||||
|
|
||||||
|
def get_qdrant_vector_store():
|
||||||
|
collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default")
|
||||||
|
vector_store_path = os.getenv("VECTOR_STORE_PATH")
|
||||||
|
host=os.getenv("VECTOR_STORE_HOST", "127.0.0.1"),
|
||||||
|
port=int(os.getenv("VECTOR_STORE_PORT", "6333")),
|
||||||
|
|
||||||
|
if not vector_store_path or not host:
|
||||||
|
raise ValueError(
|
||||||
|
"Please provide either VECTOR_STORE_PATH or VECTOR_STORE_HOST and VECTOR_STORE_PORT"
|
||||||
|
)
|
||||||
|
# if VECTOR_STORE_PATH is set, use a local QdrantVectorStore from the path
|
||||||
|
# otherwise, use a remote QdrantVectorStore
|
||||||
|
global qclient
|
||||||
|
if qclient == None:
|
||||||
|
if vector_store_path:
|
||||||
|
qclient = qdrant_client.QdrantClient(
|
||||||
|
path=vector_store_path,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
qclient = qdrant_client.QdrantClient(
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_store = QdrantVectorStore(client=qclient, collection_name=collection_name)
|
||||||
|
return vector_store
|
||||||
|
|
||||||
|
def get_chroma_vector_store():
|
||||||
|
collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default")
|
||||||
|
vector_store_path = os.getenv("VECTOR_STORE_PATH")
|
||||||
|
# if VECTOR_STORE_PATH is set, use a local ChromaVectorStore from the path
|
||||||
|
# otherwise, use a remote ChromaVectorStore (ChromaDB Cloud is not supported yet)
|
||||||
|
if vector_store_path:
|
||||||
|
store = ChromaVectorStore.from_params(
|
||||||
|
persist_dir=vector_store_path, collection_name=collection_name,
|
||||||
|
collection_kwargs={"metadata":{"hnsw:space":"cosine"}},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if not os.getenv("VECTOR_STORE_HOST") or not os.getenv("VECTOR_STORE_PORT"):
|
||||||
|
raise ValueError(
|
||||||
|
"Please provide either VECTOR_STORE_PATH or VECTOR_STORE_HOST and VECTOR_STORE_PORT"
|
||||||
|
)
|
||||||
|
store = ChromaVectorStore.from_params(
|
||||||
|
host=os.getenv("VECTOR_STORE_HOST"),
|
||||||
|
port=int(os.getenv("VECTOR_STORE_PORT")),
|
||||||
|
collection_name=collection_name,
|
||||||
|
collection_kwargs={"metadata":{"hnsw:space":"cosine"}},
|
||||||
|
)
|
||||||
|
return store
|
||||||
|
|
||||||
|
def get_vector_store():
|
||||||
|
store_type=os.getenv("VECTOR_STORE_TYPE")
|
||||||
|
|
||||||
|
store = None
|
||||||
|
|
||||||
|
match store_type:
|
||||||
|
case "chroma":
|
||||||
|
store = get_chroma_vector_store()
|
||||||
|
case "qdrant":
|
||||||
|
store = get_qdrant_vector_store()
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Invalid vector store type: {store_type}")
|
||||||
|
|
||||||
|
return store
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from app.engine.loaders.db import DBLoaderConfig, get_db_documents
|
||||||
|
from app.engine.loaders.file import FileLoaderConfig, get_file_documents
|
||||||
|
from app.engine.loaders.web import WebLoaderConfig, get_web_documents
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def load_configs():
|
||||||
|
with open("config/loaders.yaml") as f:
|
||||||
|
configs = yaml.safe_load(f)
|
||||||
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
def get_documents():
|
||||||
|
documents = []
|
||||||
|
config = load_configs()
|
||||||
|
if config is None or len(config.items()) == 0:
|
||||||
|
return documents
|
||||||
|
|
||||||
|
for loader_type, loader_config in config.items():
|
||||||
|
logger.info(
|
||||||
|
f"Loading documents from loader: {loader_type}, config: {loader_config}"
|
||||||
|
)
|
||||||
|
|
||||||
|
loader_config = loader_config or []
|
||||||
|
match loader_type:
|
||||||
|
case "file":
|
||||||
|
document = get_file_documents(FileLoaderConfig(**loader_config))
|
||||||
|
case "web":
|
||||||
|
document = get_web_documents(WebLoaderConfig(**loader_config))
|
||||||
|
case "db":
|
||||||
|
document = get_db_documents(configs=[DBLoaderConfig(**cfg) for cfg in loader_config])
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Invalid loader type: {loader_type}")
|
||||||
|
documents.extend(document)
|
||||||
|
|
||||||
|
return documents
|
||||||
@@ -0,0 +1,140 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
from llama_index.core import SQLDatabase, Document
|
||||||
|
from llama_index.readers.database import DatabaseReader
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class CustomDatabaseReader(DatabaseReader):
|
||||||
|
"""Simple Database reader.
|
||||||
|
|
||||||
|
Concatenates each row into Document used by LlamaIndex.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sql_database (Optional[SQLDatabase]): SQL database to use,
|
||||||
|
including table names to specify.
|
||||||
|
See :ref:`Ref-Struct-Store` for more details.
|
||||||
|
|
||||||
|
OR
|
||||||
|
|
||||||
|
engine (Optional[Engine]): SQLAlchemy Engine object of the database connection.
|
||||||
|
|
||||||
|
OR
|
||||||
|
|
||||||
|
uri (Optional[str]): uri of the database connection.
|
||||||
|
|
||||||
|
OR
|
||||||
|
|
||||||
|
scheme (Optional[str]): scheme of the database connection.
|
||||||
|
host (Optional[str]): host of the database connection.
|
||||||
|
port (Optional[int]): port of the database connection.
|
||||||
|
user (Optional[str]): user of the database connection.
|
||||||
|
password (Optional[str]): password of the database connection.
|
||||||
|
dbname (Optional[str]): dbname of the database connection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DatabaseReader: A DatabaseReader object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sql_database: Optional[SQLDatabase] = None,
|
||||||
|
engine: Optional[Engine] = None,
|
||||||
|
uri: Optional[str] = None,
|
||||||
|
scheme: Optional[str] = None,
|
||||||
|
host: Optional[str] = None,
|
||||||
|
port: Optional[str] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
password: Optional[str] = None,
|
||||||
|
dbname: Optional[str] = None,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize with parameters."""
|
||||||
|
if sql_database:
|
||||||
|
self.sql_database = sql_database
|
||||||
|
elif engine:
|
||||||
|
self.sql_database = SQLDatabase(engine, *args, **kwargs)
|
||||||
|
elif uri:
|
||||||
|
self.uri = uri
|
||||||
|
self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
|
||||||
|
elif scheme and host and port and user and password and dbname:
|
||||||
|
uri = f"{scheme}://{user}:{password}@{host}:{port}/{dbname}"
|
||||||
|
self.uri = uri
|
||||||
|
self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"You must provide either a SQLDatabase, "
|
||||||
|
"a SQL Alchemy Engine, a valid connection URI, or a valid "
|
||||||
|
"set of credentials."
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_data(self, query: str, explanation: str) -> List[Document]:
|
||||||
|
"""Query and load data from the Database, returning a list of Documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): Query parameter to filter tables and rows.
|
||||||
|
explanation (str): Explanation for the query to be included in the document.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Document]: A list of Document objects.
|
||||||
|
"""
|
||||||
|
dco_str = explanation + "\n"
|
||||||
|
|
||||||
|
with self.sql_database.engine.connect() as connection:
|
||||||
|
if query is None:
|
||||||
|
raise ValueError("A query parameter is necessary to filter the data")
|
||||||
|
else:
|
||||||
|
result = connection.execute(text(query))
|
||||||
|
|
||||||
|
dco_str += ", ".join(
|
||||||
|
[f"{entry}" for entry in result.keys()]
|
||||||
|
) + "\n"
|
||||||
|
|
||||||
|
for item in result.fetchall():
|
||||||
|
# Fetch each item
|
||||||
|
record_str = ", ".join(
|
||||||
|
[f"{entry}" for col, entry in zip(result.keys(), item)]
|
||||||
|
)
|
||||||
|
dco_str += record_str + "\n"
|
||||||
|
|
||||||
|
doc = Document(text=dco_str)
|
||||||
|
doc.metadata["name"] = query
|
||||||
|
doc.metadata["context"] = query
|
||||||
|
doc.metadata["file_type"] = "application/vnd.ms-excel"
|
||||||
|
return [doc]
|
||||||
|
|
||||||
|
class DBLoaderConfig(BaseModel):
|
||||||
|
uri: str
|
||||||
|
queries: List[dict]
|
||||||
|
|
||||||
|
def get_db_documents(configs: list[DBLoaderConfig]):
|
||||||
|
docs = []
|
||||||
|
|
||||||
|
if len(configs) == 0 or configs[0].uri == "":
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to load database, error message: uri is empty. Return as empty document list."
|
||||||
|
)
|
||||||
|
return docs
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
'file_type': 'application/booway.document.zj',
|
||||||
|
}
|
||||||
|
|
||||||
|
for entry in configs:
|
||||||
|
engine = create_engine(entry.uri)
|
||||||
|
sql_database = SQLDatabase(engine)
|
||||||
|
|
||||||
|
loader = CustomDatabaseReader(sql_database)
|
||||||
|
for query_dict in entry.queries:
|
||||||
|
query = query_dict.get("sql", "")
|
||||||
|
explanation = query_dict.get("explanation", "")
|
||||||
|
logger.info(f"Loading data from database with query: {query}")
|
||||||
|
documents = loader.load_data(query=query, explanation=explanation)
|
||||||
|
|
||||||
|
docs.extend(documents)
|
||||||
|
return docs
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
from llama_index.core.readers.json import JSONReader
|
||||||
|
from llama_parse import LlamaParse
|
||||||
|
from pydantic import BaseModel, validator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FileLoaderConfig(BaseModel):
|
||||||
|
data_dir: str = "data"
|
||||||
|
use_llama_parse: bool = False
|
||||||
|
|
||||||
|
@validator("data_dir")
|
||||||
|
def data_dir_must_exist(cls, v):
|
||||||
|
if not os.path.isdir(v):
|
||||||
|
raise ValueError(f"Directory '{v}' does not exist")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
def llama_parse_parser():
|
||||||
|
if os.getenv("LLAMA_CLOUD_API_KEY") is None:
|
||||||
|
raise ValueError(
|
||||||
|
"LLAMA_CLOUD_API_KEY environment variable is not set. "
|
||||||
|
"Please set it in .env file or in your shell environment then run again!"
|
||||||
|
)
|
||||||
|
parser = LlamaParse(
|
||||||
|
result_type="markdown",
|
||||||
|
verbose=True,
|
||||||
|
language="en",
|
||||||
|
ignore_errors=False,
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def llama_parse_extractor() -> Dict[str, LlamaParse]:
|
||||||
|
from llama_parse.utils import SUPPORTED_FILE_TYPES
|
||||||
|
|
||||||
|
parser = llama_parse_parser()
|
||||||
|
return {file_type: parser for file_type in SUPPORTED_FILE_TYPES}
|
||||||
|
|
||||||
|
def llama_local_extractor() -> Dict[str, BaseReader]:
|
||||||
|
return {".json" : JSONReader(clean_json=False,levels_back=0)}
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_documents(config: FileLoaderConfig):
|
||||||
|
from llama_index.core.readers import SimpleDirectoryReader
|
||||||
|
|
||||||
|
try:
|
||||||
|
file_extractor = None
|
||||||
|
if config.use_llama_parse:
|
||||||
|
# LlamaParse is async first,
|
||||||
|
# so we need to use nest_asyncio to run it in sync mode
|
||||||
|
import nest_asyncio
|
||||||
|
|
||||||
|
nest_asyncio.apply()
|
||||||
|
|
||||||
|
file_extractor = llama_parse_extractor()
|
||||||
|
else:
|
||||||
|
file_extractor = llama_local_extractor()
|
||||||
|
|
||||||
|
reader = SimpleDirectoryReader(
|
||||||
|
config.data_dir,
|
||||||
|
recursive=True,
|
||||||
|
filename_as_id=True,
|
||||||
|
raise_on_error=True,
|
||||||
|
file_extractor=file_extractor,
|
||||||
|
)
|
||||||
|
return reader.load_data()
|
||||||
|
except Exception as e:
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
# Catch the error if the data dir is empty
|
||||||
|
# and return as empty document list
|
||||||
|
_, _, exc_traceback = sys.exc_info()
|
||||||
|
function_name = traceback.extract_tb(exc_traceback)[-1].name
|
||||||
|
if function_name == "_add_files":
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to load file documents, error message: {e} . Return as empty document list."
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
# Raise the error if it is not the case of empty data dir
|
||||||
|
raise e
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class CrawlUrl(BaseModel):
|
||||||
|
base_url: str
|
||||||
|
prefix: str
|
||||||
|
max_depth: int = Field(default=1, ge=0)
|
||||||
|
|
||||||
|
|
||||||
|
class WebLoaderConfig(BaseModel):
|
||||||
|
driver_arguments: list[str] = Field(default=None)
|
||||||
|
urls: list[CrawlUrl] = []
|
||||||
|
|
||||||
|
|
||||||
|
def get_web_documents(config: WebLoaderConfig):
|
||||||
|
from llama_index.readers.web import WholeSiteReader
|
||||||
|
from selenium import webdriver
|
||||||
|
from selenium.webdriver.chrome.options import Options
|
||||||
|
|
||||||
|
options = Options()
|
||||||
|
driver_arguments = config.driver_arguments or []
|
||||||
|
for arg in driver_arguments:
|
||||||
|
options.add_argument(arg)
|
||||||
|
|
||||||
|
docs = []
|
||||||
|
urls = config.urls or []
|
||||||
|
for url in config.urls:
|
||||||
|
scraper = WholeSiteReader(
|
||||||
|
prefix=url.prefix,
|
||||||
|
max_depth=url.max_depth,
|
||||||
|
driver=webdriver.Chrome(options=options),
|
||||||
|
)
|
||||||
|
docs.extend(scraper.load_data(url.base_url))
|
||||||
|
|
||||||
|
return docs
|
||||||
@@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def load_configs():
|
def load_configs():
|
||||||
with open("config/loaders.yaml",'r', encoding='utf-8') as f:
|
with open("config/loaders.yaml") as f:
|
||||||
configs = yaml.safe_load(f)
|
configs = yaml.safe_load(f)
|
||||||
return configs
|
return configs
|
||||||
|
|
||||||
|
|||||||
@@ -2,14 +2,17 @@ import logging
|
|||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from llama_index.core import SQLDatabase, Document
|
from llama_index.core import SQLDatabase, Document
|
||||||
|
from llama_index.core.objects import SQLTableSchema
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
from llama_index.readers.database import DatabaseReader
|
from llama_index.readers.database import DatabaseReader
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import create_engine, text
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy import text
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class CustomDatabaseReader(DatabaseReader):
|
class CustomDatabaseReader(BaseReader):
|
||||||
"""Simple Database reader.
|
"""Simple Database reader.
|
||||||
|
|
||||||
Concatenates each row into Document used by LlamaIndex.
|
Concatenates each row into Document used by LlamaIndex.
|
||||||
@@ -73,30 +76,28 @@ class CustomDatabaseReader(DatabaseReader):
|
|||||||
"set of credentials."
|
"set of credentials."
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_data(self, query: str, explanation: str) -> List[Document]:
|
def load_data(self, query: str) -> List[Document]:
|
||||||
"""Query and load data from the Database, returning a list of Documents.
|
"""Query and load data from the Database, returning a list of Documents.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): Query parameter to filter tables and rows.
|
query (str): Query parameter to filter tables and rows.
|
||||||
explanation (str): Explanation for the query to be included in the document.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Document]: A list of Document objects.
|
List[Document]: A list of Document objects.
|
||||||
"""
|
"""
|
||||||
dco_str = explanation + "\n"
|
dco_str = ""
|
||||||
|
|
||||||
with self.sql_database.engine.connect() as connection:
|
with self.sql_database.engine.connect() as connection:
|
||||||
if query is None:
|
if query is None:
|
||||||
raise ValueError("A query parameter is necessary to filter the data")
|
raise ValueError("A query parameter is necessary to filter the data")
|
||||||
else:
|
else:
|
||||||
result = connection.execute(text(query))
|
result = connection.execute(text(query))
|
||||||
|
|
||||||
dco_str += ", ".join(
|
dco_str = ", ".join(
|
||||||
[f"{entry}" for entry in result.keys()]
|
[f"{entry}" for entry in result.keys()]
|
||||||
) + "\n"
|
)
|
||||||
|
|
||||||
for item in result.fetchall():
|
for item in result.fetchall():
|
||||||
# Fetch each item
|
# fetch each item
|
||||||
record_str = ", ".join(
|
record_str = ", ".join(
|
||||||
[f"{entry}" for col, entry in zip(result.keys(), item)]
|
[f"{entry}" for col, entry in zip(result.keys(), item)]
|
||||||
)
|
)
|
||||||
@@ -110,7 +111,7 @@ class CustomDatabaseReader(DatabaseReader):
|
|||||||
|
|
||||||
class DBLoaderConfig(BaseModel):
|
class DBLoaderConfig(BaseModel):
|
||||||
uri: str
|
uri: str
|
||||||
queries: List[dict]
|
queries: List[str]
|
||||||
|
|
||||||
def get_db_documents(configs: list[DBLoaderConfig]):
|
def get_db_documents(configs: list[DBLoaderConfig]):
|
||||||
docs = []
|
docs = []
|
||||||
@@ -122,19 +123,33 @@ def get_db_documents(configs: list[DBLoaderConfig]):
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
'file_type': 'application/booway.document.zj',
|
#'file_name':'',
|
||||||
|
'file_type':'application/booway.document.zj',
|
||||||
|
#'file_path':'',
|
||||||
|
#'file_size':'',
|
||||||
|
#'creation_date':'',
|
||||||
|
#'last_modified_date':'',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#from llama_index.readers.database import DatabaseReader
|
||||||
for entry in configs:
|
for entry in configs:
|
||||||
engine = create_engine(entry.uri)
|
engine = create_engine(entry.uri)
|
||||||
sql_database = SQLDatabase(engine)
|
sql_database = SQLDatabase(engine)
|
||||||
|
|
||||||
|
# table_schema_objs = makeDescriptionByEngine(sql_database)
|
||||||
|
# table_node_mapping = SQLTableNodeMapping(sql_database)
|
||||||
|
#
|
||||||
|
# nodes = table_node_mapping.to_nodes(table_schema_objs)
|
||||||
|
# for node in nodes:
|
||||||
|
# node.metadata.update(metadata)
|
||||||
|
#
|
||||||
|
# docs.extend(nodes)
|
||||||
|
|
||||||
|
queries = entry.queries or []
|
||||||
loader = CustomDatabaseReader(sql_database)
|
loader = CustomDatabaseReader(sql_database)
|
||||||
for query_dict in entry.queries:
|
for query in queries:
|
||||||
query = query_dict.get("sql", "")
|
|
||||||
explanation = query_dict.get("explanation", "")
|
|
||||||
logger.info(f"Loading data from database with query: {query}")
|
logger.info(f"Loading data from database with query: {query}")
|
||||||
documents = loader.load_data(query=query, explanation=explanation)
|
documents = loader.load_data(query=query)
|
||||||
|
|
||||||
docs.extend(documents)
|
docs.extend(documents)
|
||||||
return docs
|
return docs
|
||||||
|
|||||||
@@ -39,16 +39,15 @@ refine_template_str = (
|
|||||||
"这是原本的问题: {query_str}\n"
|
"这是原本的问题: {query_str}\n"
|
||||||
"我们已经提供了回答: {existing_answer}\n"
|
"我们已经提供了回答: {existing_answer}\n"
|
||||||
"现在我们有机会改进这个回答 "
|
"现在我们有机会改进这个回答 "
|
||||||
"使用以下更多上下文(仅当有助于改进回答时使用)\n"
|
"使用以下更多上下文(仅当需要用时)\n"
|
||||||
"------------\n"
|
"------------\n"
|
||||||
"{context_msg}\n"
|
"{context_msg}\n"
|
||||||
"------------\n"
|
"------------\n"
|
||||||
"如果新的上下文对回答没有影响,或者原来的回答已经正确,直接返回原本的回答。\n"
|
"根据新的上下文, 请改进原来的回答。"
|
||||||
"如果新的上下文有助于改进,请基于它更新回答,但不要引入与问题无关的信息。\n"
|
"如果新的上下文没有用, 直接返回原本的回答。\n"
|
||||||
"如果是表结构或者是数据库的相关内容,仅用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n"
|
"如果是表结构或者是数据库的相关内容,只用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n"
|
||||||
"改进的回答: "
|
"改进的回答: "
|
||||||
)
|
)
|
||||||
|
|
||||||
refine_template = PromptTemplate(refine_template_str)
|
refine_template = PromptTemplate(refine_template_str)
|
||||||
|
|
||||||
summary_template_str = (
|
summary_template_str = (
|
||||||
|
|||||||
@@ -0,0 +1,133 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, cast
|
||||||
|
|
||||||
|
from llama_index.core.base.base_retriever import BaseRetriever
|
||||||
|
from llama_index.core.callbacks.base import CallbackManager
|
||||||
|
from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K
|
||||||
|
from llama_index.core.indices.vector_store.base import VectorStoreIndex
|
||||||
|
from llama_index.core.schema import BaseNode, IndexNode, NodeWithScore, QueryBundle
|
||||||
|
from llama_index.core.storage.docstore.types import BaseDocumentStore
|
||||||
|
from llama_index.core.vector_stores.utils import (
|
||||||
|
node_to_metadata_dict,
|
||||||
|
metadata_dict_to_node,
|
||||||
|
)
|
||||||
|
|
||||||
|
import bm25s
|
||||||
|
from app.engine.retriever.CHTokener import chTokenize
|
||||||
|
|
||||||
|
CHDEFAULT_PERSIST_ARGS = {"similarity_top_k": "similarity_top_k", "_verbose": "verbose"}
|
||||||
|
|
||||||
|
CHDEFAULT_PERSIST_FILENAME = "retriever.json"
|
||||||
|
|
||||||
|
class CHBM25Retriever(BaseRetriever):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
nodes: Optional[List[BaseNode]] = None,
|
||||||
|
existing_bm25: Optional[bm25s.BM25] = None,
|
||||||
|
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
|
||||||
|
callback_manager: Optional[CallbackManager] = None,
|
||||||
|
objects: Optional[List[IndexNode]] = None,
|
||||||
|
object_map: Optional[dict] = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> None:
|
||||||
|
self.similarity_top_k = similarity_top_k
|
||||||
|
if existing_bm25 is not None:
|
||||||
|
self.bm25 = existing_bm25
|
||||||
|
self.corpus = existing_bm25.corpus
|
||||||
|
else:
|
||||||
|
from nltk.corpus import stopwords
|
||||||
|
if nodes is None:
|
||||||
|
raise ValueError("Please pass nodes or an existing BM25 object.")
|
||||||
|
|
||||||
|
self.corpus = [node_to_metadata_dict(node) for node in nodes]
|
||||||
|
|
||||||
|
corpus_tokens = chTokenize(
|
||||||
|
[node.get_content() for node in nodes],
|
||||||
|
show_progress=verbose,
|
||||||
|
)
|
||||||
|
self.bm25 = bm25s.BM25()
|
||||||
|
self.bm25.index(corpus_tokens, show_progress=verbose)
|
||||||
|
super().__init__(
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
object_map=object_map,
|
||||||
|
objects=objects,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_defaults(
|
||||||
|
cls,
|
||||||
|
index: Optional[VectorStoreIndex] = None,
|
||||||
|
nodes: Optional[List[BaseNode]] = None,
|
||||||
|
docstore: Optional[BaseDocumentStore] = None,
|
||||||
|
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> "CHBM25Retriever":
|
||||||
|
if sum(bool(val) for val in [index, nodes, docstore]) != 1:
|
||||||
|
raise ValueError("Please pass exactly one of index, nodes, or docstore.")
|
||||||
|
|
||||||
|
if index is not None:
|
||||||
|
docstore = index.docstore
|
||||||
|
|
||||||
|
if docstore is not None:
|
||||||
|
nodes = cast(List[BaseNode], list(docstore.docs.values()))
|
||||||
|
|
||||||
|
assert (
|
||||||
|
nodes is not None
|
||||||
|
), "Please pass exactly one of index, nodes, or docstore."
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
nodes=nodes,
|
||||||
|
similarity_top_k=similarity_top_k,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_persist_args(self) -> Dict[str, Any]:
|
||||||
|
"""Get Persist Args Dict to Save."""
|
||||||
|
return {
|
||||||
|
CHDEFAULT_PERSIST_ARGS[key]: getattr(self, key)
|
||||||
|
for key in CHDEFAULT_PERSIST_ARGS
|
||||||
|
if hasattr(self, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
def persist(self, path: str, **kwargs: Any) -> None:
|
||||||
|
"""Persist the retriever to a directory."""
|
||||||
|
self.bm25.save(path, corpus=self.corpus, **kwargs)
|
||||||
|
with open(os.path.join(path, CHDEFAULT_PERSIST_FILENAME), "w") as f:
|
||||||
|
json.dump(self.get_persist_args(), f, indent=2)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_persist_dir(cls, path: str, **kwargs: Any) -> "CHBM25Retriever":
|
||||||
|
"""Load the retriever from a directory."""
|
||||||
|
bm25 = bm25s.BM25.load(path, load_corpus=True, **kwargs)
|
||||||
|
with open(os.path.join(path, CHDEFAULT_PERSIST_FILENAME)) as f:
|
||||||
|
retriever_data = json.load(f)
|
||||||
|
return cls(existing_bm25=bm25, **retriever_data)
|
||||||
|
|
||||||
|
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||||
|
query = query_bundle.query_str
|
||||||
|
tokenized_query = chTokenize(
|
||||||
|
query,show_progress=self._verbose
|
||||||
|
)
|
||||||
|
indexes, scores = self.bm25.retrieve(
|
||||||
|
tokenized_query, k=self.similarity_top_k, show_progress=self._verbose
|
||||||
|
)
|
||||||
|
|
||||||
|
# batched, but only one query
|
||||||
|
indexes = indexes[0]
|
||||||
|
scores = scores[0]
|
||||||
|
|
||||||
|
nodes: List[NodeWithScore] = []
|
||||||
|
for idx, score in zip(indexes, scores):
|
||||||
|
# idx can be an int or a dict of the node
|
||||||
|
if isinstance(idx, dict):
|
||||||
|
node = metadata_dict_to_node(idx)
|
||||||
|
else:
|
||||||
|
node_dict = self.corpus[int(idx)]
|
||||||
|
node = metadata_dict_to_node(node_dict)
|
||||||
|
nodes.append(NodeWithScore(node=node, score=float(score)))
|
||||||
|
|
||||||
|
return nodes
|
||||||
@@ -0,0 +1,46 @@
|
|||||||
|
from typing import Any, Dict, List, Union, Callable, NamedTuple
|
||||||
|
from bm25s.tokenization import *
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
except ImportError:
|
||||||
|
|
||||||
|
def tqdm(iterable, *args, **kwargs):
|
||||||
|
return iterable
|
||||||
|
|
||||||
|
|
||||||
|
def chinese_tokenizer(text: str) -> List[str]:
|
||||||
|
import jieba
|
||||||
|
from nltk.corpus import stopwords
|
||||||
|
tokens = jieba.lcut(text)
|
||||||
|
return [token for token in tokens if token not in stopwords.words('chinese')]
|
||||||
|
|
||||||
|
def chTokenize(
|
||||||
|
texts,
|
||||||
|
show_progress: bool = True,
|
||||||
|
leave: bool = False,
|
||||||
|
) -> Union[List[List[str]], Tokenized]:
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
|
||||||
|
corpus_ids = []
|
||||||
|
token_to_index = {}
|
||||||
|
|
||||||
|
for text in tqdm(
|
||||||
|
texts, desc="Split strings", leave=leave, disable=not show_progress
|
||||||
|
):
|
||||||
|
|
||||||
|
splitted = chinese_tokenizer(text)
|
||||||
|
doc_ids = []
|
||||||
|
|
||||||
|
for token in splitted:
|
||||||
|
if token not in token_to_index:
|
||||||
|
token_to_index[token] = len(token_to_index)
|
||||||
|
|
||||||
|
token_id = token_to_index[token]
|
||||||
|
doc_ids.append(token_id)
|
||||||
|
|
||||||
|
corpus_ids.append(doc_ids)
|
||||||
|
|
||||||
|
return Tokenized(ids=corpus_ids, vocab=token_to_index)
|
||||||
|
|
||||||
@@ -0,0 +1,67 @@
|
|||||||
|
import os
|
||||||
|
from typing import Optional, Any, Dict, List
|
||||||
|
|
||||||
|
from llama_index.core.base.base_retriever import BaseRetriever
|
||||||
|
from llama_index.core.schema import NodeWithScore, QueryBundle
|
||||||
|
|
||||||
|
from app.engine.retriever.CHBM25Retriever import CHBM25Retriever
|
||||||
|
|
||||||
|
|
||||||
|
class HybridRetriever(BaseRetriever):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vector_index,
|
||||||
|
similarity_top_k: int = 2,
|
||||||
|
out_top_k: Optional[int] = None,
|
||||||
|
alpha: float = 0.5,
|
||||||
|
filters = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._vector_index = vector_index
|
||||||
|
self._embed_model = vector_index._embed_model
|
||||||
|
self._out_top_k = out_top_k or similarity_top_k
|
||||||
|
self._vecRetriever = vector_index.as_retriever(
|
||||||
|
similarity_top_k=similarity_top_k,filters = filters
|
||||||
|
)
|
||||||
|
|
||||||
|
STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm")
|
||||||
|
if os.path.exists(STORAGE_DIR) and len(os.listdir(STORAGE_DIR)) > 0:
|
||||||
|
self._bm25Retriever = CHBM25Retriever.from_persist_dir(STORAGE_DIR)
|
||||||
|
else:
|
||||||
|
bmRetriver = CHBM25Retriever.from_defaults(similarity_top_k=similarity_top_k,nodes=self._vector_index.vector_store.get_nodes(None))
|
||||||
|
bmRetriver.persist(STORAGE_DIR)
|
||||||
|
self._alpha = alpha
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||||
|
vecNodes:List[NodeWithScore] = self._vecRetriever.retrieve(query_bundle.query_str)
|
||||||
|
bmNodes:List[NodeWithScore] = self._bm25Retriever.retrieve(query_bundle.query_str)
|
||||||
|
|
||||||
|
bmDic:Dict[str,NodeWithScore] = {}
|
||||||
|
for node in bmNodes:
|
||||||
|
bmDic[node.node_id] = node
|
||||||
|
|
||||||
|
result_tups = []
|
||||||
|
for i in range(len(vecNodes)):
|
||||||
|
node = vecNodes[i]
|
||||||
|
bmScore = 0.0
|
||||||
|
if node.node_id in bmDic:
|
||||||
|
bmScore = bmDic[node.node_id].score
|
||||||
|
bmDic.pop(node.node_id)
|
||||||
|
else:
|
||||||
|
bmScore = 0.0
|
||||||
|
full_similarity = (self._alpha * node.score) + (
|
||||||
|
(1 - self._alpha) * bmScore
|
||||||
|
)
|
||||||
|
result_tups.append((full_similarity, node))
|
||||||
|
|
||||||
|
for _,node in bmDic.items():
|
||||||
|
full_similarity = (1 - self._alpha) * node.score
|
||||||
|
result_tups.append((full_similarity, node))
|
||||||
|
|
||||||
|
result_tups = sorted(result_tups, key=lambda x: x[0], reverse=True)
|
||||||
|
for full_score, node in result_tups:
|
||||||
|
node.score = full_score
|
||||||
|
return [n for _, n in result_tups][:self._out_top_k]
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
from llama_index.core.tools.function_tool import FunctionTool
|
||||||
|
|
||||||
|
|
||||||
|
def duckduckgo_search(
|
||||||
|
query: str,
|
||||||
|
region: str = "wt-wt",
|
||||||
|
max_results: int = 10,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Use this function to search for any query in DuckDuckGo.
|
||||||
|
Args:
|
||||||
|
query (str): The query to search in DuckDuckGo.
|
||||||
|
region Optional(str): The region to be used for the search in [country-language] convention, ex us-en, uk-en, ru-ru, etc...
|
||||||
|
max_results Optional(int): The maximum number of results to be returned. Default is 10.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from duckduckgo_search import DDGS
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"duckduckgo_search package is required to use this function."
|
||||||
|
"Please install it by running: `poetry add duckduckgo_search` or `pip install duckduckgo_search`"
|
||||||
|
)
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"keywords": query,
|
||||||
|
"region": region,
|
||||||
|
"max_results": max_results,
|
||||||
|
}
|
||||||
|
results = []
|
||||||
|
with DDGS() as ddg:
|
||||||
|
results = list(ddg.text(**params))
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def get_tools(**kwargs):
|
||||||
|
return [FunctionTool.from_defaults(duckduckgo_search)]
|
||||||
@@ -0,0 +1,60 @@
|
|||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
import json
|
||||||
|
import importlib
|
||||||
|
from cachetools import cached, LRUCache
|
||||||
|
from llama_index.core.tools.tool_spec.base import BaseToolSpec
|
||||||
|
from llama_index.core.tools.function_tool import FunctionTool
|
||||||
|
|
||||||
|
|
||||||
|
class ToolType:
|
||||||
|
LLAMAHUB = "llamahub"
|
||||||
|
LOCAL = "local"
|
||||||
|
|
||||||
|
|
||||||
|
class ToolFactory:
|
||||||
|
|
||||||
|
TOOL_SOURCE_PACKAGE_MAP = {
|
||||||
|
ToolType.LLAMAHUB: "llama_index.tools",
|
||||||
|
ToolType.LOCAL: "app.engine.tools",
|
||||||
|
}
|
||||||
|
|
||||||
|
def load_tools(tool_type: str, tool_name: str, config: dict) -> list[FunctionTool]:
|
||||||
|
source_package = ToolFactory.TOOL_SOURCE_PACKAGE_MAP[tool_type]
|
||||||
|
try:
|
||||||
|
if "ToolSpec" in tool_name:
|
||||||
|
tool_package, tool_cls_name = tool_name.split(".")
|
||||||
|
module_name = f"{source_package}.{tool_package}"
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
tool_class = getattr(module, tool_cls_name)
|
||||||
|
tool_spec: BaseToolSpec = tool_class(**config)
|
||||||
|
return tool_spec.to_tool_list()
|
||||||
|
else:
|
||||||
|
module = importlib.import_module(f"{source_package}.{tool_name}")
|
||||||
|
tools = module.get_tools(**config)
|
||||||
|
if not all(isinstance(tool, FunctionTool) for tool in tools):
|
||||||
|
raise ValueError(
|
||||||
|
f"The module {module} does not contain valid tools"
|
||||||
|
)
|
||||||
|
return tools
|
||||||
|
except ImportError as e:
|
||||||
|
raise ValueError(f"Failed to import tool {tool_name}: {e}")
|
||||||
|
except AttributeError as e:
|
||||||
|
raise ValueError(f"Failed to load tool {tool_name}: {e}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_env() -> list[FunctionTool]:
|
||||||
|
tools = []
|
||||||
|
if os.path.exists("config/tools.yaml"):
|
||||||
|
with open("config/tools.yaml", "r") as f:
|
||||||
|
tool_configs = yaml.safe_load(f)
|
||||||
|
if tool_configs != None and len(tool_configs.items()) != 0:
|
||||||
|
for tool_type, config_entries in tool_configs.items():
|
||||||
|
if config_entries == None or len(config_entries.items()) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for tool_name, config in config_entries.items():
|
||||||
|
tools.extend(
|
||||||
|
ToolFactory.load_tools(tool_type, tool_name, config)
|
||||||
|
)
|
||||||
|
return tools
|
||||||
@@ -0,0 +1,108 @@
|
|||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
import logging
|
||||||
|
import requests
|
||||||
|
from typing import Optional
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from llama_index.core.tools import FunctionTool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGeneratorToolOutput(BaseModel):
|
||||||
|
is_success: bool = Field(
|
||||||
|
...,
|
||||||
|
description="Whether the image generation was successful.",
|
||||||
|
)
|
||||||
|
image_url: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="The URL of the generated image.",
|
||||||
|
)
|
||||||
|
error_message: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="The error message if the image generation failed.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGeneratorTool:
|
||||||
|
_IMG_OUTPUT_FORMAT = "webp"
|
||||||
|
_IMG_OUTPUT_DIR = "output/tool"
|
||||||
|
_IMG_GEN_API = "https://api.stability.ai/v2beta/stable-image/generate/core"
|
||||||
|
|
||||||
|
def __init__(self, api_key: str = None):
|
||||||
|
if not api_key:
|
||||||
|
api_key = os.getenv("STABILITY_API_KEY")
|
||||||
|
self._api_key = api_key
|
||||||
|
self.fileserver_url_prefix = os.getenv("FILESERVER_URL_PREFIX")
|
||||||
|
if self._api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"STABILITY_API_KEY key is required to run image generator. Get it here: https://platform.stability.ai/account/keys"
|
||||||
|
)
|
||||||
|
if self.fileserver_url_prefix is None:
|
||||||
|
raise ValueError("FILESERVER_URL_PREFIX is required.")
|
||||||
|
|
||||||
|
def _prepare_output_dir(self):
|
||||||
|
"""
|
||||||
|
Create the output directory if it doesn't exist
|
||||||
|
"""
|
||||||
|
if not os.path.exists(self._IMG_OUTPUT_DIR):
|
||||||
|
os.makedirs(self._IMG_OUTPUT_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
def _save_image(self, image_data: bytes):
|
||||||
|
self._prepare_output_dir()
|
||||||
|
filename = f"{uuid.uuid4()}.{self._IMG_OUTPUT_FORMAT}"
|
||||||
|
output_path = os.path.join(self._IMG_OUTPUT_DIR, filename)
|
||||||
|
with open(output_path, "wb") as f:
|
||||||
|
f.write(image_data)
|
||||||
|
url = f"{os.getenv('FILESERVER_URL_PREFIX')}/{self._IMG_OUTPUT_DIR}/{filename}"
|
||||||
|
logger.info(f"Saved image to {output_path}.\nURL: {url}")
|
||||||
|
return url
|
||||||
|
|
||||||
|
def _call_stability_api(self, prompt: str):
|
||||||
|
headers = {
|
||||||
|
"authorization": f"Bearer {self._api_key}",
|
||||||
|
"accept": "image/*",
|
||||||
|
}
|
||||||
|
data = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"output_format": self._IMG_OUTPUT_FORMAT,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self._IMG_GEN_API,
|
||||||
|
headers=headers,
|
||||||
|
files={"none": ""},
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def generate_image(self, prompt: str) -> ImageGeneratorToolOutput:
|
||||||
|
"""
|
||||||
|
Use this tool to generate an image based on the prompt.
|
||||||
|
Args:
|
||||||
|
prompt (str): The prompt to generate the image from.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Call the Stability API
|
||||||
|
response = self._call_stability_api(prompt)
|
||||||
|
|
||||||
|
# Save the image and get the URL
|
||||||
|
image_url = self._save_image(response.content)
|
||||||
|
|
||||||
|
return ImageGeneratorToolOutput(
|
||||||
|
is_success=True,
|
||||||
|
image_url=image_url,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(e, exc_info=True)
|
||||||
|
return ImageGeneratorToolOutput(
|
||||||
|
is_success=False,
|
||||||
|
error_message=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tools(**kwargs):
|
||||||
|
return [FunctionTool.from_defaults(ImageGeneratorTool(**kwargs).generate_image)]
|
||||||
@@ -0,0 +1,143 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import base64
|
||||||
|
import uuid
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List, Tuple, Dict, Optional
|
||||||
|
from llama_index.core.tools import FunctionTool
|
||||||
|
from e2b_code_interpreter import CodeInterpreter
|
||||||
|
from e2b_code_interpreter.models import Logs
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class InterpreterExtraResult(BaseModel):
|
||||||
|
type: str
|
||||||
|
content: Optional[str] = None
|
||||||
|
filename: Optional[str] = None
|
||||||
|
url: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class E2BToolOutput(BaseModel):
|
||||||
|
is_error: bool
|
||||||
|
logs: Logs
|
||||||
|
results: List[InterpreterExtraResult] = []
|
||||||
|
|
||||||
|
|
||||||
|
class E2BCodeInterpreter:
|
||||||
|
|
||||||
|
output_dir = "output/tool"
|
||||||
|
|
||||||
|
def __init__(self, api_key: str = None):
|
||||||
|
if api_key is None:
|
||||||
|
api_key = os.getenv("E2B_API_KEY")
|
||||||
|
filesever_url_prefix = os.getenv("FILESERVER_URL_PREFIX")
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError(
|
||||||
|
"E2B_API_KEY key is required to run code interpreter. Get it here: https://e2b.dev/docs/getting-started/api-key"
|
||||||
|
)
|
||||||
|
if not filesever_url_prefix:
|
||||||
|
raise ValueError(
|
||||||
|
"FILESERVER_URL_PREFIX is required to display file output from sandbox"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.filesever_url_prefix = filesever_url_prefix
|
||||||
|
self.interpreter = CodeInterpreter(api_key=api_key)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.interpreter.close()
|
||||||
|
|
||||||
|
def get_output_path(self, filename: str) -> str:
|
||||||
|
# if output directory doesn't exist, create it
|
||||||
|
if not os.path.exists(self.output_dir):
|
||||||
|
os.makedirs(self.output_dir, exist_ok=True)
|
||||||
|
return os.path.join(self.output_dir, filename)
|
||||||
|
|
||||||
|
def save_to_disk(self, base64_data: str, ext: str) -> Dict:
|
||||||
|
filename = f"{uuid.uuid4()}.{ext}" # generate a unique filename
|
||||||
|
buffer = base64.b64decode(base64_data)
|
||||||
|
output_path = self.get_output_path(filename)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(output_path, "wb") as file:
|
||||||
|
file.write(buffer)
|
||||||
|
except IOError as e:
|
||||||
|
logger.error(f"Failed to write to file {output_path}: {str(e)}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
logger.info(f"Saved file to {output_path}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"outputPath": output_path,
|
||||||
|
"filename": filename,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_file_url(self, filename: str) -> str:
|
||||||
|
return f"{self.filesever_url_prefix}/{self.output_dir}/{filename}"
|
||||||
|
|
||||||
|
def parse_result(self, result) -> List[InterpreterExtraResult]:
|
||||||
|
"""
|
||||||
|
The result could include multiple formats (e.g. png, svg, etc.) but encoded in base64
|
||||||
|
We save each result to disk and return saved file metadata (extension, filename, url)
|
||||||
|
"""
|
||||||
|
if not result:
|
||||||
|
return []
|
||||||
|
|
||||||
|
output = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
formats = result.formats()
|
||||||
|
results = [result[format] for format in formats]
|
||||||
|
|
||||||
|
for ext, data in zip(formats, results):
|
||||||
|
match ext:
|
||||||
|
case "png" | "svg" | "jpeg" | "pdf":
|
||||||
|
result = self.save_to_disk(data, ext)
|
||||||
|
filename = result["filename"]
|
||||||
|
output.append(
|
||||||
|
InterpreterExtraResult(
|
||||||
|
type=ext,
|
||||||
|
filename=filename,
|
||||||
|
url=self.get_file_url(filename),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
case _:
|
||||||
|
output.append(
|
||||||
|
InterpreterExtraResult(
|
||||||
|
type=ext,
|
||||||
|
content=data,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as error:
|
||||||
|
logger.exception(error, exc_info=True)
|
||||||
|
logger.error("Error when parsing output from E2b interpreter tool", error)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def interpret(self, code: str) -> E2BToolOutput:
|
||||||
|
"""
|
||||||
|
Execute python code in a Jupyter notebook cell, the toll will return result, stdout, stderr, display_data, and error.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
code (str): The python code to be executed in a single cell.
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"\n{'='*50}\n> Running following AI-generated code:\n{code}\n{'='*50}"
|
||||||
|
)
|
||||||
|
exec = self.interpreter.notebook.exec_cell(code)
|
||||||
|
|
||||||
|
if exec.error:
|
||||||
|
logger.error("Error when executing code", exec.error)
|
||||||
|
output = E2BToolOutput(is_error=True, logs=exec.logs, results=[])
|
||||||
|
else:
|
||||||
|
if len(exec.results) == 0:
|
||||||
|
output = E2BToolOutput(is_error=False, logs=exec.logs, results=[])
|
||||||
|
else:
|
||||||
|
results = self.parse_result(exec.results[0])
|
||||||
|
output = E2BToolOutput(is_error=False, logs=exec.logs, results=results)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def get_tools(**kwargs):
|
||||||
|
return [FunctionTool.from_defaults(E2BCodeInterpreter(**kwargs).interpret)]
|
||||||
@@ -0,0 +1,78 @@
|
|||||||
|
from typing import Dict, List, Tuple
|
||||||
|
from llama_index.tools.openapi import OpenAPIToolSpec
|
||||||
|
from llama_index.tools.requests import RequestsToolSpec
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAPIActionToolSpec(OpenAPIToolSpec, RequestsToolSpec):
|
||||||
|
"""
|
||||||
|
A combination of OpenAPI and Requests tool specs that can parse OpenAPI specs and make requests.
|
||||||
|
|
||||||
|
openapi_uri: str: The file path or URL to the OpenAPI spec.
|
||||||
|
domain_headers: dict: Whitelist domains and the headers to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
spec_functions = OpenAPIToolSpec.spec_functions + RequestsToolSpec.spec_functions
|
||||||
|
# Cached parsed specs by URI
|
||||||
|
_specs: Dict[str, Tuple[Dict, List[str]]] = {}
|
||||||
|
|
||||||
|
def __init__(self, openapi_uri: str, domain_headers: dict = None, **kwargs):
|
||||||
|
if domain_headers is None:
|
||||||
|
domain_headers = {}
|
||||||
|
if openapi_uri not in self._specs:
|
||||||
|
openapi_spec, servers = self._load_openapi_spec(openapi_uri)
|
||||||
|
self._specs[openapi_uri] = (openapi_spec, servers)
|
||||||
|
else:
|
||||||
|
openapi_spec, servers = self._specs[openapi_uri]
|
||||||
|
|
||||||
|
# Add the servers to the domain headers if they are not already present
|
||||||
|
for server in servers:
|
||||||
|
if server not in domain_headers:
|
||||||
|
domain_headers[server] = {}
|
||||||
|
|
||||||
|
OpenAPIToolSpec.__init__(self, spec=openapi_spec)
|
||||||
|
RequestsToolSpec.__init__(self, domain_headers)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_openapi_spec(uri: str) -> Tuple[Dict, List[str]]:
|
||||||
|
"""
|
||||||
|
Load an OpenAPI spec from a URI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
uri (str): A file path or URL to the OpenAPI spec.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Document]: A list of Document objects.
|
||||||
|
"""
|
||||||
|
import yaml
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
if uri.startswith("http"):
|
||||||
|
import requests
|
||||||
|
|
||||||
|
response = requests.get(uri)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not initialize OpenAPIActionToolSpec: "
|
||||||
|
f"Failed to load OpenAPI spec from {uri}, status code: {response.status_code}"
|
||||||
|
)
|
||||||
|
spec = yaml.safe_load(response.text)
|
||||||
|
elif uri.startswith("file"):
|
||||||
|
filepath = urlparse(uri).path
|
||||||
|
with open(filepath, "r") as file:
|
||||||
|
spec = yaml.safe_load(file)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not initialize OpenAPIActionToolSpec: Invalid OpenAPI URI provided. "
|
||||||
|
"Only HTTP and file path are supported."
|
||||||
|
)
|
||||||
|
# Add the servers to the whitelist
|
||||||
|
try:
|
||||||
|
servers = [
|
||||||
|
urlparse(server["url"]).netloc for server in spec.get("servers", [])
|
||||||
|
]
|
||||||
|
except KeyError as e:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not initialize OpenAPIActionToolSpec: Invalid OpenAPI spec provided. "
|
||||||
|
"Could not get `servers` from the spec."
|
||||||
|
) from e
|
||||||
|
return spec, servers
|
||||||
@@ -0,0 +1,73 @@
|
|||||||
|
"""Open Meteo weather map tool spec."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import requests
|
||||||
|
import pytz
|
||||||
|
from llama_index.core.tools import FunctionTool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenMeteoWeather:
|
||||||
|
geo_api = "https://geocoding-api.open-meteo.com/v1"
|
||||||
|
weather_api = "https://api.open-meteo.com/v1"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_geo_location(cls, location: str) -> dict:
|
||||||
|
"""Get geo location from location name."""
|
||||||
|
params = {"name": location, "count": 10, "language": "en", "format": "json"}
|
||||||
|
response = requests.get(f"{cls.geo_api}/search", params=params)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f"Failed to fetch geo location: {response.status_code}")
|
||||||
|
else:
|
||||||
|
data = response.json()
|
||||||
|
result = data["results"][0]
|
||||||
|
geo_location = {
|
||||||
|
"id": result["id"],
|
||||||
|
"name": result["name"],
|
||||||
|
"latitude": result["latitude"],
|
||||||
|
"longitude": result["longitude"],
|
||||||
|
}
|
||||||
|
return geo_location
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_weather_information(cls, location: str) -> dict:
|
||||||
|
"""Use this function to get the weather of any given location.
|
||||||
|
Note that the weather code should follow WMO Weather interpretation codes (WW):
|
||||||
|
0: Clear sky
|
||||||
|
1, 2, 3: Mainly clear, partly cloudy, and overcast
|
||||||
|
45, 48: Fog and depositing rime fog
|
||||||
|
51, 53, 55: Drizzle: Light, moderate, and dense intensity
|
||||||
|
56, 57: Freezing Drizzle: Light and dense intensity
|
||||||
|
61, 63, 65: Rain: Slight, moderate and heavy intensity
|
||||||
|
66, 67: Freezing Rain: Light and heavy intensity
|
||||||
|
71, 73, 75: Snow fall: Slight, moderate, and heavy intensity
|
||||||
|
77: Snow grains
|
||||||
|
80, 81, 82: Rain showers: Slight, moderate, and violent
|
||||||
|
85, 86: Snow showers slight and heavy
|
||||||
|
95: Thunderstorm: Slight or moderate
|
||||||
|
96, 99: Thunderstorm with slight and heavy hail
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"Calling open-meteo api to get weather information of location: {location}"
|
||||||
|
)
|
||||||
|
geo_location = cls._get_geo_location(location)
|
||||||
|
timezone = pytz.timezone("UTC").zone
|
||||||
|
params = {
|
||||||
|
"latitude": geo_location["latitude"],
|
||||||
|
"longitude": geo_location["longitude"],
|
||||||
|
"current": "temperature_2m,weather_code",
|
||||||
|
"hourly": "temperature_2m,weather_code",
|
||||||
|
"daily": "weather_code",
|
||||||
|
"timezone": timezone,
|
||||||
|
}
|
||||||
|
response = requests.get(f"{cls.weather_api}/forecast", params=params)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to fetch weather information: {response.status_code}"
|
||||||
|
)
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
def get_tools(**kwargs):
|
||||||
|
return [FunctionTool.from_defaults(OpenMeteoWeather.get_weather_information)]
|
||||||
@@ -0,0 +1,272 @@
|
|||||||
|
"""Xinference embeddings file."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
from http import HTTPStatus
|
||||||
|
from typing import Any, Dict, List, Optional, Union, Tuple
|
||||||
|
|
||||||
|
from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding, dispatcher
|
||||||
|
from llama_index.core.bridge.pydantic import PrivateAttr
|
||||||
|
from llama_index.core.callbacks import CBEventType, EventPayload
|
||||||
|
from llama_index.core.embeddings.multi_modal_base import MultiModalEmbedding
|
||||||
|
from llama_index.core.instrumentation.events.rerank import ReRankStartEvent, ReRankEndEvent
|
||||||
|
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||||
|
from llama_index.core.schema import ImageType, NodeWithScore, QueryBundle
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
EMBED_MAX_INPUT_LENGTH = 2048
|
||||||
|
EMBED_MAX_BATCH_SIZE = 1
|
||||||
|
|
||||||
|
|
||||||
|
class XinferenceEmbedding(BaseEmbedding):
|
||||||
|
"""Xinference class for text embedding.
|
||||||
|
|
||||||
|
"""
|
||||||
|
model_description: Dict[str, Any] = Field(
|
||||||
|
description="The model description from Xinference."
|
||||||
|
)
|
||||||
|
_generator: Any = PrivateAttr()
|
||||||
|
_model_uid: str = Field(description="The Xinference model to use.")
|
||||||
|
_endpoint: str = Field(description="The Xinference endpoint URL to use.")
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_uid: str,
|
||||||
|
endpoint: str,
|
||||||
|
embed_batch_size: int = EMBED_MAX_BATCH_SIZE,
|
||||||
|
dimensions: Optional[int] = None,
|
||||||
|
additional_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
api_version: Optional[str] = None,
|
||||||
|
max_retries: int = 10,
|
||||||
|
# timeout: float = 60.0,
|
||||||
|
# reuse_client: bool = True,
|
||||||
|
# callback_manager: Optional[CallbackManager] = None,
|
||||||
|
# default_headers: Optional[Dict[str, str]] = None,
|
||||||
|
# http_client: Optional[httpx.Client] = None,
|
||||||
|
# async_http_client: Optional[httpx.AsyncClient] = None,
|
||||||
|
# num_workers: Optional[int] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
generator, model_description, embed_batch_size, dimensions = self.load_model(
|
||||||
|
model_uid, endpoint
|
||||||
|
)
|
||||||
|
self._generator = generator
|
||||||
|
#self._model_uid = model_uid
|
||||||
|
#self._endpoint = endpoint
|
||||||
|
super().__init__(
|
||||||
|
embed_batch_size=embed_batch_size,
|
||||||
|
dimensions=dimensions,
|
||||||
|
#callback_manager=callback_manager,
|
||||||
|
model_name=model_uid,
|
||||||
|
additional_kwargs=additional_kwargs,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
max_retries=max_retries,
|
||||||
|
# reuse_client=reuse_client,
|
||||||
|
# timeout=timeout,
|
||||||
|
# default_headers=default_headers,
|
||||||
|
# num_workers=num_workers,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]:
|
||||||
|
try:
|
||||||
|
from xinference.client import RESTfulClient
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import Xinference library."
|
||||||
|
'Please install Xinference with `pip install "xinference[all]"`'
|
||||||
|
)
|
||||||
|
|
||||||
|
client = RESTfulClient(endpoint)
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert isinstance(client, RESTfulClient)
|
||||||
|
except AssertionError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Could not create RESTfulClient instance."
|
||||||
|
"Please make sure Xinference endpoint is running at the correct port."
|
||||||
|
)
|
||||||
|
|
||||||
|
generator = client.get_model(model_uid)
|
||||||
|
model_description = client.list_models()[model_uid]
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert generator is not None
|
||||||
|
assert model_description is not None
|
||||||
|
except AssertionError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Could not get model from endpoint."
|
||||||
|
"Please make sure Xinference endpoint is running at the correct port."
|
||||||
|
)
|
||||||
|
|
||||||
|
model = model_description["model_name"]
|
||||||
|
replica = model_description['replica']
|
||||||
|
dimensions = model_description['dimensions']
|
||||||
|
max_tokens = model_description['max_tokens']
|
||||||
|
|
||||||
|
return generator, model_description, replica, dimensions
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def class_name(cls) -> str:
|
||||||
|
return "XinferenceEmbedding"
|
||||||
|
|
||||||
|
def _get_text_embedding(self, text: str) -> Embedding:
|
||||||
|
"""
|
||||||
|
Embed the input text synchronously.
|
||||||
|
|
||||||
|
Subclasses should implement this method. Reference get_text_embedding's
|
||||||
|
docstring for more information.
|
||||||
|
"""
|
||||||
|
assert self._generator is not None
|
||||||
|
|
||||||
|
response = self._generator.create_embedding(input=text)
|
||||||
|
return response['data'][0]['embedding']
|
||||||
|
|
||||||
|
def _get_query_embedding(self, query: str) -> Embedding:
|
||||||
|
"""
|
||||||
|
Embed the input query synchronously.
|
||||||
|
|
||||||
|
Subclasses should implement this method. Reference get_query_embedding's
|
||||||
|
docstring for more information.
|
||||||
|
"""
|
||||||
|
return self._get_text_embedding(query)
|
||||||
|
|
||||||
|
async def _aget_query_embedding(self, query: str) -> Embedding:
|
||||||
|
"""
|
||||||
|
Embed the input query asynchronously.
|
||||||
|
|
||||||
|
Subclasses should implement this method. Reference get_query_embedding's
|
||||||
|
docstring for more information.
|
||||||
|
"""
|
||||||
|
return self._get_query_embedding(query)
|
||||||
|
|
||||||
|
class XinferenceRerank(BaseNodePostprocessor):
|
||||||
|
"""Xinference class for rerank.
|
||||||
|
|
||||||
|
"""
|
||||||
|
model_description: Dict[str, Any] = Field(
|
||||||
|
description="The model description from Xinference."
|
||||||
|
)
|
||||||
|
_generator: Any = PrivateAttr()
|
||||||
|
_model_uid: str = Field(description="The Xinference model to use.")
|
||||||
|
_endpoint: str = Field(description="The Xinference endpoint URL to use.")
|
||||||
|
model: str = Field(description="Dashscope rerank model name.")
|
||||||
|
top_n: int = Field(description="Top N nodes to return.")
|
||||||
|
threshold: float = Field(description="threshold nodes to return.")
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_uid: str,
|
||||||
|
endpoint: str,
|
||||||
|
top_n: int = None,
|
||||||
|
threshold: float = None,
|
||||||
|
return_documents: bool = False
|
||||||
|
):
|
||||||
|
_model_uid = model_uid
|
||||||
|
_endpoint = endpoint
|
||||||
|
_op_n = top_n
|
||||||
|
threshold = threshold
|
||||||
|
generator, model_description = self.load_model(
|
||||||
|
model_uid, endpoint
|
||||||
|
)
|
||||||
|
self._generator = generator
|
||||||
|
super().__init__(top_n=top_n, model=model_uid, model_uid=model_uid, threshold = threshold, return_documents=return_documents)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def class_name(cls) -> str:
|
||||||
|
return "XinferenceRerank"
|
||||||
|
|
||||||
|
def _postprocess_nodes(
|
||||||
|
self,
|
||||||
|
nodes: List[NodeWithScore],
|
||||||
|
query_bundle: Optional[QueryBundle] = None,
|
||||||
|
) -> List[NodeWithScore]:
|
||||||
|
if query_bundle is None:
|
||||||
|
raise ValueError("Missing query bundle in extra info.")
|
||||||
|
if len(nodes) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
dispatcher.event(
|
||||||
|
ReRankStartEvent(
|
||||||
|
nodes = nodes,
|
||||||
|
top_n = self.top_n,
|
||||||
|
query = query_bundle,
|
||||||
|
model_name = self.model
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.callback_manager.event(
|
||||||
|
CBEventType.RERANKING,
|
||||||
|
payload={
|
||||||
|
EventPayload.NODES: nodes,
|
||||||
|
EventPayload.MODEL_NAME: self._model_uid,
|
||||||
|
EventPayload.QUERY_STR: query_bundle.query_str,
|
||||||
|
EventPayload.TOP_K: self.top_n,
|
||||||
|
},
|
||||||
|
) as event:
|
||||||
|
texts = [node.node.get_content() for node in nodes]
|
||||||
|
response = self._generator.rerank(texts,query_bundle.query_str)
|
||||||
|
new_nodes = []
|
||||||
|
for result in response['results']:
|
||||||
|
new_node_with_score = NodeWithScore(
|
||||||
|
node=nodes[result['index']].node, score=result['relevance_score']
|
||||||
|
)
|
||||||
|
if self.threshold is not None:
|
||||||
|
if new_node_with_score.score >=self.threshold:
|
||||||
|
new_nodes.append(new_node_with_score)
|
||||||
|
|
||||||
|
if self.top_n is not None:
|
||||||
|
if len(new_nodes) > self.top_n:
|
||||||
|
for index in new_nodes[self.top_n:-1]:
|
||||||
|
new_nodes.remove(index)
|
||||||
|
|
||||||
|
event.on_end(payload={EventPayload.NODES: new_nodes})
|
||||||
|
|
||||||
|
dispatcher.event(
|
||||||
|
ReRankEndEvent(
|
||||||
|
nodes= new_nodes
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return new_nodes
|
||||||
|
|
||||||
|
def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]:
|
||||||
|
try:
|
||||||
|
from xinference.client import RESTfulClient
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import Xinference library."
|
||||||
|
'Please install Xinference with `pip install "xinference[all]"`'
|
||||||
|
)
|
||||||
|
|
||||||
|
client = RESTfulClient(endpoint)
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert isinstance(client, RESTfulClient)
|
||||||
|
except AssertionError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Could not create RESTfulClient instance."
|
||||||
|
"Please make sure Xinference endpoint is running at the correct port."
|
||||||
|
)
|
||||||
|
|
||||||
|
generator = client.get_model(model_uid)
|
||||||
|
model_description = client.list_models()[model_uid]
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert generator is not None
|
||||||
|
assert model_description is not None
|
||||||
|
except AssertionError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Could not get model from endpoint."
|
||||||
|
"Please make sure Xinference endpoint is running at the correct port."
|
||||||
|
)
|
||||||
|
|
||||||
|
model = model_description["model_name"]
|
||||||
|
|
||||||
|
return generator, model_description
|
||||||
Reference in New Issue
Block a user