优化属性图检索功能及支持OpenAI线上模型

This commit is contained in:
wanyaokun
2024-09-20 17:34:38 +08:00
parent 092f7230c1
commit f7260da6d9
12 changed files with 350 additions and 76 deletions
@@ -0,0 +1,63 @@
from typing import Any, Callable, List, Optional, Union
from llama_index.core.llms.llm import LLM
from llama_index.core.indices.property_graph.sub_retrievers.base import (
BasePGRetriever,
)
from llama_index.core.graph_stores.types import (
PropertyGraphStore,
KG_SOURCE_REL,
)
from llama_index.core.settings import Settings
from llama_index.core.schema import (
NodeWithScore,
QueryBundle,
)
from llama_index.core.graph_stores.types import EntityNode
class GraphKeyWordRetriever(BasePGRetriever):
def __init__(
self,
graph_store: PropertyGraphStore,
include_text: bool = True,
path_depth: int = 1,
llm: Optional[LLM] = None,
**kwargs: Any,
) -> None:
self._llm = llm or Settings.llm
self._path_depth = path_depth
super().__init__(graph_store=graph_store, include_text=include_text, **kwargs)
def _prepare_matches(self,query_bundle: QueryBundle) -> List[NodeWithScore]:
kg_nodes = []
labelNodes = self._graph_store.get()
for labelNode in labelNodes:
if isinstance(labelNode,EntityNode) and labelNode.name in query_bundle.query_str:
kg_nodes.append(labelNode)
triplets = self._graph_store.get_rel_map(
kg_nodes,
depth=self._path_depth,
ignore_rels=[KG_SOURCE_REL],
)
return self._get_nodes_with_score(triplets)
async def _aprepare_matches(self,query_bundle: QueryBundle) -> List[NodeWithScore]:
kg_nodes = []
labelNodes = await self._graph_store.aget()
for labelNode in labelNodes:
if isinstance(labelNode,EntityNode) and labelNode.name in query_bundle.query_str:
kg_nodes.append(labelNode)
triplets = await self._graph_store.aget_rel_map(
kg_nodes,
depth=self._path_depth,
ignore_rels=[KG_SOURCE_REL],
)
return self._get_nodes_with_score(triplets)
def retrieve_from_graph(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
return self._prepare_matches(query_bundle)
async def aretrieve_from_graph(
self, query_bundle: QueryBundle
) -> List[NodeWithScore]:
return await self._aprepare_matches(query_bundle)