调整代码结构,同时修改重定义提示词的方式。
This commit is contained in:
@@ -0,0 +1,67 @@
|
||||
import os
|
||||
from typing import Optional, Any, Dict, List
|
||||
|
||||
from llama_index.core.base.base_retriever import BaseRetriever
|
||||
from llama_index.core.schema import NodeWithScore, QueryBundle
|
||||
|
||||
from app.engine.retriever import CHBM25Retriever
|
||||
|
||||
|
||||
class HybridRetriever(BaseRetriever):
|
||||
def __init__(
|
||||
self,
|
||||
vector_index,
|
||||
similarity_top_k: int = 2,
|
||||
out_top_k: Optional[int] = None,
|
||||
alpha: float = 0.5,
|
||||
filters = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._vector_index = vector_index
|
||||
self._embed_model = vector_index._embed_model
|
||||
self._out_top_k = out_top_k or similarity_top_k
|
||||
self._vecRetriever = vector_index.as_retriever(
|
||||
similarity_top_k=similarity_top_k,filters = filters
|
||||
)
|
||||
|
||||
STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm")
|
||||
if os.path.exists(STORAGE_DIR) and len(os.listdir(STORAGE_DIR)) > 0:
|
||||
self._bm25Retriever = CHBM25Retriever.from_persist_dir(STORAGE_DIR)
|
||||
else:
|
||||
bmRetriver = CHBM25Retriever.from_defaults(similarity_top_k=similarity_top_k,nodes=self._vector_index.vector_store.get_nodes(None))
|
||||
bmRetriver.persist(STORAGE_DIR)
|
||||
self._alpha = alpha
|
||||
|
||||
|
||||
|
||||
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||
vecNodes:List[NodeWithScore] = self._vecRetriever.retrieve(query_bundle.query_str)
|
||||
bmNodes:List[NodeWithScore] = self._bm25Retriever.retrieve(query_bundle.query_str)
|
||||
|
||||
bmDic:Dict[str,NodeWithScore] = {}
|
||||
for node in bmNodes:
|
||||
bmDic[node.node_id] = node
|
||||
|
||||
result_tups = []
|
||||
for i in range(len(vecNodes)):
|
||||
node = vecNodes[i]
|
||||
bmScore = 0.0
|
||||
if node.node_id in bmDic:
|
||||
bmScore = bmDic[node.node_id].score
|
||||
bmDic.pop(node.node_id)
|
||||
else:
|
||||
bmScore = 0.0
|
||||
full_similarity = (self._alpha * node.score) + (
|
||||
(1 - self._alpha) * bmScore
|
||||
)
|
||||
result_tups.append((full_similarity, node))
|
||||
|
||||
for _,node in bmDic.items():
|
||||
full_similarity = (1 - self._alpha) * node.score
|
||||
result_tups.append((full_similarity, node))
|
||||
|
||||
result_tups = sorted(result_tups, key=lambda x: x[0], reverse=True)
|
||||
for full_score, node in result_tups:
|
||||
node.score = full_score
|
||||
return [n for _, n in result_tups][:self._out_top_k]
|
||||
Reference in New Issue
Block a user