67 lines
2.5 KiB
Python
67 lines
2.5 KiB
Python
import os
|
|
from typing import Optional, Any, Dict, List
|
|
|
|
from llama_index.core.base.base_retriever import BaseRetriever
|
|
from llama_index.core.schema import NodeWithScore, QueryBundle
|
|
|
|
from app.engine.retriever import CHBM25Retriever
|
|
|
|
|
|
class HybridRetriever(BaseRetriever):
|
|
def __init__(
|
|
self,
|
|
vector_index,
|
|
similarity_top_k: int = 2,
|
|
out_top_k: Optional[int] = None,
|
|
alpha: float = 0.5,
|
|
filters = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
super().__init__(**kwargs)
|
|
self._vector_index = vector_index
|
|
self._embed_model = vector_index._embed_model
|
|
self._out_top_k = out_top_k or similarity_top_k
|
|
self._vecRetriever = vector_index.as_retriever(
|
|
similarity_top_k=similarity_top_k,filters = filters
|
|
)
|
|
|
|
STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm")
|
|
if os.path.exists(STORAGE_DIR) and len(os.listdir(STORAGE_DIR)) > 0:
|
|
self._bm25Retriever = CHBM25Retriever.from_persist_dir(STORAGE_DIR)
|
|
else:
|
|
bmRetriver = CHBM25Retriever.from_defaults(similarity_top_k=similarity_top_k,nodes=self._vector_index.vector_store.get_nodes(None))
|
|
bmRetriver.persist(STORAGE_DIR)
|
|
self._alpha = alpha
|
|
|
|
|
|
|
|
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
|
vecNodes:List[NodeWithScore] = self._vecRetriever.retrieve(query_bundle.query_str)
|
|
bmNodes:List[NodeWithScore] = self._bm25Retriever.retrieve(query_bundle.query_str)
|
|
|
|
bmDic:Dict[str,NodeWithScore] = {}
|
|
for node in bmNodes:
|
|
bmDic[node.node_id] = node
|
|
|
|
result_tups = []
|
|
for i in range(len(vecNodes)):
|
|
node = vecNodes[i]
|
|
bmScore = 0.0
|
|
if node.node_id in bmDic:
|
|
bmScore = bmDic[node.node_id].score
|
|
bmDic.pop(node.node_id)
|
|
else:
|
|
bmScore = 0.0
|
|
full_similarity = (self._alpha * node.score) + (
|
|
(1 - self._alpha) * bmScore
|
|
)
|
|
result_tups.append((full_similarity, node))
|
|
|
|
for _,node in bmDic.items():
|
|
full_similarity = (1 - self._alpha) * node.score
|
|
result_tups.append((full_similarity, node))
|
|
|
|
result_tups = sorted(result_tups, key=lambda x: x[0], reverse=True)
|
|
for full_score, node in result_tups:
|
|
node.score = full_score
|
|
return [n for _, n in result_tups][:self._out_top_k] |