import os from llama_index.core.schema import TransformComponent, BaseNode from llama_index.core.graph_stores.types import ( KG_NODES_KEY, KG_RELATIONS_KEY, ) from app.engine.loaders.markdownReader import ChunkMarkdownReader from app.engine.graph.graphTypes import * import uuid IdField = '_id' nodeTypeField = 'nodeType' parentIDField = 'parentId' class Record: def __init__(self,id:str,tableName:str,property:dict) -> None: self._childs = [] self._id = id self._property:dict = property self._tableName = tableName def add(self,rcd:'Record'): self._childs.append(rcd) @property def Property(self): return self._property @property def Id(self): return self._id @property def Name(self): for name,value in self._property.items(): if '名称' in name: return value raise ValueError('记录名称为空') @property def Label(self): if '工程属性' in self._tableName: label = '属性项' else: label = self.Property[nodeTypeField] if nodeTypeField in self.Property else self.Name label = label + '项' return label @property def HasChild(self): return len(self._childs) > 0 def childCount(self): return len(self._childs) def child(self,index:int): return self._childs[index] class RecordTreeMake: def __init__(self,tableName:str,records:list) -> None: self._records = records self._tableName = tableName def make(self)->list[Record]: rcdMaps:dict[str,Record] = {} for record in self._records: parid= '' if parentIDField in record: parid = record[parentIDField] id = record[IdField] if IdField in record else str(uuid.uuid1()) if parid in rcdMaps: parRcd = rcdMaps[parid] parRcd.add(Record(id,self._tableName,record)) else: rcdMaps[id] = Record(id,self._tableName,record) return list(rcdMaps.values()) class PrjGraphExtractor(TransformComponent): ProjectName:str _nodeMaps = {} _prjID = '' def __init__(self,PrjName:str): super().__init__(ProjectName = PrjName) def __call__( self, llama_nodes: list[BaseNode], **kwargs ) -> list[BaseNode]: if len(llama_nodes) > 0: self._addPrjNode(llama_nodes[0]) for llama_node in llama_nodes: fileName = self._getFileName(llama_node) if fileName == '工程属性': self._dealAttributeNode(llama_node) else: self._dealCommonNode(llama_node) return llama_nodes def _dealCommonNode(self,llama_node:BaseNode): fileName = self._getFileName(llama_node) existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, []) existing_relations:list = llama_node.metadata.pop(KG_RELATIONS_KEY, []) records:dict[str,list] = self._getRecordNode(llama_node) fInfos = fileName.split('_') if len(fInfos) == 1: fileID = self._addTableNode(existing_nodes = existing_nodes,name=fInfos[0], label=fInfos[0]) elif len(fInfos) == 2: fileID = self._addTableNode(existing_nodes = existing_nodes,name=fInfos[1], label=fInfos[0]) else: raise ValueError("文件名存在多个下划线") rdMake = RecordTreeMake(fileName,records) rcds:list[Record] = rdMake.make() for record in rcds: self._make_RecordEdge(existing_nodes = existing_nodes, existing_relations = existing_relations, fileID = fileID, rcd = record) existing_relations.append( RagRelation( label="包含", source_id= self._prjID, target_id= fileID ) ) llama_node.metadata[KG_NODES_KEY] = existing_nodes llama_node.metadata[KG_RELATIONS_KEY] = existing_relations def _dealAttributeNode(self,llama_node:BaseNode): fileName = self._getFileName(llama_node) existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, []) existing_relations:list = llama_node.metadata.pop(KG_RELATIONS_KEY, []) records:dict[str,list] = self._getRecordNode(llama_node) fileID = self._addTableNode(existing_nodes = existing_nodes,name=fileName, label=fileName) rdMake = RecordTreeMake(fileName,records) rcds:list[Record] = rdMake.make() for record in rcds: attID = self._add_RecorNode(existing_nodes = existing_nodes,rcd = record) existing_relations.append( RagRelation( label="聚合", source_id= fileID, target_id= attID ) ) existing_relations.append( RagRelation( label="包含", source_id= self._prjID, target_id= fileID ) ) llama_node.metadata[KG_NODES_KEY] = existing_nodes llama_node.metadata[KG_RELATIONS_KEY] = existing_relations def _getRecordNode(self,llama_node:BaseNode): content = llama_node.get_content() rd = ChunkMarkdownReader() rd.markdown_to_tups(content) records = rd.records() return records def _getFileName(self,llama_node:BaseNode): meta = llama_node.metadata fileName:str = os.path.splitext(meta['file_name'])[0] return fileName def _addPrjNode(self,llama_node:BaseNode): existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, []) self._prjID = self._addTableNode(existing_nodes = existing_nodes,name=self.ProjectName,label=self.ProjectName) llama_node.metadata[KG_NODES_KEY] = existing_nodes def _addTableNode(self,existing_nodes:list,name:str,label:str): return self._add_Node(existing_nodes,name,label,str(uuid.uuid1())) def _add_RecorNode(self,existing_nodes:list,rcd:Record): return self._add_Node(existing_nodes,rcd.Name,rcd.Label,rcd.Id,rcd.Property) def _add_Node(self,existing_nodes:list,name:str,label:str,defaultid:str,property:dict= {}): id:str = '' if name in self._nodeMaps: nodes:list[RagEntityNode] = self._nodeMaps[name] for node in nodes: if node.properties == property and node.name == name and node.label == label: id = node.id break if id =='': id = defaultid newNode = RagEntityNode(name = name,label=label,properties = property,uid = id) existing_nodes.append(newNode) if name in self._nodeMaps: nodes:list[RagEntityNode] = self._nodeMaps[name] nodes.append(newNode) else: self._nodeMaps[name] = [newNode] return id def _make_RecordEdge(self,existing_nodes:list,existing_relations:list,fileID:str,rcd:Record,parID:str = ''): rcdID = self._add_RecorNode(existing_nodes = existing_nodes,rcd = rcd) if fileID!='': existing_relations.append( RagRelation( label="子级", source_id= fileID, target_id= rcdID ) ) if parID!='': existing_relations.append( RagRelation( label="子级", source_id= parID, target_id= rcdID ) ) if rcd.HasChild: for i in range(rcd.childCount()): child = rcd.child(i) self._make_RecordEdge(existing_nodes,existing_relations,'',child,rcdID)