优化属性图检索功能及支持OpenAI线上模型
This commit is contained in:
@@ -0,0 +1,42 @@
|
||||
|
||||
from typing import Any, List, Dict, Sequence, Tuple, Optional
|
||||
from llama_index.core.graph_stores.simple_labelled import SimplePropertyGraphStore
|
||||
from llama_index.core.graph_stores.types import LabelledNode,ChunkNode,LabelledPropertyGraph
|
||||
from app.engine.graph.graphTypes import *
|
||||
|
||||
class RAGPropertyGraphStore(SimplePropertyGraphStore):
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls,
|
||||
data: dict,
|
||||
) -> "SimplePropertyGraphStore":
|
||||
"""Load from dict."""
|
||||
# need to load nodes manually
|
||||
node_dicts = data["nodes"]
|
||||
relation_dicts = data["relations"]
|
||||
|
||||
kg_nodes: Dict[str, LabelledNode] = {}
|
||||
kg_relations: Dict[str, LabelledNode] = {}
|
||||
for id, node_dict in node_dicts.items():
|
||||
if "name" in node_dict:
|
||||
kg_nodes[id] = RagEntityNode.model_validate(node_dict)
|
||||
elif "text" in node_dict:
|
||||
kg_nodes[id] = ChunkNode.model_validate(node_dict)
|
||||
else:
|
||||
raise ValueError(f"Could not infer node type for data: {node_dict!s}")
|
||||
|
||||
for id, node_dict in relation_dicts.items():
|
||||
kg_relations[id] = RagRelation.model_validate(node_dict)
|
||||
|
||||
# clear the nodes, to load later
|
||||
data["nodes"] = {}
|
||||
data["relations"] = {}
|
||||
|
||||
# load the graph
|
||||
graph = LabelledPropertyGraph.model_validate(data)
|
||||
|
||||
# add the node back
|
||||
graph.nodes = kg_nodes
|
||||
graph.relations = kg_relations
|
||||
|
||||
return cls(graph)
|
||||
Reference in New Issue
Block a user