dev #1

Merged
ly merged 41 commits from dev into main 2024-08-22 09:41:14 +08:00
Showing only changes of commit 6e473499b8 - Show all commits
+81 -2
View File
@@ -13,6 +13,77 @@ from app.engine.tools import ToolFactory
from app.engine.index import get_index from app.engine.index import get_index
from app.settings import get_node_postprocessors from app.settings import get_node_postprocessors
from llama_index.core.retrievers import BaseRetriever
from llama_index.core import QueryBundle
from llama_index.core.schema import NodeWithScore
from typing import List, Any, Optional,Dict
from llama_index.core.query_engine.retriever_query_engine import RetrieverQueryEngine
class HybridRetriever(BaseRetriever):
def __init__(
self,
vector_index,
similarity_top_k: int = 2,
out_top_k: Optional[int] = None,
alpha: float = 0.5,
filters = None,
**kwargs: Any,
) -> None:
from llama_index.retrievers.bm25 import BM25Retriever
from nltk.corpus import stopwords
super().__init__(**kwargs)
self._vector_index = vector_index
self._embed_model = vector_index._embed_model
self._out_top_k = out_top_k or similarity_top_k
self._vecRetriever = vector_index.as_retriever(
similarity_top_k=similarity_top_k,filters = filters
)
self._bm25Retriever = BM25Retriever.from_defaults(similarity_top_k=similarity_top_k,
nodes=self._vector_index.vector_store.get_nodes(None),
language=stopwords.words('chinese'))
self._alpha = alpha
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
vecNodes:List[NodeWithScore] = self._vecRetriever.retrieve(query_bundle.query_str)
bmNodes:List[NodeWithScore] = self._bm25Retriever.retrieve(query_bundle.query_str)
bmDic:Dict[str,NodeWithScore] = {}
for node in bmNodes:
bmDic[node.node_id] = node
result_tups = []
for i in range(len(vecNodes)):
node = vecNodes[i]
bmScore = 0.0
if node.node_id in bmDic:
bmScore = bmDic[node.node_id].score
bmDic.pop(node.node_id)
else:
bmScore = 0.0
full_similarity = (self._alpha * node.score) + (
(1 - self._alpha) * bmScore
)
result_tups.append((full_similarity, node))
for _,node in bmDic.items():
full_similarity = (1 - self._alpha) * node.score
result_tups.append((full_similarity, node))
result_tups = sorted(result_tups, key=lambda x: x[0], reverse=True)
for full_score, node in result_tups:
node.score = full_score
return [n for _, n in result_tups][:self._out_top_k]
def get_Retriever(index,**kwargs):
bEnableHybrid = True if os.getenv("HYBRID_ENABLED",False).title() == 'True' else False
if bEnableHybrid:
alpha = float(os.getenv("HYBRID_ALPHA", "0.5"))
retriever = HybridRetriever(index,alpha = alpha,**kwargs)
else:
retriever = index.as_retriever(**kwargs)
return retriever
sql_database = None sql_database = None
sql_obj_index = None sql_obj_index = None
@@ -55,10 +126,12 @@ def get_chat_engine(filters=None, params=None):
# 创建向量检索查询工具 # 创建向量检索查询工具
postprocess = get_node_postprocessors() postprocess = get_node_postprocessors()
query_engine = index.as_query_engine( query_engine = RetrieverQueryEngine.from_args(
similarity_top_k=top_k, filters=filters, get_Retriever(index,similarity_top_k=top_k,
filters=filters),
node_postprocessors=postprocess, node_postprocessors=postprocess,
) )
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool", query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool",
description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。如果你不知道答案,就说你不知道,不要编造答案。", description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。如果你不知道答案,就说你不知道,不要编造答案。",
) )
@@ -83,3 +156,9 @@ def get_chat_engine(filters=None, params=None):
# #
# # wrap the worker in the top-level planner # # wrap the worker in the top-level planner
# return StructuredPlannerAgent(worker, tools) # return StructuredPlannerAgent(worker, tools)