167 lines
7.3 KiB
Python
167 lines
7.3 KiB
Python
import os
|
|
|
|
from llama_index.core import SQLDatabase, SummaryIndex, VectorStoreIndex
|
|
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
|
|
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex
|
|
from llama_index.core.settings import Settings
|
|
from llama_index.core.agent import AgentRunner, StructuredPlannerAgent, FunctionCallingAgentWorker
|
|
from llama_index.core.tools.query_engine import QueryEngineTool
|
|
from sqlalchemy import create_engine, Engine
|
|
|
|
from app.engine.loaders.db import makeDescriptionByEngine
|
|
from app.engine.tools import ToolFactory
|
|
from app.engine.index import get_index
|
|
from app.engine.retriever.CHBM25Retriever import CHBM25Retriever
|
|
from app.settings import get_node_postprocessors
|
|
|
|
from llama_index.core.retrievers import BaseRetriever
|
|
from llama_index.core import QueryBundle
|
|
from llama_index.core.schema import NodeWithScore
|
|
from typing import List, Any, Optional,Dict
|
|
from llama_index.core.query_engine.retriever_query_engine import RetrieverQueryEngine
|
|
|
|
class HybridRetriever(BaseRetriever):
|
|
def __init__(
|
|
self,
|
|
vector_index,
|
|
similarity_top_k: int = 2,
|
|
out_top_k: Optional[int] = None,
|
|
alpha: float = 0.5,
|
|
filters = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
super().__init__(**kwargs)
|
|
self._vector_index = vector_index
|
|
self._embed_model = vector_index._embed_model
|
|
self._out_top_k = out_top_k or similarity_top_k
|
|
self._vecRetriever = vector_index.as_retriever(
|
|
similarity_top_k=similarity_top_k,filters = filters
|
|
)
|
|
|
|
STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm")
|
|
if os.path.exists(STORAGE_DIR) and len(os.listdir(STORAGE_DIR)) > 0:
|
|
self._bm25Retriever = CHBM25Retriever.from_persist_dir(STORAGE_DIR)
|
|
else:
|
|
bmRetriver = CHBM25Retriever.from_defaults(similarity_top_k=similarity_top_k,nodes=self._vector_index.vector_store.get_nodes(None))
|
|
bmRetriver.persist(STORAGE_DIR)
|
|
self._alpha = alpha
|
|
|
|
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
|
vecNodes:List[NodeWithScore] = self._vecRetriever.retrieve(query_bundle.query_str)
|
|
bmNodes:List[NodeWithScore] = self._bm25Retriever.retrieve(query_bundle.query_str)
|
|
|
|
bmDic:Dict[str,NodeWithScore] = {}
|
|
for node in bmNodes:
|
|
bmDic[node.node_id] = node
|
|
|
|
result_tups = []
|
|
for i in range(len(vecNodes)):
|
|
node = vecNodes[i]
|
|
bmScore = 0.0
|
|
if node.node_id in bmDic:
|
|
bmScore = bmDic[node.node_id].score
|
|
bmDic.pop(node.node_id)
|
|
else:
|
|
bmScore = 0.0
|
|
full_similarity = (self._alpha * node.score) + (
|
|
(1 - self._alpha) * bmScore
|
|
)
|
|
result_tups.append((full_similarity, node))
|
|
|
|
for _,node in bmDic.items():
|
|
full_similarity = (1 - self._alpha) * node.score
|
|
result_tups.append((full_similarity, node))
|
|
|
|
result_tups = sorted(result_tups, key=lambda x: x[0], reverse=True)
|
|
for full_score, node in result_tups:
|
|
node.score = full_score
|
|
return [n for _, n in result_tups][:self._out_top_k]
|
|
|
|
def get_Retriever(index,**kwargs):
|
|
bEnableHybrid = True if os.getenv("HYBRID_ENABLED",False).title() == 'True' else False
|
|
if bEnableHybrid:
|
|
alpha = float(os.getenv("HYBRID_ALPHA", "0.5"))
|
|
retriever = HybridRetriever(index,alpha = alpha,**kwargs)
|
|
else:
|
|
retriever = index.as_retriever(**kwargs)
|
|
return retriever
|
|
|
|
sql_database = None
|
|
sql_obj_index = None
|
|
|
|
def get_chat_engine(filters=None, params=None):
|
|
system_prompt = os.getenv("SYSTEM_PROMPT")
|
|
top_k = int(os.getenv("TOP_K", "3"))
|
|
tools = []
|
|
|
|
global sql_obj_index
|
|
global sql_database
|
|
if sql_obj_index is None:
|
|
sqlengine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
|
|
sql_database = SQLDatabase(sqlengine)
|
|
table_schema_objs = makeDescriptionByEngine(sql_database)
|
|
table_node_mapping = SQLTableNodeMapping(sql_database)
|
|
|
|
sql_obj_index = ObjectIndex.from_objects(
|
|
table_schema_objs,
|
|
table_node_mapping,
|
|
index_cls=VectorStoreIndex,
|
|
)
|
|
|
|
# 创建SQL查询工具
|
|
sql_query_engine = SQLTableRetrieverQueryEngine(sql_database,
|
|
sql_obj_index.as_retriever(similarity_top_k=top_k),
|
|
verbose=True,)
|
|
sql_query_tool = QueryEngineTool.from_defaults(query_engine=sql_query_engine,
|
|
name="zjdata_query_tool",
|
|
description="来源于一个由博微公司电力造价软件编制的造价工程文件。该文件以多张表格的形式存储存储了整个工程的全部数据内容。适用于以详细的自然语言查询表格数据方式查询造价工程各项具体属性、费用的数值。请先使用“zj_query_tool”无法解决才使用本工具")
|
|
|
|
# Add query tool if index exists
|
|
index = get_index()
|
|
if index is not None:
|
|
summary_index = SummaryIndex(index.vector_store.get_nodes(node_ids=None))
|
|
summary_query_engine = summary_index.as_query_engine()
|
|
summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool",
|
|
description="适用于任何需要进行全面总结、概括的要求。",
|
|
#description="适用于任何需要对所有内容进行全面总结的请求。有关电力造价领域更具体部分的问题,请使用zj_query_engine_tool",
|
|
)
|
|
|
|
# 创建向量检索查询工具
|
|
postprocess = get_node_postprocessors()
|
|
query_engine = RetrieverQueryEngine.from_args(
|
|
get_Retriever(index,similarity_top_k=top_k,
|
|
filters=filters),
|
|
node_postprocessors=postprocess,
|
|
)
|
|
|
|
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool",
|
|
description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。如果你不知道答案,就说你不知道,不要编造答案。",
|
|
)
|
|
|
|
tools.append(summary_query_tool)
|
|
tools.append(query_engine_tool)
|
|
#tools.append(sql_query_tool)
|
|
|
|
# Add additional tools
|
|
tools += ToolFactory.from_env()
|
|
|
|
return AgentRunner.from_llm(
|
|
llm=Settings.llm,
|
|
tools=tools,
|
|
system_prompt=system_prompt,
|
|
verbose=True,
|
|
)
|
|
# create the function calling worker for reasoning
|
|
# worker = FunctionCallingAgentWorker.from_tools(
|
|
# tools, verbose=True
|
|
# )
|
|
#
|
|
# # wrap the worker in the top-level planner
|
|
# return StructuredPlannerAgent(worker, tools)
|
|
|
|
|
|
|
|
|
|
|
|
|