删除误上传的文件
This commit is contained in:
@@ -1,22 +0,0 @@
|
||||
import logging
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from app.engine.vectordb import get_vector_store
|
||||
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
index = None
|
||||
|
||||
def get_index(params=None):
|
||||
global index
|
||||
if index is None:
|
||||
logger.info("Connecting vector store...")
|
||||
|
||||
store = get_vector_store()
|
||||
# Load the index from the vector store
|
||||
# If you are using a vector store that doesn't store text,
|
||||
# you must load the index from both the vector store and the document store
|
||||
index = VectorStoreIndex.from_vector_store(store)
|
||||
logger.info("Finished load index from vector store.")
|
||||
|
||||
return index
|
||||
@@ -1,61 +0,0 @@
|
||||
import os
|
||||
|
||||
from llama_index.core.agent import AgentRunner, ReActChatFormatter
|
||||
from llama_index.core.settings import Settings
|
||||
from llama_index.core.tools.query_engine import QueryEngineTool
|
||||
|
||||
from app.engine.engine import create_query_engine, create_summary_query_engine
|
||||
from app.engine.index import get_index
|
||||
#from app.engine.loaders.db import makeDescriptionByEngine
|
||||
from app.engine.tools import ToolFactory
|
||||
|
||||
|
||||
def get_chat_engine(filters=None, params=None):
|
||||
system_prompt = os.getenv("SYSTEM_PROMPT")
|
||||
top_k = int(os.getenv("TOP_K", "3"))
|
||||
use_reranker = os.getenv("RERANK_ENABLED")
|
||||
tools = []
|
||||
|
||||
# 创建SQL查询工具
|
||||
# sql_query_engine = create_summary_query_engine(index)
|
||||
# sql_query_tool = QueryEngineTool.from_defaults(query_engine=sql_query_engine,
|
||||
# name="zjdata_query_tool",
|
||||
# description="来源于一个由博微公司电力造价软件编制的造价工程文件。该文件以多张表格的形式存储存储了整个工程的全部数据内容。适用于以详细的自然语言查询表格数据方式查询造价工程各项具体属性、费用的数值。请先使用“zj_query_tool”无法解决才使用本工具"
|
||||
# )
|
||||
#tools.append(sql_query_tool)
|
||||
|
||||
# Add query tool if index exists
|
||||
index = get_index()
|
||||
if index is not None:
|
||||
summary_query_engine = create_summary_query_engine(index,top_k,use_reranker,filters)
|
||||
summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool",
|
||||
description="适用于任何需要进行全面总结、概括的要求。",
|
||||
)
|
||||
query_engine = create_query_engine(index,top_k,use_reranker,filters)
|
||||
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool",
|
||||
description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。",
|
||||
)
|
||||
|
||||
tools.append(summary_query_tool)
|
||||
tools.append(query_engine_tool)
|
||||
|
||||
# Add additional tools
|
||||
tools += ToolFactory.from_env()
|
||||
|
||||
prefix_messages = ("""您的设计旨在帮助完成各种任务,从回答问题到提供其他类型分析的摘要。\n\n##工具\n\n你可以访问各种工具。你有责任按照你认为合适的顺序使用这些工具来完成当前的任务。\n这可能需要将任务分解为子任务,并使用不同的工具来完成每个子任务。\n\n你可以访问以下工具:\n{tool_desc}\n\n\n##输出格式\n\n请用与问题相同的语言回答,并使用以下格式:\n\n \nThought: 用户当前的语言是:(user's language)。我需要使用工具来帮助我回答问题。\nAction: 如果使用工具,则为工具名称(one of {tool_names})。\nAction Input: 输入给工具的内容,使用JSON格式表示kwargs(例如{{\"input\": \"hello world\", \"num_beams\": 5}})\n \n\n请始终以Thought开始。\n\n请始终以Thought开始。\n\n请始终以Thought开始。\n\n请始终以Thought开始。\n\n切勿用Markdown代码标记包围你的响应。如果需要,可以在响应中使用代码标记。\n\n请为Action Input使用有效的JSON格式。不要这样做{{\'input\': \'hello world\', \'num_beams\': 5}}。\n\n如果使用此格式,用户将以下面的格式进行回应:\n\n \nObservation: 工具响应\n \n\n你应该继续重复上述格式,直到你有足够的信息来回答问题而无需使用更多工具。此时,你必须使用以下两种格式之一进行回答:\n\n \nThought: 我可以不用任何工具来回答。我将使用用户的语言来回答。\nAnswer: [你的答案(与用户问题相同的语言)]\n \n\n \nThought: 我无法使用提供的工具回答问题。\nAnswer: [你的答案(与用户问题相同的语言)]\n \n\n##如果从工具中得到的回应是Empty Response,那么只需要回答“我不知道”,不需要额外回答别的内容。## 当前对话\n\n以下是当前对话,由人类和助手的消息交替组成。\n""")
|
||||
react_chat_formatter = ReActChatFormatter.from_defaults(prefix_messages)
|
||||
agentrunner = AgentRunner.from_llm(
|
||||
llm=Settings.llm,
|
||||
tools=tools,
|
||||
react_chat_formatter=react_chat_formatter,
|
||||
system_prompt=system_prompt,
|
||||
verbose=True,
|
||||
)
|
||||
return agentrunner
|
||||
# create the function calling worker for reasoning
|
||||
# worker = FunctionCallingAgentWorker.from_tools(
|
||||
# tools, verbose=True
|
||||
# )
|
||||
#
|
||||
# # wrap the worker in the top-level planner
|
||||
# return StructuredPlannerAgent(worker, tools)
|
||||
@@ -1 +0,0 @@
|
||||
STORAGE_DIR = "storage" # directory to cache the generated index
|
||||
@@ -1,108 +0,0 @@
|
||||
import os
|
||||
|
||||
from llama_index.core import SummaryIndex, SQLDatabase, VectorStoreIndex
|
||||
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
|
||||
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema
|
||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
from llama_index.core.response_synthesizers import ResponseMode
|
||||
from llama_index.readers.database import DatabaseReader
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template
|
||||
from app.engine.retriever.HybridRetriever import HybridRetriever
|
||||
from app.settings import get_node_postprocessors
|
||||
|
||||
def makeDescriptionByEngine(sql_database:SQLDatabase):
|
||||
reader = DatabaseReader(sql_database)
|
||||
|
||||
table_names = sql_database.get_usable_table_names()
|
||||
table_schema_objs = []
|
||||
for table_name in table_names:
|
||||
columns = sql_database.get_table_columns(table_name)
|
||||
if len(columns) > 150:
|
||||
continue
|
||||
stats_txt = ""
|
||||
|
||||
if table_name == 'gongchengshuxing':
|
||||
stats_txt = '该表中有以下属性:'
|
||||
documents = reader.load_data(query='select name from gongchengshuxing')
|
||||
for index in range(len(documents) if len(documents) < 30 else 30):
|
||||
if index == 0:
|
||||
continue
|
||||
elif index > 1:
|
||||
stats_txt += ','
|
||||
stats_txt += documents[index].text.split(':')[1]
|
||||
|
||||
tbSchema = (SQLTableSchema(table_name=table_name, context_str=stats_txt))
|
||||
table_schema_objs.append(tbSchema)
|
||||
|
||||
return table_schema_objs
|
||||
|
||||
def get_Retriever(index,**kwargs):
|
||||
strEnableHybrid = os.getenv("HYBRID_ENABLED",'False')
|
||||
bEnableHybrid = True if strEnableHybrid is not None and strEnableHybrid.title() == 'True' else False
|
||||
if bEnableHybrid:
|
||||
alpha = float(os.getenv("HYBRID_ALPHA", "0.5"))
|
||||
retriever = HybridRetriever(index,alpha = alpha,**kwargs)
|
||||
else:
|
||||
retriever = index.as_retriever(**kwargs)
|
||||
return retriever
|
||||
|
||||
|
||||
sql_database = None
|
||||
sql_obj_index = None
|
||||
|
||||
# Create a summary query engine
|
||||
def create_summary_query_engine(top_k=3, use_reranker=False, filters=None):
|
||||
global sql_obj_index
|
||||
global sql_database
|
||||
if sql_obj_index is None or sql_database is None:
|
||||
sqlengine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
|
||||
sql_database = SQLDatabase(sqlengine)
|
||||
table_schema_objs = makeDescriptionByEngine(sql_database)
|
||||
table_node_mapping = SQLTableNodeMapping(sql_database)
|
||||
|
||||
sql_obj_index = ObjectIndex.from_objects(
|
||||
table_schema_objs,
|
||||
table_node_mapping,
|
||||
index_cls=VectorStoreIndex,
|
||||
)
|
||||
|
||||
# 创建SQL查询工具
|
||||
sql_query_engine = SQLTableRetrieverQueryEngine(sql_database,
|
||||
sql_obj_index.as_retriever(similarity_top_k=top_k),
|
||||
verbose=True,
|
||||
)
|
||||
return sql_query_engine
|
||||
|
||||
# Create a summary query engine
|
||||
def create_summary_query_engine(index, top_k=3, use_reranker=False, filters=None):
|
||||
summary_index = SummaryIndex(index.vector_store.get_nodes(node_ids=None))
|
||||
summary_query_engine = summary_index.as_query_engine(
|
||||
response_mode=ResponseMode.TREE_SUMMARIZE,
|
||||
use_async=True,
|
||||
streaming=True,
|
||||
)
|
||||
return summary_query_engine
|
||||
|
||||
# Create a query engine
|
||||
def create_query_engine(index, top_k=3, use_reranker=False, filters=None):
|
||||
# 创建向量检索查询工具
|
||||
postprocess = None
|
||||
if use_reranker:
|
||||
postprocess = get_node_postprocessors()
|
||||
|
||||
query_engine = RetrieverQueryEngine.from_args(
|
||||
get_Retriever(index,
|
||||
similarity_top_k=top_k,
|
||||
filters=filters),
|
||||
text_qa_template=text_qa_template,
|
||||
refine_template=refine_template,
|
||||
summary_template = summary_template,
|
||||
simple_template = simple_template,
|
||||
node_postprocessors=postprocess,
|
||||
use_async=True,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
return query_engine
|
||||
@@ -1,94 +0,0 @@
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from app.engine.loaders import get_documents
|
||||
from app.engine.vectordb import get_vector_store
|
||||
from app.settings import init_settings
|
||||
from app.engine.retriever.CHBM25Retriever import CHBM25Retriever
|
||||
from llama_index.core.ingestion import IngestionPipeline
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.settings import Settings
|
||||
from llama_index.core.storage import StorageContext
|
||||
from llama_index.core.storage.docstore import SimpleDocumentStore
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger()
|
||||
|
||||
STORAGE_DIR = os.getenv("STORAGE_DIR", "storage")
|
||||
|
||||
|
||||
def get_doc_store():
|
||||
|
||||
# If the storage directory is there, load the document store from it.
|
||||
# If not, set up an in-memory document store since we can't load from a directory that doesn't exist.
|
||||
if os.path.exists(STORAGE_DIR):
|
||||
return SimpleDocumentStore.from_persist_dir(STORAGE_DIR)
|
||||
else:
|
||||
return SimpleDocumentStore()
|
||||
|
||||
|
||||
def run_pipeline(docstore, vector_store, documents):
|
||||
pipeline = IngestionPipeline(
|
||||
transformations=[
|
||||
SentenceSplitter(
|
||||
chunk_size=Settings.chunk_size,
|
||||
chunk_overlap=Settings.chunk_overlap,
|
||||
),
|
||||
Settings.embed_model,
|
||||
],
|
||||
docstore=docstore,
|
||||
docstore_strategy="upserts_and_delete",
|
||||
vector_store=vector_store,
|
||||
)
|
||||
|
||||
# Run the ingestion pipeline and store the results
|
||||
nodes = pipeline.run(show_progress=True, documents=documents)
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
def persist_storage(docstore, vector_store):
|
||||
storage_context = StorageContext.from_defaults(
|
||||
docstore=docstore,
|
||||
vector_store=vector_store,
|
||||
)
|
||||
storage_context.persist(STORAGE_DIR)
|
||||
|
||||
|
||||
def persist_BMRetriever(vector_store):
|
||||
STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm")
|
||||
top_k = int(os.getenv("TOP_K", "3"))
|
||||
bmRetriver = CHBM25Retriever.from_defaults(similarity_top_k=top_k,nodes=vector_store.get_nodes([]))
|
||||
bmRetriver.persist(STORAGE_DIR)
|
||||
|
||||
|
||||
def generate_datasource():
|
||||
init_settings()
|
||||
logger.info("Generate index for the provided data")
|
||||
|
||||
# Get the stores and documents or create new ones
|
||||
documents = get_documents()
|
||||
# Set private=false to mark the document as public (required for filtering)
|
||||
for doc in documents:
|
||||
doc.metadata["private"] = "false"
|
||||
docstore = get_doc_store()
|
||||
vector_store = get_vector_store()
|
||||
|
||||
# Run the ingestion pipeline
|
||||
_ = run_pipeline(docstore, vector_store, documents)
|
||||
|
||||
# Build the index and persist storage
|
||||
persist_storage(docstore, vector_store)
|
||||
persist_BMRetriever(vector_store)
|
||||
|
||||
logger.info("Finished generating the index")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from phoenix.trace import using_project
|
||||
with using_project(os.getenv("PHOENIX_PROJECT_NAME") + "_generate") as obj:
|
||||
generate_datasource()
|
||||
@@ -1,93 +0,0 @@
|
||||
from llama_index.core import PromptTemplate
|
||||
|
||||
text_qa_template_str = (
|
||||
"# 角色\n"
|
||||
"你是一名博微造价工程数据查询助手,专精于电力工程文件中的信息。"
|
||||
"你的职责是提供有关电力造价、造价编制软件、文件结构及相关数据的精准、客观的回答,"
|
||||
"如同直接从文件中提取的内容。\n"
|
||||
"知识库中已经导入一个工程的全部数据,请你站在当前工程的角度回答用户关于工程文件的问题。\n"
|
||||
"例如:询问“此工程”指当前导入的工程。询问“此工程名称”指当前导入的工程的工程名称。\n"
|
||||
|
||||
"## 技能\n"
|
||||
"### 技能 1: 数据查询与提供\n"
|
||||
"- 准确回答所有关于电力工程造价的相关问题。\n"
|
||||
"- 提供具体数据,如成本估算、材料清单、劳动力需求等。\n"
|
||||
"- 确保提供的信息严格基于工程文档中的记录。\n"
|
||||
|
||||
"### 技能 2: 技术性解释\n"
|
||||
"- 解释造价工程中的技术术语和概念。\n"
|
||||
"- 为复杂的工程细节提供清晰易懂的说明。\n"
|
||||
|
||||
"## 约束\n"
|
||||
"- 仅回答与电力工程造价文件相关的具体问题。\n"
|
||||
"- 不进行任何超出文件内容的猜测或假设。\n"
|
||||
"- 所有回答均基于文件内容,采用客观和技术性的语言。\n"
|
||||
"- 请基于这些信息回答问题。如果无法找到相关信息,请不要额外发散回答,不要回答多余的信息,只需要回答“我不知道这个问题的答案”。\n"
|
||||
"以下为上下文信息\n"
|
||||
"---------------------\n"
|
||||
"{context_str}\n"
|
||||
"---------------------\n"
|
||||
"请根据上下文信息而非先前知识回答我的问题或回复我的指令。前面的上下文信息可能有用,也可能没用,你需要从我给出的上下文信息中选出与我的问题最相关的那些,来为你的回答提供依据。回答一定要忠于原文,简洁但不丢信息,不要胡乱编造。如果无法找到相关信息,请不要额外发散回答,不要回答多余的信息,只需要回答“我不知道这个问题的答案”。我的问题或指令是什么语种,你就用什么语种回复。\n"
|
||||
"如果是表结构或者是数据库的相关内容,只用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n"
|
||||
|
||||
"问题:{query_str}\n"
|
||||
"你的回复: "
|
||||
)
|
||||
|
||||
|
||||
text_qa_template = PromptTemplate(text_qa_template_str)
|
||||
|
||||
refine_template_str = (
|
||||
"这是原本的问题: {query_str}\n"
|
||||
"我们已经提供了回答: {existing_answer}\n"
|
||||
"现在我们有机会改进这个回答 "
|
||||
"使用以下更多上下文(仅当有助于改进回答时使用)\n"
|
||||
"如果新的上下文对回答没有影响,或者原来的回答已经正确,不要在上次回答的后边再加上多余的补充信息,直接返回原本的回答。\n"
|
||||
"如果新的上下文对回答没有影响,或者原来的回答已经正确,不要在上次回答的后边再加上多余的补充信息,直接返回原本的回答。\n"
|
||||
"------------\n"
|
||||
"{context_msg}\n"
|
||||
"------------\n"
|
||||
"如果回答中已经包含有正确答案,不要返回多余的解释等信息,只返回正确答案\n"
|
||||
"如果是表结构或者是数据库的相关内容,仅用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n"
|
||||
"改进的回答: "
|
||||
)
|
||||
|
||||
refine_template = PromptTemplate(refine_template_str)
|
||||
|
||||
summary_template_str = (
|
||||
"# 角色\n"
|
||||
"你是一名博微造价工程数据查询助手,专精于电力工程文件中的信息。"
|
||||
"你的职责是提供有关电力造价、造价编制软件、文件结构及相关数据的精准、客观的回答,"
|
||||
"如同直接从文件中提取的内容。\n"
|
||||
|
||||
"## 技能\n"
|
||||
"### 技能 1: 数据查询与提供\n"
|
||||
"- 准确回答所有关于电力工程造价的相关问题。\n"
|
||||
"- 提供具体数据,如成本估算、材料清单、劳动力需求等。\n"
|
||||
"- 确保提供的信息严格基于工程文档中的记录。\n"
|
||||
|
||||
"### 技能 2: 技术性解释\n"
|
||||
"- 解释造价工程中的技术术语和概念。\n"
|
||||
"- 为复杂的工程细节提供清晰易懂的说明。\n"
|
||||
|
||||
"## 约束\n"
|
||||
"- 仅回答与电力工程造价文件相关的具体问题。\n"
|
||||
"- 不进行任何超出文件内容的猜测或假设。\n"
|
||||
"- 所有回答均基于文件内容,采用客观和技术性的语言。\n"
|
||||
"- 请基于这些信息回答问题。如果无法找到相关信息,请不要额外发散回答,不要回答多余的信息,只需要回答“我不知道这个问题的答案”。\n"
|
||||
"来自多个来源的上下文信息如下。\n"
|
||||
"---------------------\n"
|
||||
"{context_str}\n"
|
||||
"---------------------\n"
|
||||
"鉴于来自多个来源的信息而非先验知识, "
|
||||
"回答查询。\n"
|
||||
"如果是表结构或者是数据库的相关内容,只用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n"
|
||||
"Query: {query_str}\n"
|
||||
"Answer: "
|
||||
)
|
||||
summary_template = PromptTemplate(summary_template_str)
|
||||
|
||||
simple_template_str = (
|
||||
"{query_str}"
|
||||
)
|
||||
simple_template = PromptTemplate(simple_template_str)
|
||||
@@ -1,71 +0,0 @@
|
||||
import os
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.vector_stores.qdrant import QdrantVectorStore
|
||||
from qdrant_client import qdrant_client
|
||||
|
||||
qclient = None
|
||||
|
||||
def get_qdrant_vector_store():
|
||||
collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default")
|
||||
vector_store_path = os.getenv("VECTOR_STORE_PATH")
|
||||
host=os.getenv("VECTOR_STORE_HOST", "127.0.0.1"),
|
||||
port=int(os.getenv("VECTOR_STORE_PORT", "6333")),
|
||||
|
||||
if not vector_store_path or not host:
|
||||
raise ValueError(
|
||||
"Please provide either VECTOR_STORE_PATH or VECTOR_STORE_HOST and VECTOR_STORE_PORT"
|
||||
)
|
||||
# if VECTOR_STORE_PATH is set, use a local QdrantVectorStore from the path
|
||||
# otherwise, use a remote QdrantVectorStore
|
||||
global qclient
|
||||
if qclient == None:
|
||||
if vector_store_path:
|
||||
qclient = qdrant_client.QdrantClient(
|
||||
path=vector_store_path,
|
||||
)
|
||||
else:
|
||||
qclient = qdrant_client.QdrantClient(
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
|
||||
vector_store = QdrantVectorStore(client=qclient, collection_name=collection_name)
|
||||
return vector_store
|
||||
|
||||
def get_chroma_vector_store():
|
||||
collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default")
|
||||
vector_store_path = os.getenv("VECTOR_STORE_PATH")
|
||||
# if VECTOR_STORE_PATH is set, use a local ChromaVectorStore from the path
|
||||
# otherwise, use a remote ChromaVectorStore (ChromaDB Cloud is not supported yet)
|
||||
if vector_store_path:
|
||||
store = ChromaVectorStore.from_params(
|
||||
persist_dir=vector_store_path, collection_name=collection_name,
|
||||
collection_kwargs={"metadata":{"hnsw:space":"cosine"}},
|
||||
)
|
||||
else:
|
||||
if not os.getenv("VECTOR_STORE_HOST") or not os.getenv("VECTOR_STORE_PORT"):
|
||||
raise ValueError(
|
||||
"Please provide either VECTOR_STORE_PATH or VECTOR_STORE_HOST and VECTOR_STORE_PORT"
|
||||
)
|
||||
store = ChromaVectorStore.from_params(
|
||||
host=os.getenv("VECTOR_STORE_HOST"),
|
||||
port=int(os.getenv("VECTOR_STORE_PORT")),
|
||||
collection_name=collection_name,
|
||||
collection_kwargs={"metadata":{"hnsw:space":"cosine"}},
|
||||
)
|
||||
return store
|
||||
|
||||
def get_vector_store():
|
||||
store_type=os.getenv("VECTOR_STORE_TYPE")
|
||||
|
||||
store = None
|
||||
|
||||
match store_type:
|
||||
case "chroma":
|
||||
store = get_chroma_vector_store()
|
||||
case "qdrant":
|
||||
store = get_qdrant_vector_store()
|
||||
case _:
|
||||
raise ValueError(f"Invalid vector store type: {store_type}")
|
||||
|
||||
return store
|
||||
@@ -1,40 +0,0 @@
|
||||
import logging
|
||||
|
||||
import yaml
|
||||
from app.engine.loaders.db import DBLoaderConfig, get_db_documents
|
||||
from app.engine.loaders.file import FileLoaderConfig, get_file_documents
|
||||
from app.engine.loaders.web import WebLoaderConfig, get_web_documents
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_configs():
|
||||
with open("config/loaders.yaml") as f:
|
||||
configs = yaml.safe_load(f)
|
||||
return configs
|
||||
|
||||
|
||||
def get_documents():
|
||||
documents = []
|
||||
config = load_configs()
|
||||
if config is None or len(config.items()) == 0:
|
||||
return documents
|
||||
|
||||
for loader_type, loader_config in config.items():
|
||||
logger.info(
|
||||
f"Loading documents from loader: {loader_type}, config: {loader_config}"
|
||||
)
|
||||
|
||||
loader_config = loader_config or []
|
||||
match loader_type:
|
||||
case "file":
|
||||
document = get_file_documents(FileLoaderConfig(**loader_config))
|
||||
case "web":
|
||||
document = get_web_documents(WebLoaderConfig(**loader_config))
|
||||
case "db":
|
||||
document = get_db_documents(configs=[DBLoaderConfig(**cfg) for cfg in loader_config])
|
||||
case _:
|
||||
raise ValueError(f"Invalid loader type: {loader_type}")
|
||||
documents.extend(document)
|
||||
|
||||
return documents
|
||||
@@ -1,140 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from llama_index.core import SQLDatabase, Document
|
||||
from llama_index.readers.database import DatabaseReader
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CustomDatabaseReader(DatabaseReader):
|
||||
"""Simple Database reader.
|
||||
|
||||
Concatenates each row into Document used by LlamaIndex.
|
||||
|
||||
Args:
|
||||
sql_database (Optional[SQLDatabase]): SQL database to use,
|
||||
including table names to specify.
|
||||
See :ref:`Ref-Struct-Store` for more details.
|
||||
|
||||
OR
|
||||
|
||||
engine (Optional[Engine]): SQLAlchemy Engine object of the database connection.
|
||||
|
||||
OR
|
||||
|
||||
uri (Optional[str]): uri of the database connection.
|
||||
|
||||
OR
|
||||
|
||||
scheme (Optional[str]): scheme of the database connection.
|
||||
host (Optional[str]): host of the database connection.
|
||||
port (Optional[int]): port of the database connection.
|
||||
user (Optional[str]): user of the database connection.
|
||||
password (Optional[str]): password of the database connection.
|
||||
dbname (Optional[str]): dbname of the database connection.
|
||||
|
||||
Returns:
|
||||
DatabaseReader: A DatabaseReader object.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sql_database: Optional[SQLDatabase] = None,
|
||||
engine: Optional[Engine] = None,
|
||||
uri: Optional[str] = None,
|
||||
scheme: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
user: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
dbname: Optional[str] = None,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize with parameters."""
|
||||
if sql_database:
|
||||
self.sql_database = sql_database
|
||||
elif engine:
|
||||
self.sql_database = SQLDatabase(engine, *args, **kwargs)
|
||||
elif uri:
|
||||
self.uri = uri
|
||||
self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
|
||||
elif scheme and host and port and user and password and dbname:
|
||||
uri = f"{scheme}://{user}:{password}@{host}:{port}/{dbname}"
|
||||
self.uri = uri
|
||||
self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You must provide either a SQLDatabase, "
|
||||
"a SQL Alchemy Engine, a valid connection URI, or a valid "
|
||||
"set of credentials."
|
||||
)
|
||||
|
||||
def load_data(self, query: str, explanation: str) -> List[Document]:
|
||||
"""Query and load data from the Database, returning a list of Documents.
|
||||
|
||||
Args:
|
||||
query (str): Query parameter to filter tables and rows.
|
||||
explanation (str): Explanation for the query to be included in the document.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of Document objects.
|
||||
"""
|
||||
dco_str = explanation + "\n"
|
||||
|
||||
with self.sql_database.engine.connect() as connection:
|
||||
if query is None:
|
||||
raise ValueError("A query parameter is necessary to filter the data")
|
||||
else:
|
||||
result = connection.execute(text(query))
|
||||
|
||||
dco_str += ", ".join(
|
||||
[f"{entry}" for entry in result.keys()]
|
||||
) + "\n"
|
||||
|
||||
for item in result.fetchall():
|
||||
# Fetch each item
|
||||
record_str = ", ".join(
|
||||
[f"{entry}" for col, entry in zip(result.keys(), item)]
|
||||
)
|
||||
dco_str += record_str + "\n"
|
||||
|
||||
doc = Document(text=dco_str)
|
||||
doc.metadata["name"] = query
|
||||
doc.metadata["context"] = query
|
||||
doc.metadata["file_type"] = "application/vnd.ms-excel"
|
||||
return [doc]
|
||||
|
||||
class DBLoaderConfig(BaseModel):
|
||||
uri: str
|
||||
queries: List[dict]
|
||||
|
||||
def get_db_documents(configs: list[DBLoaderConfig]):
|
||||
docs = []
|
||||
|
||||
if len(configs) == 0 or configs[0].uri == "":
|
||||
logger.warning(
|
||||
f"Failed to load database, error message: uri is empty. Return as empty document list."
|
||||
)
|
||||
return docs
|
||||
|
||||
metadata = {
|
||||
'file_type': 'application/booway.document.zj',
|
||||
}
|
||||
|
||||
for entry in configs:
|
||||
engine = create_engine(entry.uri)
|
||||
sql_database = SQLDatabase(engine)
|
||||
|
||||
loader = CustomDatabaseReader(sql_database)
|
||||
for query_dict in entry.queries:
|
||||
query = query_dict.get("sql", "")
|
||||
explanation = query_dict.get("explanation", "")
|
||||
logger.info(f"Loading data from database with query: {query}")
|
||||
documents = loader.load_data(query=query, explanation=explanation)
|
||||
|
||||
docs.extend(documents)
|
||||
return docs
|
||||
@@ -1,88 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from llama_index.core.readers.json import JSONReader
|
||||
from llama_parse import LlamaParse
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileLoaderConfig(BaseModel):
|
||||
data_dir: str = "data"
|
||||
use_llama_parse: bool = False
|
||||
|
||||
@validator("data_dir")
|
||||
def data_dir_must_exist(cls, v):
|
||||
if not os.path.isdir(v):
|
||||
raise ValueError(f"Directory '{v}' does not exist")
|
||||
return v
|
||||
|
||||
|
||||
def llama_parse_parser():
|
||||
if os.getenv("LLAMA_CLOUD_API_KEY") is None:
|
||||
raise ValueError(
|
||||
"LLAMA_CLOUD_API_KEY environment variable is not set. "
|
||||
"Please set it in .env file or in your shell environment then run again!"
|
||||
)
|
||||
parser = LlamaParse(
|
||||
result_type="markdown",
|
||||
verbose=True,
|
||||
language="en",
|
||||
ignore_errors=False,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def llama_parse_extractor() -> Dict[str, LlamaParse]:
|
||||
from llama_parse.utils import SUPPORTED_FILE_TYPES
|
||||
|
||||
parser = llama_parse_parser()
|
||||
return {file_type: parser for file_type in SUPPORTED_FILE_TYPES}
|
||||
|
||||
def llama_local_extractor() -> Dict[str, BaseReader]:
|
||||
return {".json" : JSONReader(clean_json=False,levels_back=0)}
|
||||
|
||||
|
||||
def get_file_documents(config: FileLoaderConfig):
|
||||
from llama_index.core.readers import SimpleDirectoryReader
|
||||
|
||||
try:
|
||||
file_extractor = None
|
||||
if config.use_llama_parse:
|
||||
# LlamaParse is async first,
|
||||
# so we need to use nest_asyncio to run it in sync mode
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
|
||||
file_extractor = llama_parse_extractor()
|
||||
else:
|
||||
file_extractor = llama_local_extractor()
|
||||
|
||||
reader = SimpleDirectoryReader(
|
||||
config.data_dir,
|
||||
recursive=True,
|
||||
filename_as_id=True,
|
||||
raise_on_error=True,
|
||||
file_extractor=file_extractor,
|
||||
)
|
||||
return reader.load_data()
|
||||
except Exception as e:
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
# Catch the error if the data dir is empty
|
||||
# and return as empty document list
|
||||
_, _, exc_traceback = sys.exc_info()
|
||||
function_name = traceback.extract_tb(exc_traceback)[-1].name
|
||||
if function_name == "_add_files":
|
||||
logger.warning(
|
||||
f"Failed to load file documents, error message: {e} . Return as empty document list."
|
||||
)
|
||||
return []
|
||||
else:
|
||||
# Raise the error if it is not the case of empty data dir
|
||||
raise e
|
||||
@@ -1,37 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CrawlUrl(BaseModel):
|
||||
base_url: str
|
||||
prefix: str
|
||||
max_depth: int = Field(default=1, ge=0)
|
||||
|
||||
|
||||
class WebLoaderConfig(BaseModel):
|
||||
driver_arguments: list[str] = Field(default=None)
|
||||
urls: list[CrawlUrl] = []
|
||||
|
||||
|
||||
def get_web_documents(config: WebLoaderConfig):
|
||||
from llama_index.readers.web import WholeSiteReader
|
||||
from selenium import webdriver
|
||||
from selenium.webdriver.chrome.options import Options
|
||||
|
||||
options = Options()
|
||||
driver_arguments = config.driver_arguments or []
|
||||
for arg in driver_arguments:
|
||||
options.add_argument(arg)
|
||||
|
||||
docs = []
|
||||
urls = config.urls or []
|
||||
for url in config.urls:
|
||||
scraper = WholeSiteReader(
|
||||
prefix=url.prefix,
|
||||
max_depth=url.max_depth,
|
||||
driver=webdriver.Chrome(options=options),
|
||||
)
|
||||
docs.extend(scraper.load_data(url.base_url))
|
||||
|
||||
return docs
|
||||
@@ -1,133 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, cast
|
||||
|
||||
from llama_index.core.base.base_retriever import BaseRetriever
|
||||
from llama_index.core.callbacks.base import CallbackManager
|
||||
from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K
|
||||
from llama_index.core.indices.vector_store.base import VectorStoreIndex
|
||||
from llama_index.core.schema import BaseNode, IndexNode, NodeWithScore, QueryBundle
|
||||
from llama_index.core.storage.docstore.types import BaseDocumentStore
|
||||
from llama_index.core.vector_stores.utils import (
|
||||
node_to_metadata_dict,
|
||||
metadata_dict_to_node,
|
||||
)
|
||||
|
||||
import bm25s
|
||||
from app.engine.retriever.CHTokener import chTokenize
|
||||
|
||||
CHDEFAULT_PERSIST_ARGS = {"similarity_top_k": "similarity_top_k", "_verbose": "verbose"}
|
||||
|
||||
CHDEFAULT_PERSIST_FILENAME = "retriever.json"
|
||||
|
||||
class CHBM25Retriever(BaseRetriever):
|
||||
def __init__(
|
||||
self,
|
||||
nodes: Optional[List[BaseNode]] = None,
|
||||
existing_bm25: Optional[bm25s.BM25] = None,
|
||||
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
objects: Optional[List[IndexNode]] = None,
|
||||
object_map: Optional[dict] = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
self.similarity_top_k = similarity_top_k
|
||||
if existing_bm25 is not None:
|
||||
self.bm25 = existing_bm25
|
||||
self.corpus = existing_bm25.corpus
|
||||
else:
|
||||
from nltk.corpus import stopwords
|
||||
if nodes is None:
|
||||
raise ValueError("Please pass nodes or an existing BM25 object.")
|
||||
|
||||
self.corpus = [node_to_metadata_dict(node) for node in nodes]
|
||||
|
||||
corpus_tokens = chTokenize(
|
||||
[node.get_content() for node in nodes],
|
||||
show_progress=verbose,
|
||||
)
|
||||
self.bm25 = bm25s.BM25()
|
||||
self.bm25.index(corpus_tokens, show_progress=verbose)
|
||||
super().__init__(
|
||||
callback_manager=callback_manager,
|
||||
object_map=object_map,
|
||||
objects=objects,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_defaults(
|
||||
cls,
|
||||
index: Optional[VectorStoreIndex] = None,
|
||||
nodes: Optional[List[BaseNode]] = None,
|
||||
docstore: Optional[BaseDocumentStore] = None,
|
||||
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
|
||||
verbose: bool = False,
|
||||
) -> "CHBM25Retriever":
|
||||
if sum(bool(val) for val in [index, nodes, docstore]) != 1:
|
||||
raise ValueError("Please pass exactly one of index, nodes, or docstore.")
|
||||
|
||||
if index is not None:
|
||||
docstore = index.docstore
|
||||
|
||||
if docstore is not None:
|
||||
nodes = cast(List[BaseNode], list(docstore.docs.values()))
|
||||
|
||||
assert (
|
||||
nodes is not None
|
||||
), "Please pass exactly one of index, nodes, or docstore."
|
||||
|
||||
return cls(
|
||||
nodes=nodes,
|
||||
similarity_top_k=similarity_top_k,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
def get_persist_args(self) -> Dict[str, Any]:
|
||||
"""Get Persist Args Dict to Save."""
|
||||
return {
|
||||
CHDEFAULT_PERSIST_ARGS[key]: getattr(self, key)
|
||||
for key in CHDEFAULT_PERSIST_ARGS
|
||||
if hasattr(self, key)
|
||||
}
|
||||
|
||||
def persist(self, path: str, **kwargs: Any) -> None:
|
||||
"""Persist the retriever to a directory."""
|
||||
self.bm25.save(path, corpus=self.corpus, **kwargs)
|
||||
with open(os.path.join(path, CHDEFAULT_PERSIST_FILENAME), "w") as f:
|
||||
json.dump(self.get_persist_args(), f, indent=2)
|
||||
|
||||
@classmethod
|
||||
def from_persist_dir(cls, path: str, **kwargs: Any) -> "CHBM25Retriever":
|
||||
"""Load the retriever from a directory."""
|
||||
bm25 = bm25s.BM25.load(path, load_corpus=True, **kwargs)
|
||||
with open(os.path.join(path, CHDEFAULT_PERSIST_FILENAME)) as f:
|
||||
retriever_data = json.load(f)
|
||||
return cls(existing_bm25=bm25, **retriever_data)
|
||||
|
||||
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||
query = query_bundle.query_str
|
||||
tokenized_query = chTokenize(
|
||||
query,show_progress=self._verbose
|
||||
)
|
||||
indexes, scores = self.bm25.retrieve(
|
||||
tokenized_query, k=self.similarity_top_k, show_progress=self._verbose
|
||||
)
|
||||
|
||||
# batched, but only one query
|
||||
indexes = indexes[0]
|
||||
scores = scores[0]
|
||||
|
||||
nodes: List[NodeWithScore] = []
|
||||
for idx, score in zip(indexes, scores):
|
||||
# idx can be an int or a dict of the node
|
||||
if isinstance(idx, dict):
|
||||
node = metadata_dict_to_node(idx)
|
||||
else:
|
||||
node_dict = self.corpus[int(idx)]
|
||||
node = metadata_dict_to_node(node_dict)
|
||||
nodes.append(NodeWithScore(node=node, score=float(score)))
|
||||
|
||||
return nodes
|
||||
@@ -1,46 +0,0 @@
|
||||
from typing import Any, Dict, List, Union, Callable, NamedTuple
|
||||
from bm25s.tokenization import *
|
||||
|
||||
try:
|
||||
from tqdm.auto import tqdm
|
||||
except ImportError:
|
||||
|
||||
def tqdm(iterable, *args, **kwargs):
|
||||
return iterable
|
||||
|
||||
|
||||
def chinese_tokenizer(text: str) -> List[str]:
|
||||
import jieba
|
||||
from nltk.corpus import stopwords
|
||||
tokens = jieba.lcut(text)
|
||||
return [token for token in tokens if token not in stopwords.words('chinese')]
|
||||
|
||||
def chTokenize(
|
||||
texts,
|
||||
show_progress: bool = True,
|
||||
leave: bool = False,
|
||||
) -> Union[List[List[str]], Tokenized]:
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
corpus_ids = []
|
||||
token_to_index = {}
|
||||
|
||||
for text in tqdm(
|
||||
texts, desc="Split strings", leave=leave, disable=not show_progress
|
||||
):
|
||||
|
||||
splitted = chinese_tokenizer(text)
|
||||
doc_ids = []
|
||||
|
||||
for token in splitted:
|
||||
if token not in token_to_index:
|
||||
token_to_index[token] = len(token_to_index)
|
||||
|
||||
token_id = token_to_index[token]
|
||||
doc_ids.append(token_id)
|
||||
|
||||
corpus_ids.append(doc_ids)
|
||||
|
||||
return Tokenized(ids=corpus_ids, vocab=token_to_index)
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
import os
|
||||
from typing import Optional, Any, Dict, List
|
||||
|
||||
from llama_index.core.base.base_retriever import BaseRetriever
|
||||
from llama_index.core.schema import NodeWithScore, QueryBundle
|
||||
|
||||
from app.engine.retriever.CHBM25Retriever import CHBM25Retriever
|
||||
|
||||
|
||||
class HybridRetriever(BaseRetriever):
|
||||
def __init__(
|
||||
self,
|
||||
vector_index,
|
||||
similarity_top_k: int = 2,
|
||||
out_top_k: Optional[int] = None,
|
||||
alpha: float = 0.5,
|
||||
filters = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._vector_index = vector_index
|
||||
self._embed_model = vector_index._embed_model
|
||||
self._out_top_k = out_top_k or similarity_top_k
|
||||
self._vecRetriever = vector_index.as_retriever(
|
||||
similarity_top_k=similarity_top_k,filters = filters
|
||||
)
|
||||
|
||||
STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm")
|
||||
if os.path.exists(STORAGE_DIR) and len(os.listdir(STORAGE_DIR)) > 0:
|
||||
self._bm25Retriever = CHBM25Retriever.from_persist_dir(STORAGE_DIR)
|
||||
else:
|
||||
bmRetriver = CHBM25Retriever.from_defaults(similarity_top_k=similarity_top_k,nodes=self._vector_index.vector_store.get_nodes(None))
|
||||
bmRetriver.persist(STORAGE_DIR)
|
||||
self._alpha = alpha
|
||||
|
||||
|
||||
|
||||
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||
vecNodes:List[NodeWithScore] = self._vecRetriever.retrieve(query_bundle.query_str)
|
||||
bmNodes:List[NodeWithScore] = self._bm25Retriever.retrieve(query_bundle.query_str)
|
||||
|
||||
bmDic:Dict[str,NodeWithScore] = {}
|
||||
for node in bmNodes:
|
||||
bmDic[node.node_id] = node
|
||||
|
||||
result_tups = []
|
||||
for i in range(len(vecNodes)):
|
||||
node = vecNodes[i]
|
||||
bmScore = 0.0
|
||||
if node.node_id in bmDic:
|
||||
bmScore = bmDic[node.node_id].score
|
||||
bmDic.pop(node.node_id)
|
||||
else:
|
||||
bmScore = 0.0
|
||||
full_similarity = (self._alpha * node.score) + (
|
||||
(1 - self._alpha) * bmScore
|
||||
)
|
||||
result_tups.append((full_similarity, node))
|
||||
|
||||
for _,node in bmDic.items():
|
||||
full_similarity = (1 - self._alpha) * node.score
|
||||
result_tups.append((full_similarity, node))
|
||||
|
||||
result_tups = sorted(result_tups, key=lambda x: x[0], reverse=True)
|
||||
for full_score, node in result_tups:
|
||||
node.score = full_score
|
||||
return [n for _, n in result_tups][:self._out_top_k]
|
||||
@@ -1,36 +0,0 @@
|
||||
from llama_index.core.tools.function_tool import FunctionTool
|
||||
|
||||
|
||||
def duckduckgo_search(
|
||||
query: str,
|
||||
region: str = "wt-wt",
|
||||
max_results: int = 10,
|
||||
):
|
||||
"""
|
||||
Use this function to search for any query in DuckDuckGo.
|
||||
Args:
|
||||
query (str): The query to search in DuckDuckGo.
|
||||
region Optional(str): The region to be used for the search in [country-language] convention, ex us-en, uk-en, ru-ru, etc...
|
||||
max_results Optional(int): The maximum number of results to be returned. Default is 10.
|
||||
"""
|
||||
try:
|
||||
from duckduckgo_search import DDGS
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"duckduckgo_search package is required to use this function."
|
||||
"Please install it by running: `poetry add duckduckgo_search` or `pip install duckduckgo_search`"
|
||||
)
|
||||
|
||||
params = {
|
||||
"keywords": query,
|
||||
"region": region,
|
||||
"max_results": max_results,
|
||||
}
|
||||
results = []
|
||||
with DDGS() as ddg:
|
||||
results = list(ddg.text(**params))
|
||||
return results
|
||||
|
||||
|
||||
def get_tools(**kwargs):
|
||||
return [FunctionTool.from_defaults(duckduckgo_search)]
|
||||
@@ -1,60 +0,0 @@
|
||||
import os
|
||||
import yaml
|
||||
import json
|
||||
import importlib
|
||||
from cachetools import cached, LRUCache
|
||||
from llama_index.core.tools.tool_spec.base import BaseToolSpec
|
||||
from llama_index.core.tools.function_tool import FunctionTool
|
||||
|
||||
|
||||
class ToolType:
|
||||
LLAMAHUB = "llamahub"
|
||||
LOCAL = "local"
|
||||
|
||||
|
||||
class ToolFactory:
|
||||
|
||||
TOOL_SOURCE_PACKAGE_MAP = {
|
||||
ToolType.LLAMAHUB: "llama_index.tools",
|
||||
ToolType.LOCAL: "app.engine.tools",
|
||||
}
|
||||
|
||||
def load_tools(tool_type: str, tool_name: str, config: dict) -> list[FunctionTool]:
|
||||
source_package = ToolFactory.TOOL_SOURCE_PACKAGE_MAP[tool_type]
|
||||
try:
|
||||
if "ToolSpec" in tool_name:
|
||||
tool_package, tool_cls_name = tool_name.split(".")
|
||||
module_name = f"{source_package}.{tool_package}"
|
||||
module = importlib.import_module(module_name)
|
||||
tool_class = getattr(module, tool_cls_name)
|
||||
tool_spec: BaseToolSpec = tool_class(**config)
|
||||
return tool_spec.to_tool_list()
|
||||
else:
|
||||
module = importlib.import_module(f"{source_package}.{tool_name}")
|
||||
tools = module.get_tools(**config)
|
||||
if not all(isinstance(tool, FunctionTool) for tool in tools):
|
||||
raise ValueError(
|
||||
f"The module {module} does not contain valid tools"
|
||||
)
|
||||
return tools
|
||||
except ImportError as e:
|
||||
raise ValueError(f"Failed to import tool {tool_name}: {e}")
|
||||
except AttributeError as e:
|
||||
raise ValueError(f"Failed to load tool {tool_name}: {e}")
|
||||
|
||||
@staticmethod
|
||||
def from_env() -> list[FunctionTool]:
|
||||
tools = []
|
||||
if os.path.exists("config/tools.yaml"):
|
||||
with open("config/tools.yaml", "r") as f:
|
||||
tool_configs = yaml.safe_load(f)
|
||||
if tool_configs != None and len(tool_configs.items()) != 0:
|
||||
for tool_type, config_entries in tool_configs.items():
|
||||
if config_entries == None or len(config_entries.items()) == 0:
|
||||
continue
|
||||
|
||||
for tool_name, config in config_entries.items():
|
||||
tools.extend(
|
||||
ToolFactory.load_tools(tool_type, tool_name, config)
|
||||
)
|
||||
return tools
|
||||
@@ -1,108 +0,0 @@
|
||||
import os
|
||||
import uuid
|
||||
import logging
|
||||
import requests
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from llama_index.core.tools import FunctionTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImageGeneratorToolOutput(BaseModel):
|
||||
is_success: bool = Field(
|
||||
...,
|
||||
description="Whether the image generation was successful.",
|
||||
)
|
||||
image_url: Optional[str] = Field(
|
||||
None,
|
||||
description="The URL of the generated image.",
|
||||
)
|
||||
error_message: Optional[str] = Field(
|
||||
None,
|
||||
description="The error message if the image generation failed.",
|
||||
)
|
||||
|
||||
|
||||
class ImageGeneratorTool:
|
||||
_IMG_OUTPUT_FORMAT = "webp"
|
||||
_IMG_OUTPUT_DIR = "output/tool"
|
||||
_IMG_GEN_API = "https://api.stability.ai/v2beta/stable-image/generate/core"
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
if not api_key:
|
||||
api_key = os.getenv("STABILITY_API_KEY")
|
||||
self._api_key = api_key
|
||||
self.fileserver_url_prefix = os.getenv("FILESERVER_URL_PREFIX")
|
||||
if self._api_key is None:
|
||||
raise ValueError(
|
||||
"STABILITY_API_KEY key is required to run image generator. Get it here: https://platform.stability.ai/account/keys"
|
||||
)
|
||||
if self.fileserver_url_prefix is None:
|
||||
raise ValueError("FILESERVER_URL_PREFIX is required.")
|
||||
|
||||
def _prepare_output_dir(self):
|
||||
"""
|
||||
Create the output directory if it doesn't exist
|
||||
"""
|
||||
if not os.path.exists(self._IMG_OUTPUT_DIR):
|
||||
os.makedirs(self._IMG_OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
def _save_image(self, image_data: bytes):
|
||||
self._prepare_output_dir()
|
||||
filename = f"{uuid.uuid4()}.{self._IMG_OUTPUT_FORMAT}"
|
||||
output_path = os.path.join(self._IMG_OUTPUT_DIR, filename)
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(image_data)
|
||||
url = f"{os.getenv('FILESERVER_URL_PREFIX')}/{self._IMG_OUTPUT_DIR}/{filename}"
|
||||
logger.info(f"Saved image to {output_path}.\nURL: {url}")
|
||||
return url
|
||||
|
||||
def _call_stability_api(self, prompt: str):
|
||||
headers = {
|
||||
"authorization": f"Bearer {self._api_key}",
|
||||
"accept": "image/*",
|
||||
}
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"output_format": self._IMG_OUTPUT_FORMAT,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
self._IMG_GEN_API,
|
||||
headers=headers,
|
||||
files={"none": ""},
|
||||
data=data,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return response
|
||||
|
||||
def generate_image(self, prompt: str) -> ImageGeneratorToolOutput:
|
||||
"""
|
||||
Use this tool to generate an image based on the prompt.
|
||||
Args:
|
||||
prompt (str): The prompt to generate the image from.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Call the Stability API
|
||||
response = self._call_stability_api(prompt)
|
||||
|
||||
# Save the image and get the URL
|
||||
image_url = self._save_image(response.content)
|
||||
|
||||
return ImageGeneratorToolOutput(
|
||||
is_success=True,
|
||||
image_url=image_url,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(e, exc_info=True)
|
||||
return ImageGeneratorToolOutput(
|
||||
is_success=False,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
|
||||
def get_tools(**kwargs):
|
||||
return [FunctionTool.from_defaults(ImageGeneratorTool(**kwargs).generate_image)]
|
||||
@@ -1,143 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
import base64
|
||||
import uuid
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Tuple, Dict, Optional
|
||||
from llama_index.core.tools import FunctionTool
|
||||
from e2b_code_interpreter import CodeInterpreter
|
||||
from e2b_code_interpreter.models import Logs
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InterpreterExtraResult(BaseModel):
|
||||
type: str
|
||||
content: Optional[str] = None
|
||||
filename: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
|
||||
|
||||
class E2BToolOutput(BaseModel):
|
||||
is_error: bool
|
||||
logs: Logs
|
||||
results: List[InterpreterExtraResult] = []
|
||||
|
||||
|
||||
class E2BCodeInterpreter:
|
||||
|
||||
output_dir = "output/tool"
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
if api_key is None:
|
||||
api_key = os.getenv("E2B_API_KEY")
|
||||
filesever_url_prefix = os.getenv("FILESERVER_URL_PREFIX")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"E2B_API_KEY key is required to run code interpreter. Get it here: https://e2b.dev/docs/getting-started/api-key"
|
||||
)
|
||||
if not filesever_url_prefix:
|
||||
raise ValueError(
|
||||
"FILESERVER_URL_PREFIX is required to display file output from sandbox"
|
||||
)
|
||||
|
||||
self.filesever_url_prefix = filesever_url_prefix
|
||||
self.interpreter = CodeInterpreter(api_key=api_key)
|
||||
|
||||
def __del__(self):
|
||||
self.interpreter.close()
|
||||
|
||||
def get_output_path(self, filename: str) -> str:
|
||||
# if output directory doesn't exist, create it
|
||||
if not os.path.exists(self.output_dir):
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
return os.path.join(self.output_dir, filename)
|
||||
|
||||
def save_to_disk(self, base64_data: str, ext: str) -> Dict:
|
||||
filename = f"{uuid.uuid4()}.{ext}" # generate a unique filename
|
||||
buffer = base64.b64decode(base64_data)
|
||||
output_path = self.get_output_path(filename)
|
||||
|
||||
try:
|
||||
with open(output_path, "wb") as file:
|
||||
file.write(buffer)
|
||||
except IOError as e:
|
||||
logger.error(f"Failed to write to file {output_path}: {str(e)}")
|
||||
raise e
|
||||
|
||||
logger.info(f"Saved file to {output_path}")
|
||||
|
||||
return {
|
||||
"outputPath": output_path,
|
||||
"filename": filename,
|
||||
}
|
||||
|
||||
def get_file_url(self, filename: str) -> str:
|
||||
return f"{self.filesever_url_prefix}/{self.output_dir}/{filename}"
|
||||
|
||||
def parse_result(self, result) -> List[InterpreterExtraResult]:
|
||||
"""
|
||||
The result could include multiple formats (e.g. png, svg, etc.) but encoded in base64
|
||||
We save each result to disk and return saved file metadata (extension, filename, url)
|
||||
"""
|
||||
if not result:
|
||||
return []
|
||||
|
||||
output = []
|
||||
|
||||
try:
|
||||
formats = result.formats()
|
||||
results = [result[format] for format in formats]
|
||||
|
||||
for ext, data in zip(formats, results):
|
||||
match ext:
|
||||
case "png" | "svg" | "jpeg" | "pdf":
|
||||
result = self.save_to_disk(data, ext)
|
||||
filename = result["filename"]
|
||||
output.append(
|
||||
InterpreterExtraResult(
|
||||
type=ext,
|
||||
filename=filename,
|
||||
url=self.get_file_url(filename),
|
||||
)
|
||||
)
|
||||
case _:
|
||||
output.append(
|
||||
InterpreterExtraResult(
|
||||
type=ext,
|
||||
content=data,
|
||||
)
|
||||
)
|
||||
except Exception as error:
|
||||
logger.exception(error, exc_info=True)
|
||||
logger.error("Error when parsing output from E2b interpreter tool", error)
|
||||
|
||||
return output
|
||||
|
||||
def interpret(self, code: str) -> E2BToolOutput:
|
||||
"""
|
||||
Execute python code in a Jupyter notebook cell, the toll will return result, stdout, stderr, display_data, and error.
|
||||
|
||||
Parameters:
|
||||
code (str): The python code to be executed in a single cell.
|
||||
"""
|
||||
logger.info(
|
||||
f"\n{'='*50}\n> Running following AI-generated code:\n{code}\n{'='*50}"
|
||||
)
|
||||
exec = self.interpreter.notebook.exec_cell(code)
|
||||
|
||||
if exec.error:
|
||||
logger.error("Error when executing code", exec.error)
|
||||
output = E2BToolOutput(is_error=True, logs=exec.logs, results=[])
|
||||
else:
|
||||
if len(exec.results) == 0:
|
||||
output = E2BToolOutput(is_error=False, logs=exec.logs, results=[])
|
||||
else:
|
||||
results = self.parse_result(exec.results[0])
|
||||
output = E2BToolOutput(is_error=False, logs=exec.logs, results=results)
|
||||
return output
|
||||
|
||||
|
||||
def get_tools(**kwargs):
|
||||
return [FunctionTool.from_defaults(E2BCodeInterpreter(**kwargs).interpret)]
|
||||
@@ -1,78 +0,0 @@
|
||||
from typing import Dict, List, Tuple
|
||||
from llama_index.tools.openapi import OpenAPIToolSpec
|
||||
from llama_index.tools.requests import RequestsToolSpec
|
||||
|
||||
|
||||
class OpenAPIActionToolSpec(OpenAPIToolSpec, RequestsToolSpec):
|
||||
"""
|
||||
A combination of OpenAPI and Requests tool specs that can parse OpenAPI specs and make requests.
|
||||
|
||||
openapi_uri: str: The file path or URL to the OpenAPI spec.
|
||||
domain_headers: dict: Whitelist domains and the headers to use.
|
||||
"""
|
||||
|
||||
spec_functions = OpenAPIToolSpec.spec_functions + RequestsToolSpec.spec_functions
|
||||
# Cached parsed specs by URI
|
||||
_specs: Dict[str, Tuple[Dict, List[str]]] = {}
|
||||
|
||||
def __init__(self, openapi_uri: str, domain_headers: dict = None, **kwargs):
|
||||
if domain_headers is None:
|
||||
domain_headers = {}
|
||||
if openapi_uri not in self._specs:
|
||||
openapi_spec, servers = self._load_openapi_spec(openapi_uri)
|
||||
self._specs[openapi_uri] = (openapi_spec, servers)
|
||||
else:
|
||||
openapi_spec, servers = self._specs[openapi_uri]
|
||||
|
||||
# Add the servers to the domain headers if they are not already present
|
||||
for server in servers:
|
||||
if server not in domain_headers:
|
||||
domain_headers[server] = {}
|
||||
|
||||
OpenAPIToolSpec.__init__(self, spec=openapi_spec)
|
||||
RequestsToolSpec.__init__(self, domain_headers)
|
||||
|
||||
@staticmethod
|
||||
def _load_openapi_spec(uri: str) -> Tuple[Dict, List[str]]:
|
||||
"""
|
||||
Load an OpenAPI spec from a URI.
|
||||
|
||||
Args:
|
||||
uri (str): A file path or URL to the OpenAPI spec.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of Document objects.
|
||||
"""
|
||||
import yaml
|
||||
from urllib.parse import urlparse
|
||||
|
||||
if uri.startswith("http"):
|
||||
import requests
|
||||
|
||||
response = requests.get(uri)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(
|
||||
"Could not initialize OpenAPIActionToolSpec: "
|
||||
f"Failed to load OpenAPI spec from {uri}, status code: {response.status_code}"
|
||||
)
|
||||
spec = yaml.safe_load(response.text)
|
||||
elif uri.startswith("file"):
|
||||
filepath = urlparse(uri).path
|
||||
with open(filepath, "r") as file:
|
||||
spec = yaml.safe_load(file)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not initialize OpenAPIActionToolSpec: Invalid OpenAPI URI provided. "
|
||||
"Only HTTP and file path are supported."
|
||||
)
|
||||
# Add the servers to the whitelist
|
||||
try:
|
||||
servers = [
|
||||
urlparse(server["url"]).netloc for server in spec.get("servers", [])
|
||||
]
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
"Could not initialize OpenAPIActionToolSpec: Invalid OpenAPI spec provided. "
|
||||
"Could not get `servers` from the spec."
|
||||
) from e
|
||||
return spec, servers
|
||||
@@ -1,73 +0,0 @@
|
||||
"""Open Meteo weather map tool spec."""
|
||||
|
||||
import logging
|
||||
import requests
|
||||
import pytz
|
||||
from llama_index.core.tools import FunctionTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenMeteoWeather:
|
||||
geo_api = "https://geocoding-api.open-meteo.com/v1"
|
||||
weather_api = "https://api.open-meteo.com/v1"
|
||||
|
||||
@classmethod
|
||||
def _get_geo_location(cls, location: str) -> dict:
|
||||
"""Get geo location from location name."""
|
||||
params = {"name": location, "count": 10, "language": "en", "format": "json"}
|
||||
response = requests.get(f"{cls.geo_api}/search", params=params)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to fetch geo location: {response.status_code}")
|
||||
else:
|
||||
data = response.json()
|
||||
result = data["results"][0]
|
||||
geo_location = {
|
||||
"id": result["id"],
|
||||
"name": result["name"],
|
||||
"latitude": result["latitude"],
|
||||
"longitude": result["longitude"],
|
||||
}
|
||||
return geo_location
|
||||
|
||||
@classmethod
|
||||
def get_weather_information(cls, location: str) -> dict:
|
||||
"""Use this function to get the weather of any given location.
|
||||
Note that the weather code should follow WMO Weather interpretation codes (WW):
|
||||
0: Clear sky
|
||||
1, 2, 3: Mainly clear, partly cloudy, and overcast
|
||||
45, 48: Fog and depositing rime fog
|
||||
51, 53, 55: Drizzle: Light, moderate, and dense intensity
|
||||
56, 57: Freezing Drizzle: Light and dense intensity
|
||||
61, 63, 65: Rain: Slight, moderate and heavy intensity
|
||||
66, 67: Freezing Rain: Light and heavy intensity
|
||||
71, 73, 75: Snow fall: Slight, moderate, and heavy intensity
|
||||
77: Snow grains
|
||||
80, 81, 82: Rain showers: Slight, moderate, and violent
|
||||
85, 86: Snow showers slight and heavy
|
||||
95: Thunderstorm: Slight or moderate
|
||||
96, 99: Thunderstorm with slight and heavy hail
|
||||
"""
|
||||
logger.info(
|
||||
f"Calling open-meteo api to get weather information of location: {location}"
|
||||
)
|
||||
geo_location = cls._get_geo_location(location)
|
||||
timezone = pytz.timezone("UTC").zone
|
||||
params = {
|
||||
"latitude": geo_location["latitude"],
|
||||
"longitude": geo_location["longitude"],
|
||||
"current": "temperature_2m,weather_code",
|
||||
"hourly": "temperature_2m,weather_code",
|
||||
"daily": "weather_code",
|
||||
"timezone": timezone,
|
||||
}
|
||||
response = requests.get(f"{cls.weather_api}/forecast", params=params)
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Failed to fetch weather information: {response.status_code}"
|
||||
)
|
||||
return response.json()
|
||||
|
||||
|
||||
def get_tools(**kwargs):
|
||||
return [FunctionTool.from_defaults(OpenMeteoWeather.get_weather_information)]
|
||||
Reference in New Issue
Block a user