60 lines
2.1 KiB
Python
60 lines
2.1 KiB
Python
|
|
from typing import Any, List, Dict, Sequence, Tuple, Optional
|
|
from llama_index.core.graph_stores.simple_labelled import SimplePropertyGraphStore
|
|
from llama_index.core.graph_stores.types import LabelledNode,ChunkNode,LabelledPropertyGraph
|
|
from app.engine.graph.graphTypes import *
|
|
|
|
class RAGPropertyGraphStore(SimplePropertyGraphStore):
|
|
@classmethod
|
|
def from_dict(
|
|
cls,
|
|
data: dict,
|
|
) -> "SimplePropertyGraphStore":
|
|
"""Load from dict."""
|
|
# need to load nodes manually
|
|
node_dicts = data["nodes"]
|
|
relation_dicts = data["relations"]
|
|
|
|
kg_nodes: Dict[str, LabelledNode] = {}
|
|
kg_relations: Dict[str, LabelledNode] = {}
|
|
for id, node_dict in node_dicts.items():
|
|
if "name" in node_dict:
|
|
kg_nodes[id] = RagEntityNode.model_validate(node_dict)
|
|
elif "text" in node_dict:
|
|
kg_nodes[id] = ChunkNode.model_validate(node_dict)
|
|
else:
|
|
raise ValueError(f"Could not infer node type for data: {node_dict!s}")
|
|
|
|
for id, node_dict in relation_dicts.items():
|
|
kg_relations[id] = RagRelation.model_validate(node_dict)
|
|
|
|
# clear the nodes, to load later
|
|
data["nodes"] = {}
|
|
data["relations"] = {}
|
|
|
|
# load the graph
|
|
graph = LabelledPropertyGraph.model_validate(data)
|
|
|
|
# add the node back
|
|
graph.nodes = kg_nodes
|
|
graph.relations = kg_relations
|
|
|
|
return cls(graph)
|
|
|
|
def save_networkx_graph(self, name: str = "kg.html") -> None:
|
|
"""Display the graph store, useful for debugging."""
|
|
import networkx as nx
|
|
|
|
G = nx.DiGraph()
|
|
for node in self.graph.nodes.values():
|
|
if isinstance(node,EntityNode):
|
|
G.add_node(node.id, label=node.name)
|
|
for triplet in self.graph.triplets:
|
|
G.add_edge(triplet[0], triplet[2], label=triplet[1])
|
|
|
|
# save to html file
|
|
from pyvis.network import Network
|
|
|
|
net = Network(notebook=False, directed=True)
|
|
net.from_nx(G)
|
|
net.write_html(name) |