64 lines
2.2 KiB
Python
64 lines
2.2 KiB
Python
from typing import Any, Callable, List, Optional, Union
|
|
from llama_index.core.llms.llm import LLM
|
|
from llama_index.core.indices.property_graph.sub_retrievers.base import (
|
|
BasePGRetriever,
|
|
)
|
|
from llama_index.core.graph_stores.types import (
|
|
PropertyGraphStore,
|
|
KG_SOURCE_REL,
|
|
)
|
|
from llama_index.core.settings import Settings
|
|
from llama_index.core.schema import (
|
|
NodeWithScore,
|
|
QueryBundle,
|
|
)
|
|
from llama_index.core.graph_stores.types import EntityNode
|
|
|
|
|
|
class GraphKeyWordRetriever(BasePGRetriever):
|
|
def __init__(
|
|
self,
|
|
graph_store: PropertyGraphStore,
|
|
include_text: bool = True,
|
|
path_depth: int = 1,
|
|
llm: Optional[LLM] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
self._llm = llm or Settings.llm
|
|
self._path_depth = path_depth
|
|
super().__init__(graph_store=graph_store, include_text=include_text, **kwargs)
|
|
|
|
def _prepare_matches(self,query_bundle: QueryBundle) -> List[NodeWithScore]:
|
|
kg_nodes = []
|
|
labelNodes = self._graph_store.get()
|
|
for labelNode in labelNodes:
|
|
if isinstance(labelNode,EntityNode) and labelNode.name in query_bundle.query_str:
|
|
kg_nodes.append(labelNode)
|
|
triplets = self._graph_store.get_rel_map(
|
|
kg_nodes,
|
|
depth=self._path_depth,
|
|
ignore_rels=[KG_SOURCE_REL],
|
|
)
|
|
return self._get_nodes_with_score(triplets)
|
|
|
|
async def _aprepare_matches(self,query_bundle: QueryBundle) -> List[NodeWithScore]:
|
|
kg_nodes = []
|
|
labelNodes = await self._graph_store.aget()
|
|
for labelNode in labelNodes:
|
|
if isinstance(labelNode,EntityNode) and labelNode.name in query_bundle.query_str:
|
|
kg_nodes.append(labelNode)
|
|
triplets = await self._graph_store.aget_rel_map(
|
|
kg_nodes,
|
|
depth=self._path_depth,
|
|
ignore_rels=[KG_SOURCE_REL],
|
|
)
|
|
return self._get_nodes_with_score(triplets)
|
|
|
|
def retrieve_from_graph(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
|
return self._prepare_matches(query_bundle)
|
|
|
|
async def aretrieve_from_graph(
|
|
self, query_bundle: QueryBundle
|
|
) -> List[NodeWithScore]:
|
|
return await self._aprepare_matches(query_bundle)
|