Files
zjdataai-app/backend/app/engine/retriever/graphBM25Retriever.py
T

73 lines
2.8 KiB
Python

import os
from typing import Any, List, Sequence, Optional
from llama_index.core.schema import BaseNode, NodeWithScore, QueryBundle
from llama_index.core.graph_stores.types import (
PropertyGraphStore,
KG_SOURCE_REL,
VECTOR_SOURCE_KEY,
)
from llama_index.core.indices.property_graph.sub_retrievers.base import BasePGRetriever
from llama_index.core.graph_stores.types import (
PropertyGraphStore,
KG_SOURCE_REL
)
from app.engine.retriever.CHBM25Retriever import CHBM25Retriever
class GraphBM25Retriever(BasePGRetriever):
def __init__(
self,
graph_store: PropertyGraphStore,
include_text: bool = True,
path_depth: int = 1,
similarity_score: Optional[float] = None,
**kwargs: Any,
) -> None:
self._path_depth = path_depth
self._similarity_score = similarity_score
STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm")
if os.path.exists(STORAGE_DIR) and len(os.listdir(STORAGE_DIR)) > 0:
self._bm25Retriever = CHBM25Retriever.from_persist_dir(STORAGE_DIR)
super().__init__(graph_store=graph_store, include_text=include_text, **kwargs)
async def aretrieve_from_graph(
self, query_bundle: QueryBundle
) -> List[NodeWithScore]:
query_result:List[NodeWithScore] = self._bm25Retriever._retrieve(query_bundle.query_str)
nodes,scores = [],[]
for scoreNode in query_result:
nodes.append(scoreNode.node)
scores.append(scoreNode.score)
kg_ids = self._get_kg_ids(nodes)
kg_nodes = await self._graph_store.aget(ids=kg_ids)
triplets = await self._graph_store.aget_rel_map(
kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL]
)
new_scores = []
for triplet in triplets:
score1 = (
scores[kg_ids.index(triplet[0].id)] if triplet[0].id in kg_ids else 0.0
)
score2 = (
scores[kg_ids.index(triplet[2].id)] if triplet[2].id in kg_ids else 0.0
)
new_scores.append(max(score1, score2))
assert len(triplets) == len(new_scores)
if self._similarity_score:
filtered_data = [
(triplet, score)
for triplet, score in zip(triplets, new_scores)
if score >= self._similarity_score
]
top_k = sorted(filtered_data, key=lambda x: x[1], reverse=True)
else:
top_k = sorted(zip(triplets, new_scores), key=lambda x: x[1], reverse=True)
return self._get_nodes_with_score([x[0] for x in top_k], [x[1] for x in top_k])
def _get_kg_ids(self, kg_nodes: Sequence[BaseNode]) -> List[str]:
"""Backward compatibility method to get kg_ids from kg_nodes."""
return [node.metadata.get(VECTOR_SOURCE_KEY, node.id_) for node in kg_nodes]