from typing import Any, Callable, List, Optional, Union from llama_index.core.llms.llm import LLM from llama_index.core.indices.property_graph.sub_retrievers.base import ( BasePGRetriever, ) from llama_index.core.graph_stores.types import ( PropertyGraphStore, KG_SOURCE_REL, ) from llama_index.core.settings import Settings from llama_index.core.schema import ( NodeWithScore, QueryBundle, ) from llama_index.core.graph_stores.types import EntityNode class GraphKeyWordRetriever(BasePGRetriever): def __init__( self, graph_store: PropertyGraphStore, include_text: bool = True, path_depth: int = 1, llm: Optional[LLM] = None, **kwargs: Any, ) -> None: self._llm = llm or Settings.llm self._path_depth = path_depth super().__init__(graph_store=graph_store, include_text=include_text, **kwargs) def _prepare_matches(self,query_bundle: QueryBundle) -> List[NodeWithScore]: kg_nodes = [] labelNodes = self._graph_store.get() for labelNode in labelNodes: if isinstance(labelNode,EntityNode) and labelNode.name in query_bundle.query_str: kg_nodes.append(labelNode) triplets = self._graph_store.get_rel_map( kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL], ) return self._get_nodes_with_score(triplets) async def _aprepare_matches(self,query_bundle: QueryBundle) -> List[NodeWithScore]: kg_nodes = [] labelNodes = await self._graph_store.aget() for labelNode in labelNodes: if isinstance(labelNode,EntityNode) and labelNode.name in query_bundle.query_str: kg_nodes.append(labelNode) triplets = await self._graph_store.aget_rel_map( kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL], ) return self._get_nodes_with_score(triplets) def retrieve_from_graph(self, query_bundle: QueryBundle) -> List[NodeWithScore]: return self._prepare_matches(query_bundle) async def aretrieve_from_graph( self, query_bundle: QueryBundle ) -> List[NodeWithScore]: return await self._aprepare_matches(query_bundle)