232 lines
8.2 KiB
Python
232 lines
8.2 KiB
Python
import os
|
|
from llama_index.core.schema import TransformComponent, BaseNode
|
|
from llama_index.core.graph_stores.types import (
|
|
KG_NODES_KEY,
|
|
KG_RELATIONS_KEY,
|
|
)
|
|
from app.engine.loaders.markdownReader import ChunkMarkdownReader
|
|
from app.engine.graph.graphTypes import *
|
|
import uuid
|
|
|
|
IdField = '_id'
|
|
nodeTypeField = 'nodeType'
|
|
parentIDField = 'parentId'
|
|
|
|
class Record:
|
|
def __init__(self,id:str,tableName:str,property:dict) -> None:
|
|
self._childs = []
|
|
self._id = id
|
|
self._property:dict = property
|
|
self._tableName = tableName
|
|
|
|
def add(self,rcd:'Record'):
|
|
self._childs.append(rcd)
|
|
|
|
@property
|
|
def Property(self):
|
|
return self._property
|
|
|
|
@property
|
|
def Id(self):
|
|
return self._id
|
|
|
|
@property
|
|
def Name(self):
|
|
for name,value in self._property.items():
|
|
if '名称' in name:
|
|
return value
|
|
raise ValueError('记录名称为空')
|
|
|
|
@property
|
|
def Label(self):
|
|
if '工程属性' in self._tableName:
|
|
label = '属性项'
|
|
else:
|
|
label = self.Property[nodeTypeField] if nodeTypeField in self.Property else self.Name
|
|
label = label + '项'
|
|
return label
|
|
|
|
@property
|
|
def HasChild(self):
|
|
return len(self._childs) > 0
|
|
|
|
def childCount(self):
|
|
return len(self._childs)
|
|
|
|
def child(self,index:int):
|
|
return self._childs[index]
|
|
|
|
class RecordTreeMake:
|
|
def __init__(self,tableName:str,records:list) -> None:
|
|
self._records = records
|
|
self._tableName = tableName
|
|
|
|
def make(self)->list[Record]:
|
|
rcdMaps:dict[str,Record] = {}
|
|
for record in self._records:
|
|
parid= ''
|
|
if parentIDField in record:
|
|
parid = record[parentIDField]
|
|
id = record[IdField] if IdField in record else str(uuid.uuid1())
|
|
if parid in rcdMaps:
|
|
parRcd = rcdMaps[parid]
|
|
parRcd.add(Record(id,self._tableName,record))
|
|
else:
|
|
rcdMaps[id] = Record(id,self._tableName,record)
|
|
return list(rcdMaps.values())
|
|
|
|
class PrjGraphExtractor(TransformComponent):
|
|
ProjectName:str
|
|
_nodeMaps = {}
|
|
_prjID = ''
|
|
def __init__(self,PrjName:str):
|
|
super().__init__(ProjectName = PrjName)
|
|
|
|
def __call__(
|
|
self, llama_nodes: list[BaseNode], **kwargs
|
|
) -> list[BaseNode]:
|
|
if len(llama_nodes) > 0:
|
|
self._addPrjNode(llama_nodes[0])
|
|
|
|
for llama_node in llama_nodes:
|
|
fileName = self._getFileName(llama_node)
|
|
if fileName == '工程属性':
|
|
self._dealAttributeNode(llama_node)
|
|
else:
|
|
self._dealCommonNode(llama_node)
|
|
return llama_nodes
|
|
|
|
def _dealCommonNode(self,llama_node:BaseNode):
|
|
fileName = self._getFileName(llama_node)
|
|
existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, [])
|
|
existing_relations:list = llama_node.metadata.pop(KG_RELATIONS_KEY, [])
|
|
|
|
records:dict[str,list] = self._getRecordNode(llama_node)
|
|
fInfos = fileName.split('_')
|
|
if len(fInfos) == 1:
|
|
fileID = self._addTableNode(existing_nodes = existing_nodes,name=fInfos[0], label=fInfos[0])
|
|
elif len(fInfos) == 2:
|
|
fileID = self._addTableNode(existing_nodes = existing_nodes,name=fInfos[1], label=fInfos[0])
|
|
else:
|
|
raise ValueError("文件名存在多个下划线")
|
|
|
|
rdMake = RecordTreeMake(fileName,records)
|
|
rcds:list[Record] = rdMake.make()
|
|
|
|
for record in rcds:
|
|
self._make_RecordEdge(existing_nodes = existing_nodes,
|
|
existing_relations = existing_relations,
|
|
fileID = fileID,
|
|
rcd = record)
|
|
|
|
existing_relations.append(
|
|
RagRelation(
|
|
label="包含",
|
|
source_id= self._prjID,
|
|
target_id= fileID
|
|
)
|
|
)
|
|
|
|
|
|
llama_node.metadata[KG_NODES_KEY] = existing_nodes
|
|
llama_node.metadata[KG_RELATIONS_KEY] = existing_relations
|
|
|
|
def _dealAttributeNode(self,llama_node:BaseNode):
|
|
fileName = self._getFileName(llama_node)
|
|
existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, [])
|
|
existing_relations:list = llama_node.metadata.pop(KG_RELATIONS_KEY, [])
|
|
records:dict[str,list] = self._getRecordNode(llama_node)
|
|
fileID = self._addTableNode(existing_nodes = existing_nodes,name=fileName, label=fileName)
|
|
|
|
rdMake = RecordTreeMake(fileName,records)
|
|
rcds:list[Record] = rdMake.make()
|
|
|
|
for record in rcds:
|
|
attID = self._add_RecorNode(existing_nodes = existing_nodes,rcd = record)
|
|
existing_relations.append(
|
|
RagRelation(
|
|
label="聚合",
|
|
source_id= fileID,
|
|
target_id= attID
|
|
)
|
|
)
|
|
|
|
existing_relations.append(
|
|
RagRelation(
|
|
label="包含",
|
|
source_id= self._prjID,
|
|
target_id= fileID
|
|
)
|
|
)
|
|
|
|
llama_node.metadata[KG_NODES_KEY] = existing_nodes
|
|
llama_node.metadata[KG_RELATIONS_KEY] = existing_relations
|
|
|
|
def _getRecordNode(self,llama_node:BaseNode):
|
|
content = llama_node.get_content()
|
|
rd = ChunkMarkdownReader()
|
|
rd.markdown_to_tups(content)
|
|
records = rd.records()
|
|
return records
|
|
|
|
def _getFileName(self,llama_node:BaseNode):
|
|
meta = llama_node.metadata
|
|
fileName:str = os.path.splitext(meta['file_name'])[0]
|
|
return fileName
|
|
|
|
def _addPrjNode(self,llama_node:BaseNode):
|
|
existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, [])
|
|
self._prjID = self._addTableNode(existing_nodes = existing_nodes,name=self.ProjectName,label=self.ProjectName)
|
|
llama_node.metadata[KG_NODES_KEY] = existing_nodes
|
|
|
|
def _addTableNode(self,existing_nodes:list,name:str,label:str):
|
|
return self._add_Node(existing_nodes,name,label,str(uuid.uuid1()))
|
|
|
|
def _add_RecorNode(self,existing_nodes:list,rcd:Record):
|
|
return self._add_Node(existing_nodes,rcd.Name,rcd.Label,rcd.Id,rcd.Property)
|
|
|
|
def _add_Node(self,existing_nodes:list,name:str,label:str,defaultid:str,property:dict= {}):
|
|
id:str = ''
|
|
if name in self._nodeMaps:
|
|
nodes:list[RagEntityNode] = self._nodeMaps[name]
|
|
for node in nodes:
|
|
if node.properties == property and node.name == name and node.label == label:
|
|
id = node.id
|
|
break
|
|
|
|
if id =='':
|
|
id = defaultid
|
|
newNode = RagEntityNode(name = name,label=label,properties = property,uid = id)
|
|
existing_nodes.append(newNode)
|
|
if name in self._nodeMaps:
|
|
nodes:list[RagEntityNode] = self._nodeMaps[name]
|
|
nodes.append(newNode)
|
|
else:
|
|
self._nodeMaps[name] = [newNode]
|
|
return id
|
|
|
|
def _make_RecordEdge(self,existing_nodes:list,existing_relations:list,fileID:str,rcd:Record,parID:str = ''):
|
|
rcdID = self._add_RecorNode(existing_nodes = existing_nodes,rcd = rcd)
|
|
if fileID!='':
|
|
existing_relations.append(
|
|
RagRelation(
|
|
label="子级",
|
|
source_id= fileID,
|
|
target_id= rcdID
|
|
)
|
|
)
|
|
|
|
if parID!='':
|
|
existing_relations.append(
|
|
RagRelation(
|
|
label="子级",
|
|
source_id= parID,
|
|
target_id= rcdID
|
|
)
|
|
)
|
|
|
|
if rcd.HasChild:
|
|
for i in range(rcd.childCount()):
|
|
child = rcd.child(i)
|
|
self._make_RecordEdge(existing_nodes,existing_relations,'',child,rcdID)
|