73 lines
2.8 KiB
Python
73 lines
2.8 KiB
Python
import os
|
|
from typing import Any, List, Sequence, Optional
|
|
from llama_index.core.schema import BaseNode, NodeWithScore, QueryBundle
|
|
from llama_index.core.graph_stores.types import (
|
|
PropertyGraphStore,
|
|
KG_SOURCE_REL,
|
|
VECTOR_SOURCE_KEY,
|
|
)
|
|
from llama_index.core.indices.property_graph.sub_retrievers.base import BasePGRetriever
|
|
from llama_index.core.graph_stores.types import (
|
|
PropertyGraphStore,
|
|
KG_SOURCE_REL
|
|
)
|
|
from app.engine.retriever.CHBM25Retriever import CHBM25Retriever
|
|
|
|
class GraphBM25Retriever(BasePGRetriever):
|
|
def __init__(
|
|
self,
|
|
graph_store: PropertyGraphStore,
|
|
include_text: bool = True,
|
|
path_depth: int = 1,
|
|
similarity_score: Optional[float] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
self._path_depth = path_depth
|
|
self._similarity_score = similarity_score
|
|
STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm")
|
|
if os.path.exists(STORAGE_DIR) and len(os.listdir(STORAGE_DIR)) > 0:
|
|
self._bm25Retriever = CHBM25Retriever.from_persist_dir(STORAGE_DIR)
|
|
super().__init__(graph_store=graph_store, include_text=include_text, **kwargs)
|
|
|
|
async def aretrieve_from_graph(
|
|
self, query_bundle: QueryBundle
|
|
) -> List[NodeWithScore]:
|
|
query_result:List[NodeWithScore] = self._bm25Retriever._retrieve(query_bundle.query_str)
|
|
nodes,scores = [],[]
|
|
for scoreNode in query_result:
|
|
nodes.append(scoreNode.node)
|
|
scores.append(scoreNode.score)
|
|
|
|
kg_ids = self._get_kg_ids(nodes)
|
|
kg_nodes = await self._graph_store.aget(ids=kg_ids)
|
|
triplets = await self._graph_store.aget_rel_map(
|
|
kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL]
|
|
)
|
|
new_scores = []
|
|
for triplet in triplets:
|
|
score1 = (
|
|
scores[kg_ids.index(triplet[0].id)] if triplet[0].id in kg_ids else 0.0
|
|
)
|
|
score2 = (
|
|
scores[kg_ids.index(triplet[2].id)] if triplet[2].id in kg_ids else 0.0
|
|
)
|
|
new_scores.append(max(score1, score2))
|
|
|
|
assert len(triplets) == len(new_scores)
|
|
|
|
if self._similarity_score:
|
|
filtered_data = [
|
|
(triplet, score)
|
|
for triplet, score in zip(triplets, new_scores)
|
|
if score >= self._similarity_score
|
|
]
|
|
|
|
top_k = sorted(filtered_data, key=lambda x: x[1], reverse=True)
|
|
else:
|
|
top_k = sorted(zip(triplets, new_scores), key=lambda x: x[1], reverse=True)
|
|
|
|
return self._get_nodes_with_score([x[0] for x in top_k], [x[1] for x in top_k])
|
|
|
|
def _get_kg_ids(self, kg_nodes: Sequence[BaseNode]) -> List[str]:
|
|
"""Backward compatibility method to get kg_ids from kg_nodes."""
|
|
return [node.metadata.get(VECTOR_SOURCE_KEY, node.id_) for node in kg_nodes] |