修改属性图节点的层级结构,新增子父级关系

This commit is contained in:
wanyaokun
2024-09-24 17:11:20 +08:00
parent e0fc5381d8
commit aace9ce292
7 changed files with 494 additions and 188 deletions
+119 -32
View File
@@ -8,6 +8,73 @@ from app.engine.loaders.markdownReader import ChunkMarkdownReader
from app.engine.graph.graphTypes import *
import uuid
IdField = '_id'
nodeTypeField = 'nodeType'
parentIDField = 'parentId'
class Record:
def __init__(self,id:str,tableName:str,property:dict) -> None:
self._childs = []
self._id = id
self._property:dict = property
self._tableName = tableName
def add(self,rcd:'Record'):
self._childs.append(rcd)
@property
def Property(self):
return self._property
@property
def Id(self):
return self._id
@property
def Name(self):
for name,value in self._property.items():
if '名称' in name:
return value
raise ValueError('记录名称为空')
@property
def Label(self):
if '工程属性' in self._tableName:
label = '属性项'
else:
label = self.Property[nodeTypeField] if nodeTypeField in self.Property else self.Name
label = label + ''
return label
@property
def HasChild(self):
return len(self._childs) > 0
def childCount(self):
return len(self._childs)
def child(self,index:int):
return self._childs[index]
class RecordTreeMake:
def __init__(self,tableName:str,records:list) -> None:
self._records = records
self._tableName = tableName
def make(self)->list[Record]:
rcdMaps:dict[str,Record] = {}
for record in self._records:
parid= ''
if parentIDField in record:
parid = record[parentIDField]
id = record[IdField] if IdField in record else str(uuid.uuid1())
if parid in rcdMaps:
parRcd = rcdMaps[parid]
parRcd.add(Record(id,self._tableName,record))
else:
rcdMaps[id] = Record(id,self._tableName,record)
return list(rcdMaps.values())
class PrjGraphExtractor(TransformComponent):
ProjectName:str
_nodeMaps = {}
@@ -37,24 +104,20 @@ class PrjGraphExtractor(TransformComponent):
records:dict[str,list] = self._getRecordNode(llama_node)
fInfos = fileName.split('_')
if len(fInfos) == 1:
fileID = self._add_node(existing_nodes = existing_nodes,name=fInfos[0], label=fInfos[0])
fileID = self._addTableNode(existing_nodes = existing_nodes,name=fInfos[0], label=fInfos[0])
elif len(fInfos) == 2:
fileID = self._add_node(existing_nodes = existing_nodes,name=fileName, label=fInfos[1])
fileID = self._addTableNode(existing_nodes = existing_nodes,name=fInfos[1], label=fInfos[0])
else:
raise ValueError("文件名存在多个下划线")
raise ValueError("文件名存在多个下划线")
index = 0
for record in records:
index = index + 1
rcdName = self._getRecordName(fileName,record)
rcdid = self._add_node(existing_nodes = existing_nodes,name=rcdName, label=rcdName,properties = record)
existing_relations.append(
RagRelation(
label="包含",
source_id= fileID,
target_id= rcdid
)
)
rdMake = RecordTreeMake(fileName,records)
rcds:list[Record] = rdMake.make()
for record in rcds:
self._make_RecordEdge(existing_nodes = existing_nodes,
existing_relations = existing_relations,
fileID = fileID,
rcd = record)
existing_relations.append(
RagRelation(
@@ -73,13 +136,13 @@ class PrjGraphExtractor(TransformComponent):
existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, [])
existing_relations:list = llama_node.metadata.pop(KG_RELATIONS_KEY, [])
records:dict[str,list] = self._getRecordNode(llama_node)
fileID = self._add_node(existing_nodes = existing_nodes,name=fileName, label=fileName)
fileID = self._addTableNode(existing_nodes = existing_nodes,name=fileName, label=fileName)
index = 0
for record in records:
index = index + 1
attName = self._getRecordName(fileName,record)
attID = self._add_node(existing_nodes = existing_nodes,name=attName, label= attName,properties = record)
rdMake = RecordTreeMake(fileName,records)
rcds:list[Record] = rdMake.make()
for record in rcds:
attID = self._add_RecorNode(existing_nodes = existing_nodes,rcd = record)
existing_relations.append(
RagRelation(
label="聚合",
@@ -113,32 +176,56 @@ class PrjGraphExtractor(TransformComponent):
def _addPrjNode(self,llama_node:BaseNode):
existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, [])
self._prjID = self._add_node(existing_nodes = existing_nodes,name=self.ProjectName,label=self.ProjectName)
self._prjID = self._addTableNode(existing_nodes = existing_nodes,name=self.ProjectName,label=self.ProjectName)
llama_node.metadata[KG_NODES_KEY] = existing_nodes
def _getRecordName(self,fileName:str,record:dict):
for name,value in record.items():
if '名称' in name:
return value
raise ValueError('记录名称为空')
def _addTableNode(self,existing_nodes:list,name:str,label:str):
return self._add_Node(existing_nodes,name,label,str(uuid.uuid1()))
def _add_RecorNode(self,existing_nodes:list,rcd:Record):
return self._add_Node(existing_nodes,rcd.Name,rcd.Label,rcd.Id,rcd.Property)
def _add_node(self,existing_nodes:list,name:str,label:str,properties:dict = {}):
def _add_Node(self,existing_nodes:list,name:str,label:str,defaultid:str,property:dict= {}):
id:str = ''
if name in self._nodeMaps:
nodes:list[RagEntityNode] = self._nodeMaps[name]
for node in nodes:
if node.properties == properties and node.name == name and node.label == label:
if node.properties == property and node.name == name and node.label == label:
id = node.id
break
if id =='':
id = str(uuid.uuid1())
newNode = RagEntityNode(name = name,label=label,properties = properties,uid = id)
id = defaultid
newNode = RagEntityNode(name = name,label=label,properties = property,uid = id)
existing_nodes.append(newNode)
if name in self._nodeMaps:
nodes:list[RagEntityNode] = self._nodeMaps[name]
nodes.append(newNode)
else:
self._nodeMaps[name] = [newNode]
return id
return id
def _make_RecordEdge(self,existing_nodes:list,existing_relations:list,fileID:str,rcd:Record,parID:str = ''):
rcdID = self._add_RecorNode(existing_nodes = existing_nodes,rcd = rcd)
if fileID!='':
existing_relations.append(
RagRelation(
label="子级",
source_id= fileID,
target_id= rcdID
)
)
if parID!='':
existing_relations.append(
RagRelation(
label="子级",
source_id= parID,
target_id= rcdID
)
)
if rcd.HasChild:
for i in range(rcd.childCount()):
child = rcd.child(i)
self._make_RecordEdge(existing_nodes,existing_relations,'',child,rcdID)