优化属性图检索功能及支持OpenAI线上模型
This commit is contained in:
@@ -0,0 +1,73 @@
|
||||
import os
|
||||
from typing import Any, List, Sequence, Optional
|
||||
from llama_index.core.schema import BaseNode, NodeWithScore, QueryBundle
|
||||
from llama_index.core.graph_stores.types import (
|
||||
PropertyGraphStore,
|
||||
KG_SOURCE_REL,
|
||||
VECTOR_SOURCE_KEY,
|
||||
)
|
||||
from llama_index.core.indices.property_graph.sub_retrievers.base import BasePGRetriever
|
||||
from llama_index.core.graph_stores.types import (
|
||||
PropertyGraphStore,
|
||||
KG_SOURCE_REL
|
||||
)
|
||||
from app.engine.retriever.CHBM25Retriever import CHBM25Retriever
|
||||
|
||||
class GraphBM25Retriever(BasePGRetriever):
|
||||
def __init__(
|
||||
self,
|
||||
graph_store: PropertyGraphStore,
|
||||
include_text: bool = True,
|
||||
path_depth: int = 1,
|
||||
similarity_score: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._path_depth = path_depth
|
||||
self._similarity_score = similarity_score
|
||||
STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm")
|
||||
if os.path.exists(STORAGE_DIR) and len(os.listdir(STORAGE_DIR)) > 0:
|
||||
self._bm25Retriever = CHBM25Retriever.from_persist_dir(STORAGE_DIR)
|
||||
super().__init__(graph_store=graph_store, include_text=include_text, **kwargs)
|
||||
|
||||
async def aretrieve_from_graph(
|
||||
self, query_bundle: QueryBundle
|
||||
) -> List[NodeWithScore]:
|
||||
query_result:List[NodeWithScore] = self._bm25Retriever._retrieve(query_bundle.query_str)
|
||||
nodes,scores = [],[]
|
||||
for scoreNode in query_result:
|
||||
nodes.append(scoreNode.node)
|
||||
scores.append(scoreNode.score)
|
||||
|
||||
kg_ids = self._get_kg_ids(nodes)
|
||||
kg_nodes = await self._graph_store.aget(ids=kg_ids)
|
||||
triplets = await self._graph_store.aget_rel_map(
|
||||
kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL]
|
||||
)
|
||||
new_scores = []
|
||||
for triplet in triplets:
|
||||
score1 = (
|
||||
scores[kg_ids.index(triplet[0].id)] if triplet[0].id in kg_ids else 0.0
|
||||
)
|
||||
score2 = (
|
||||
scores[kg_ids.index(triplet[2].id)] if triplet[2].id in kg_ids else 0.0
|
||||
)
|
||||
new_scores.append(max(score1, score2))
|
||||
|
||||
assert len(triplets) == len(new_scores)
|
||||
|
||||
if self._similarity_score:
|
||||
filtered_data = [
|
||||
(triplet, score)
|
||||
for triplet, score in zip(triplets, new_scores)
|
||||
if score >= self._similarity_score
|
||||
]
|
||||
|
||||
top_k = sorted(filtered_data, key=lambda x: x[1], reverse=True)
|
||||
else:
|
||||
top_k = sorted(zip(triplets, new_scores), key=lambda x: x[1], reverse=True)
|
||||
|
||||
return self._get_nodes_with_score([x[0] for x in top_k], [x[1] for x in top_k])
|
||||
|
||||
def _get_kg_ids(self, kg_nodes: Sequence[BaseNode]) -> List[str]:
|
||||
"""Backward compatibility method to get kg_ids from kg_nodes."""
|
||||
return [node.metadata.get(VECTOR_SOURCE_KEY, node.id_) for node in kg_nodes]
|
||||
@@ -0,0 +1,63 @@
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
from llama_index.core.llms.llm import LLM
|
||||
from llama_index.core.indices.property_graph.sub_retrievers.base import (
|
||||
BasePGRetriever,
|
||||
)
|
||||
from llama_index.core.graph_stores.types import (
|
||||
PropertyGraphStore,
|
||||
KG_SOURCE_REL,
|
||||
)
|
||||
from llama_index.core.settings import Settings
|
||||
from llama_index.core.schema import (
|
||||
NodeWithScore,
|
||||
QueryBundle,
|
||||
)
|
||||
from llama_index.core.graph_stores.types import EntityNode
|
||||
|
||||
|
||||
class GraphKeyWordRetriever(BasePGRetriever):
|
||||
def __init__(
|
||||
self,
|
||||
graph_store: PropertyGraphStore,
|
||||
include_text: bool = True,
|
||||
path_depth: int = 1,
|
||||
llm: Optional[LLM] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._llm = llm or Settings.llm
|
||||
self._path_depth = path_depth
|
||||
super().__init__(graph_store=graph_store, include_text=include_text, **kwargs)
|
||||
|
||||
def _prepare_matches(self,query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||
kg_nodes = []
|
||||
labelNodes = self._graph_store.get()
|
||||
for labelNode in labelNodes:
|
||||
if isinstance(labelNode,EntityNode) and labelNode.name in query_bundle.query_str:
|
||||
kg_nodes.append(labelNode)
|
||||
triplets = self._graph_store.get_rel_map(
|
||||
kg_nodes,
|
||||
depth=self._path_depth,
|
||||
ignore_rels=[KG_SOURCE_REL],
|
||||
)
|
||||
return self._get_nodes_with_score(triplets)
|
||||
|
||||
async def _aprepare_matches(self,query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||
kg_nodes = []
|
||||
labelNodes = await self._graph_store.aget()
|
||||
for labelNode in labelNodes:
|
||||
if isinstance(labelNode,EntityNode) and labelNode.name in query_bundle.query_str:
|
||||
kg_nodes.append(labelNode)
|
||||
triplets = await self._graph_store.aget_rel_map(
|
||||
kg_nodes,
|
||||
depth=self._path_depth,
|
||||
ignore_rels=[KG_SOURCE_REL],
|
||||
)
|
||||
return self._get_nodes_with_score(triplets)
|
||||
|
||||
def retrieve_from_graph(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||
return self._prepare_matches(query_bundle)
|
||||
|
||||
async def aretrieve_from_graph(
|
||||
self, query_bundle: QueryBundle
|
||||
) -> List[NodeWithScore]:
|
||||
return await self._aprepare_matches(query_bundle)
|
||||
Reference in New Issue
Block a user