import os from typing import Any, List, Sequence, Optional from llama_index.core.schema import BaseNode, NodeWithScore, QueryBundle from llama_index.core.graph_stores.types import ( PropertyGraphStore, KG_SOURCE_REL, VECTOR_SOURCE_KEY, ) from llama_index.core.indices.property_graph.sub_retrievers.base import BasePGRetriever from llama_index.core.graph_stores.types import ( PropertyGraphStore, KG_SOURCE_REL ) from app.engine.retriever.CHBM25Retriever import CHBM25Retriever class GraphBM25Retriever(BasePGRetriever): def __init__( self, graph_store: PropertyGraphStore, include_text: bool = True, path_depth: int = 1, similarity_score: Optional[float] = None, **kwargs: Any, ) -> None: self._path_depth = path_depth self._similarity_score = similarity_score STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm") if os.path.exists(STORAGE_DIR) and len(os.listdir(STORAGE_DIR)) > 0: self._bm25Retriever = CHBM25Retriever.from_persist_dir(STORAGE_DIR) super().__init__(graph_store=graph_store, include_text=include_text, **kwargs) async def aretrieve_from_graph( self, query_bundle: QueryBundle ) -> List[NodeWithScore]: query_result:List[NodeWithScore] = self._bm25Retriever._retrieve(query_bundle.query_str) nodes,scores = [],[] for scoreNode in query_result: nodes.append(scoreNode.node) scores.append(scoreNode.score) kg_ids = self._get_kg_ids(nodes) kg_nodes = await self._graph_store.aget(ids=kg_ids) triplets = await self._graph_store.aget_rel_map( kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL] ) new_scores = [] for triplet in triplets: score1 = ( scores[kg_ids.index(triplet[0].id)] if triplet[0].id in kg_ids else 0.0 ) score2 = ( scores[kg_ids.index(triplet[2].id)] if triplet[2].id in kg_ids else 0.0 ) new_scores.append(max(score1, score2)) assert len(triplets) == len(new_scores) if self._similarity_score: filtered_data = [ (triplet, score) for triplet, score in zip(triplets, new_scores) if score >= self._similarity_score ] top_k = sorted(filtered_data, key=lambda x: x[1], reverse=True) else: top_k = sorted(zip(triplets, new_scores), key=lambda x: x[1], reverse=True) return self._get_nodes_with_score([x[0] for x in top_k], [x[1] for x in top_k]) def _get_kg_ids(self, kg_nodes: Sequence[BaseNode]) -> List[str]: """Backward compatibility method to get kg_ids from kg_nodes.""" return [node.metadata.get(VECTOR_SOURCE_KEY, node.id_) for node in kg_nodes]