import os from llama_index.core.schema import TransformComponent, BaseNode from llama_index.core.graph_stores.types import ( KG_NODES_KEY, KG_RELATIONS_KEY, ) from app.engine.loaders.markdownReader import ChunkMarkdownReader from app.engine.graph.graphTypes import * import uuid class PrjGraphExtractor(TransformComponent): ProjectName:str _nodeMaps = {} _prjID = '' def __init__(self,PrjName:str): super().__init__(ProjectName = PrjName) def __call__( self, llama_nodes: list[BaseNode], **kwargs ) -> list[BaseNode]: if len(llama_nodes) > 0: self._addPrjNode(llama_nodes[0]) for llama_node in llama_nodes: fileName = self._getFileName(llama_node) if fileName == '工程属性': self._dealAttributeNode(llama_node) else: self._dealCommonNode(llama_node) return llama_nodes def _dealCommonNode(self,llama_node:BaseNode): fileName = self._getFileName(llama_node) existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, []) existing_relations:list = llama_node.metadata.pop(KG_RELATIONS_KEY, []) records:dict[str,list] = self._getRecordNode(llama_node) fInfos = fileName.split('_') if len(fInfos) == 1: fileID = self._add_node(existing_nodes = existing_nodes,name=fInfos[0], label=fInfos[0]) elif len(fInfos) == 2: fileID = self._add_node(existing_nodes = existing_nodes,name=fileName, label=fInfos[1]) else: raise ValueError("文件名存在多个下划线") index = 0 for record in records: index = index + 1 rcdName = self._getRecordName(fileName,record) rcdid = self._add_node(existing_nodes = existing_nodes,name=rcdName, label=rcdName,properties = record) existing_relations.append( RagRelation( label="包含", source_id= fileID, target_id= rcdid ) ) existing_relations.append( RagRelation( label="包含", source_id= self._prjID, target_id= fileID ) ) llama_node.metadata[KG_NODES_KEY] = existing_nodes llama_node.metadata[KG_RELATIONS_KEY] = existing_relations def _dealAttributeNode(self,llama_node:BaseNode): fileName = self._getFileName(llama_node) existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, []) existing_relations:list = llama_node.metadata.pop(KG_RELATIONS_KEY, []) records:dict[str,list] = self._getRecordNode(llama_node) fileID = self._add_node(existing_nodes = existing_nodes,name=fileName, label=fileName) index = 0 for record in records: index = index + 1 attName = self._getRecordName(fileName,record) attID = self._add_node(existing_nodes = existing_nodes,name=attName, label= attName,properties = record) existing_relations.append( RagRelation( label="聚合", source_id= fileID, target_id= attID ) ) existing_relations.append( RagRelation( label="包含", source_id= self._prjID, target_id= fileID ) ) llama_node.metadata[KG_NODES_KEY] = existing_nodes llama_node.metadata[KG_RELATIONS_KEY] = existing_relations def _getRecordNode(self,llama_node:BaseNode): content = llama_node.get_content() rd = ChunkMarkdownReader() rd.markdown_to_tups(content) records = rd.records() return records def _getFileName(self,llama_node:BaseNode): meta = llama_node.metadata fileName:str = os.path.splitext(meta['file_name'])[0] return fileName def _addPrjNode(self,llama_node:BaseNode): existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, []) self._prjID = self._add_node(existing_nodes = existing_nodes,name=self.ProjectName,label=self.ProjectName) llama_node.metadata[KG_NODES_KEY] = existing_nodes def _getRecordName(self,fileName:str,record:dict): for name,value in record.items(): if '名称' in name: return value raise ValueError('记录名称为空') def _add_node(self,existing_nodes:list,name:str,label:str,properties:dict = {}): id:str = '' if name in self._nodeMaps: nodes:list[RagEntityNode] = self._nodeMaps[name] for node in nodes: if node.properties == properties and node.name == name and node.label == label: id = node.id break if id =='': id = str(uuid.uuid1()) newNode = RagEntityNode(name = name,label=label,properties = properties,uid = id) existing_nodes.append(newNode) if name in self._nodeMaps: nodes:list[RagEntityNode] = self._nodeMaps[name] nodes.append(newNode) else: self._nodeMaps[name] = [newNode] return id