修改属性图节点的层级结构,新增子父级关系
This commit is contained in:
@@ -14,6 +14,7 @@ from app.engine.response.treeSummResponse import CustomTreeResponse
|
||||
from llama_index.core.settings import Settings
|
||||
from llama_index.core.indices.property_graph import LLMSynonymRetriever,VectorContextRetriever
|
||||
from llama_index.core import PropertyGraphIndex
|
||||
from app.engine.retriever.graphKeyWordRetriever import GraphKeyWordRetriever
|
||||
|
||||
ModelPlateCategory = '模型平台'
|
||||
|
||||
@@ -122,8 +123,7 @@ def create_query_engine(index,top_k=3, use_reranker=False, filters=None, respons
|
||||
llm_query = os.getenv('LLM_QUERY_WAY','rag')
|
||||
if llm_query == 'graph':
|
||||
graphIndex:PropertyGraphIndex = index
|
||||
synonym_retriver = LLMSynonymRetriever(graphIndex.property_graph_store,
|
||||
llm=Settings.llm,
|
||||
keyWord_retriver = GraphKeyWordRetriever(graphIndex.property_graph_store,
|
||||
include_text=False
|
||||
)
|
||||
if graphIndex.property_graph_store.supports_vector_queries:
|
||||
@@ -137,7 +137,7 @@ def create_query_engine(index,top_k=3, use_reranker=False, filters=None, respons
|
||||
include_text=False
|
||||
)
|
||||
|
||||
retriever = graphIndex.as_retriever(sub_retrievers=[synonym_retriver,vector_retriver])
|
||||
retriever = graphIndex.as_retriever(sub_retrievers=[keyWord_retriver,vector_retriver])
|
||||
|
||||
else:
|
||||
retriever = get_Retriever(index,
|
||||
|
||||
@@ -8,6 +8,73 @@ from app.engine.loaders.markdownReader import ChunkMarkdownReader
|
||||
from app.engine.graph.graphTypes import *
|
||||
import uuid
|
||||
|
||||
IdField = '_id'
|
||||
nodeTypeField = 'nodeType'
|
||||
parentIDField = 'parentId'
|
||||
|
||||
class Record:
|
||||
def __init__(self,id:str,tableName:str,property:dict) -> None:
|
||||
self._childs = []
|
||||
self._id = id
|
||||
self._property:dict = property
|
||||
self._tableName = tableName
|
||||
|
||||
def add(self,rcd:'Record'):
|
||||
self._childs.append(rcd)
|
||||
|
||||
@property
|
||||
def Property(self):
|
||||
return self._property
|
||||
|
||||
@property
|
||||
def Id(self):
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def Name(self):
|
||||
for name,value in self._property.items():
|
||||
if '名称' in name:
|
||||
return value
|
||||
raise ValueError('记录名称为空')
|
||||
|
||||
@property
|
||||
def Label(self):
|
||||
if '工程属性' in self._tableName:
|
||||
label = '属性项'
|
||||
else:
|
||||
label = self.Property[nodeTypeField] if nodeTypeField in self.Property else self.Name
|
||||
label = label + '项'
|
||||
return label
|
||||
|
||||
@property
|
||||
def HasChild(self):
|
||||
return len(self._childs) > 0
|
||||
|
||||
def childCount(self):
|
||||
return len(self._childs)
|
||||
|
||||
def child(self,index:int):
|
||||
return self._childs[index]
|
||||
|
||||
class RecordTreeMake:
|
||||
def __init__(self,tableName:str,records:list) -> None:
|
||||
self._records = records
|
||||
self._tableName = tableName
|
||||
|
||||
def make(self)->list[Record]:
|
||||
rcdMaps:dict[str,Record] = {}
|
||||
for record in self._records:
|
||||
parid= ''
|
||||
if parentIDField in record:
|
||||
parid = record[parentIDField]
|
||||
id = record[IdField] if IdField in record else str(uuid.uuid1())
|
||||
if parid in rcdMaps:
|
||||
parRcd = rcdMaps[parid]
|
||||
parRcd.add(Record(id,self._tableName,record))
|
||||
else:
|
||||
rcdMaps[id] = Record(id,self._tableName,record)
|
||||
return list(rcdMaps.values())
|
||||
|
||||
class PrjGraphExtractor(TransformComponent):
|
||||
ProjectName:str
|
||||
_nodeMaps = {}
|
||||
@@ -37,24 +104,20 @@ class PrjGraphExtractor(TransformComponent):
|
||||
records:dict[str,list] = self._getRecordNode(llama_node)
|
||||
fInfos = fileName.split('_')
|
||||
if len(fInfos) == 1:
|
||||
fileID = self._add_node(existing_nodes = existing_nodes,name=fInfos[0], label=fInfos[0])
|
||||
fileID = self._addTableNode(existing_nodes = existing_nodes,name=fInfos[0], label=fInfos[0])
|
||||
elif len(fInfos) == 2:
|
||||
fileID = self._add_node(existing_nodes = existing_nodes,name=fileName, label=fInfos[1])
|
||||
fileID = self._addTableNode(existing_nodes = existing_nodes,name=fInfos[1], label=fInfos[0])
|
||||
else:
|
||||
raise ValueError("文件名存在多个下划线")
|
||||
raise ValueError("文件名存在多个下划线")
|
||||
|
||||
index = 0
|
||||
for record in records:
|
||||
index = index + 1
|
||||
rcdName = self._getRecordName(fileName,record)
|
||||
rcdid = self._add_node(existing_nodes = existing_nodes,name=rcdName, label=rcdName,properties = record)
|
||||
existing_relations.append(
|
||||
RagRelation(
|
||||
label="包含",
|
||||
source_id= fileID,
|
||||
target_id= rcdid
|
||||
)
|
||||
)
|
||||
rdMake = RecordTreeMake(fileName,records)
|
||||
rcds:list[Record] = rdMake.make()
|
||||
|
||||
for record in rcds:
|
||||
self._make_RecordEdge(existing_nodes = existing_nodes,
|
||||
existing_relations = existing_relations,
|
||||
fileID = fileID,
|
||||
rcd = record)
|
||||
|
||||
existing_relations.append(
|
||||
RagRelation(
|
||||
@@ -73,13 +136,13 @@ class PrjGraphExtractor(TransformComponent):
|
||||
existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, [])
|
||||
existing_relations:list = llama_node.metadata.pop(KG_RELATIONS_KEY, [])
|
||||
records:dict[str,list] = self._getRecordNode(llama_node)
|
||||
fileID = self._add_node(existing_nodes = existing_nodes,name=fileName, label=fileName)
|
||||
fileID = self._addTableNode(existing_nodes = existing_nodes,name=fileName, label=fileName)
|
||||
|
||||
index = 0
|
||||
for record in records:
|
||||
index = index + 1
|
||||
attName = self._getRecordName(fileName,record)
|
||||
attID = self._add_node(existing_nodes = existing_nodes,name=attName, label= attName,properties = record)
|
||||
rdMake = RecordTreeMake(fileName,records)
|
||||
rcds:list[Record] = rdMake.make()
|
||||
|
||||
for record in rcds:
|
||||
attID = self._add_RecorNode(existing_nodes = existing_nodes,rcd = record)
|
||||
existing_relations.append(
|
||||
RagRelation(
|
||||
label="聚合",
|
||||
@@ -113,32 +176,56 @@ class PrjGraphExtractor(TransformComponent):
|
||||
|
||||
def _addPrjNode(self,llama_node:BaseNode):
|
||||
existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, [])
|
||||
self._prjID = self._add_node(existing_nodes = existing_nodes,name=self.ProjectName,label=self.ProjectName)
|
||||
self._prjID = self._addTableNode(existing_nodes = existing_nodes,name=self.ProjectName,label=self.ProjectName)
|
||||
llama_node.metadata[KG_NODES_KEY] = existing_nodes
|
||||
|
||||
def _getRecordName(self,fileName:str,record:dict):
|
||||
for name,value in record.items():
|
||||
if '名称' in name:
|
||||
return value
|
||||
raise ValueError('记录名称为空')
|
||||
def _addTableNode(self,existing_nodes:list,name:str,label:str):
|
||||
return self._add_Node(existing_nodes,name,label,str(uuid.uuid1()))
|
||||
|
||||
def _add_RecorNode(self,existing_nodes:list,rcd:Record):
|
||||
return self._add_Node(existing_nodes,rcd.Name,rcd.Label,rcd.Id,rcd.Property)
|
||||
|
||||
def _add_node(self,existing_nodes:list,name:str,label:str,properties:dict = {}):
|
||||
def _add_Node(self,existing_nodes:list,name:str,label:str,defaultid:str,property:dict= {}):
|
||||
id:str = ''
|
||||
if name in self._nodeMaps:
|
||||
nodes:list[RagEntityNode] = self._nodeMaps[name]
|
||||
for node in nodes:
|
||||
if node.properties == properties and node.name == name and node.label == label:
|
||||
if node.properties == property and node.name == name and node.label == label:
|
||||
id = node.id
|
||||
break
|
||||
|
||||
if id =='':
|
||||
id = str(uuid.uuid1())
|
||||
newNode = RagEntityNode(name = name,label=label,properties = properties,uid = id)
|
||||
id = defaultid
|
||||
newNode = RagEntityNode(name = name,label=label,properties = property,uid = id)
|
||||
existing_nodes.append(newNode)
|
||||
if name in self._nodeMaps:
|
||||
nodes:list[RagEntityNode] = self._nodeMaps[name]
|
||||
nodes.append(newNode)
|
||||
else:
|
||||
self._nodeMaps[name] = [newNode]
|
||||
return id
|
||||
|
||||
return id
|
||||
def _make_RecordEdge(self,existing_nodes:list,existing_relations:list,fileID:str,rcd:Record,parID:str = ''):
|
||||
rcdID = self._add_RecorNode(existing_nodes = existing_nodes,rcd = rcd)
|
||||
if fileID!='':
|
||||
existing_relations.append(
|
||||
RagRelation(
|
||||
label="子级",
|
||||
source_id= fileID,
|
||||
target_id= rcdID
|
||||
)
|
||||
)
|
||||
|
||||
if parID!='':
|
||||
existing_relations.append(
|
||||
RagRelation(
|
||||
label="子级",
|
||||
source_id= parID,
|
||||
target_id= rcdID
|
||||
)
|
||||
)
|
||||
|
||||
if rcd.HasChild:
|
||||
for i in range(rcd.childCount()):
|
||||
child = rcd.child(i)
|
||||
self._make_RecordEdge(existing_nodes,existing_relations,'',child,rcdID)
|
||||
|
||||
@@ -39,4 +39,22 @@ class RAGPropertyGraphStore(SimplePropertyGraphStore):
|
||||
graph.nodes = kg_nodes
|
||||
graph.relations = kg_relations
|
||||
|
||||
return cls(graph)
|
||||
return cls(graph)
|
||||
|
||||
def save_networkx_graph(self, name: str = "kg.html") -> None:
|
||||
"""Display the graph store, useful for debugging."""
|
||||
import networkx as nx
|
||||
|
||||
G = nx.DiGraph()
|
||||
for node in self.graph.nodes.values():
|
||||
if isinstance(node,EntityNode):
|
||||
G.add_node(node.id, label=node.name)
|
||||
for triplet in self.graph.triplets:
|
||||
G.add_edge(triplet[0], triplet[2], label=triplet[1])
|
||||
|
||||
# save to html file
|
||||
from pyvis.network import Network
|
||||
|
||||
net = Network(notebook=False, directed=True)
|
||||
net.from_nx(G)
|
||||
net.write_html(name)
|
||||
@@ -3,7 +3,6 @@ from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.vector_stores.qdrant import QdrantVectorStore
|
||||
from qdrant_client import qdrant_client
|
||||
from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore
|
||||
|
||||
qclient = None
|
||||
|
||||
def get_qdrant_vector_store(docType:str):
|
||||
@@ -74,6 +73,12 @@ def get_vector_store(docType:str):
|
||||
return store
|
||||
|
||||
def get_Neo4j_Graph_Store(docType:str):
|
||||
from neo4j import GraphDatabase
|
||||
driver = GraphDatabase.driver(os.getenv('NEO4J_URL'), auth=(os.getenv('NEO4J_USERNAME'), os.getenv('NEO4J_PASSWORD')))
|
||||
with driver.session() as session:
|
||||
session.run("MATCH (n) DETACH DELETE n")
|
||||
driver.close()
|
||||
|
||||
neo4jStore = Neo4jPropertyGraphStore(
|
||||
username= os.getenv('NEO4J_USERNAME'),
|
||||
password= os.getenv('NEO4J_PASSWORD'),
|
||||
|
||||
Reference in New Issue
Block a user