42 lines
1.5 KiB
Python
42 lines
1.5 KiB
Python
|
|
from typing import Any, List, Dict, Sequence, Tuple, Optional
|
|
from llama_index.core.graph_stores.simple_labelled import SimplePropertyGraphStore
|
|
from llama_index.core.graph_stores.types import LabelledNode,ChunkNode,LabelledPropertyGraph
|
|
from app.engine.graph.graphTypes import *
|
|
|
|
class RAGPropertyGraphStore(SimplePropertyGraphStore):
|
|
@classmethod
|
|
def from_dict(
|
|
cls,
|
|
data: dict,
|
|
) -> "SimplePropertyGraphStore":
|
|
"""Load from dict."""
|
|
# need to load nodes manually
|
|
node_dicts = data["nodes"]
|
|
relation_dicts = data["relations"]
|
|
|
|
kg_nodes: Dict[str, LabelledNode] = {}
|
|
kg_relations: Dict[str, LabelledNode] = {}
|
|
for id, node_dict in node_dicts.items():
|
|
if "name" in node_dict:
|
|
kg_nodes[id] = RagEntityNode.model_validate(node_dict)
|
|
elif "text" in node_dict:
|
|
kg_nodes[id] = ChunkNode.model_validate(node_dict)
|
|
else:
|
|
raise ValueError(f"Could not infer node type for data: {node_dict!s}")
|
|
|
|
for id, node_dict in relation_dicts.items():
|
|
kg_relations[id] = RagRelation.model_validate(node_dict)
|
|
|
|
# clear the nodes, to load later
|
|
data["nodes"] = {}
|
|
data["relations"] = {}
|
|
|
|
# load the graph
|
|
graph = LabelledPropertyGraph.model_validate(data)
|
|
|
|
# add the node back
|
|
graph.nodes = kg_nodes
|
|
graph.relations = kg_relations
|
|
|
|
return cls(graph) |