From 6e473499b8d65fd3e1df6409adb5d51e6788ecec Mon Sep 17 00:00:00 2001 From: wanyaokun <12345678> Date: Wed, 21 Aug 2024 19:30:43 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E6=B7=B7=E5=90=88=E6=A3=80?= =?UTF-8?q?=E7=B4=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/engine/__init__.py | 83 +++++++++++++++++++++++++++++++++- 1 file changed, 81 insertions(+), 2 deletions(-) diff --git a/backend/app/engine/__init__.py b/backend/app/engine/__init__.py index 444f925..fc36c14 100644 --- a/backend/app/engine/__init__.py +++ b/backend/app/engine/__init__.py @@ -13,6 +13,77 @@ from app.engine.tools import ToolFactory from app.engine.index import get_index from app.settings import get_node_postprocessors +from llama_index.core.retrievers import BaseRetriever +from llama_index.core import QueryBundle +from llama_index.core.schema import NodeWithScore +from typing import List, Any, Optional,Dict +from llama_index.core.query_engine.retriever_query_engine import RetrieverQueryEngine + +class HybridRetriever(BaseRetriever): + def __init__( + self, + vector_index, + similarity_top_k: int = 2, + out_top_k: Optional[int] = None, + alpha: float = 0.5, + filters = None, + **kwargs: Any, + ) -> None: + from llama_index.retrievers.bm25 import BM25Retriever + from nltk.corpus import stopwords + + super().__init__(**kwargs) + self._vector_index = vector_index + self._embed_model = vector_index._embed_model + self._out_top_k = out_top_k or similarity_top_k + self._vecRetriever = vector_index.as_retriever( + similarity_top_k=similarity_top_k,filters = filters + ) + self._bm25Retriever = BM25Retriever.from_defaults(similarity_top_k=similarity_top_k, + nodes=self._vector_index.vector_store.get_nodes(None), + language=stopwords.words('chinese')) + self._alpha = alpha + + def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + vecNodes:List[NodeWithScore] = self._vecRetriever.retrieve(query_bundle.query_str) + bmNodes:List[NodeWithScore] = self._bm25Retriever.retrieve(query_bundle.query_str) + + bmDic:Dict[str,NodeWithScore] = {} + for node in bmNodes: + bmDic[node.node_id] = node + + result_tups = [] + for i in range(len(vecNodes)): + node = vecNodes[i] + bmScore = 0.0 + if node.node_id in bmDic: + bmScore = bmDic[node.node_id].score + bmDic.pop(node.node_id) + else: + bmScore = 0.0 + full_similarity = (self._alpha * node.score) + ( + (1 - self._alpha) * bmScore + ) + result_tups.append((full_similarity, node)) + + for _,node in bmDic.items(): + full_similarity = (1 - self._alpha) * node.score + result_tups.append((full_similarity, node)) + + result_tups = sorted(result_tups, key=lambda x: x[0], reverse=True) + for full_score, node in result_tups: + node.score = full_score + return [n for _, n in result_tups][:self._out_top_k] + +def get_Retriever(index,**kwargs): + bEnableHybrid = True if os.getenv("HYBRID_ENABLED",False).title() == 'True' else False + if bEnableHybrid: + alpha = float(os.getenv("HYBRID_ALPHA", "0.5")) + retriever = HybridRetriever(index,alpha = alpha,**kwargs) + else: + retriever = index.as_retriever(**kwargs) + return retriever + sql_database = None sql_obj_index = None @@ -55,10 +126,12 @@ def get_chat_engine(filters=None, params=None): # 创建向量检索查询工具 postprocess = get_node_postprocessors() - query_engine = index.as_query_engine( - similarity_top_k=top_k, filters=filters, + query_engine = RetrieverQueryEngine.from_args( + get_Retriever(index,similarity_top_k=top_k, + filters=filters), node_postprocessors=postprocess, ) + query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool", description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。如果你不知道答案,就说你不知道,不要编造答案。", ) @@ -83,3 +156,9 @@ def get_chat_engine(filters=None, params=None): # # # wrap the worker in the top-level planner # return StructuredPlannerAgent(worker, tools) + + + + + +