import os from llama_index.core.schema import TransformComponent, BaseNode from llama_index.core.graph_stores.types import ( EntityNode, Relation, Triplet, KG_NODES_KEY, KG_RELATIONS_KEY, ) from app.engine.loaders.projectJson import ProjectJson from app.engine.loaders.markdownReader import ChunkMarkdownReader class PrjGraphExtractor(TransformComponent): ProjectName:str def __init__(self,PrjName:str): super().__init__(ProjectName = PrjName) def __call__( self, llama_nodes: list[BaseNode], **kwargs ) -> list[BaseNode]: if len(llama_nodes) > 0: self._addPrjNode(llama_nodes[0]) for llama_node in llama_nodes: fileName = self._getFileName(llama_node) if fileName == '工程属性': self._dealAttributeNode(llama_node) else: self._dealCommonNode(llama_node) return llama_nodes def _dealCommonNode(self,llama_node:BaseNode): fileName = self._getFileName(llama_node) existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, []) existing_relations:list = llama_node.metadata.pop(KG_RELATIONS_KEY, []) records:dict[str,list] = self._getRecordNode(llama_node) fInfos = fileName.split('_') if len(fInfos) == 1: existing_nodes.append(EntityNode(name=fInfos[0], label=fInfos[0])) elif len(fInfos) == 2: existing_nodes.append(EntityNode(name=fileName, label=fInfos[1])) else: raise ValueError("文件名存在多个下划线") index = 0 for record in records: index = index + 1 rcdName = self._getRecordName(fileName,record) existing_nodes.append(EntityNode(name=rcdName, label=rcdName,properties = record)) existing_relations.append( Relation( label="包含", source_id= fileName, target_id= rcdName, properties={}, ) ) existing_relations.append( Relation( label="包含", source_id= self.ProjectName, target_id= fileName, properties={}, ) ) llama_node.metadata[KG_NODES_KEY] = existing_nodes llama_node.metadata[KG_RELATIONS_KEY] = existing_relations def _dealAttributeNode(self,llama_node:BaseNode): fileName = self._getFileName(llama_node) existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, []) existing_relations:list = llama_node.metadata.pop(KG_RELATIONS_KEY, []) records:dict[str,list] = self._getRecordNode(llama_node) existing_nodes.append(EntityNode(name=fileName, label=fileName)) index = 0 for record in records: index = index + 1 attName = self._getRecordName(fileName,record) existing_nodes.append(EntityNode(name=attName, label='属性',properties = record)) existing_relations.append( Relation( label="聚合", source_id= fileName, target_id= attName, properties={}, ) ) existing_relations.append( Relation( label="包含", source_id= self.ProjectName, target_id= fileName, properties={}, ) ) llama_node.metadata[KG_NODES_KEY] = existing_nodes llama_node.metadata[KG_RELATIONS_KEY] = existing_relations def _getRecordNode(self,llama_node:BaseNode): content = llama_node.get_content() rd = ChunkMarkdownReader() rd.markdown_to_tups(content) records = rd.records() return records def _getFileName(self,llama_node:BaseNode): meta = llama_node.metadata fileName:str = os.path.splitext(meta['file_name'])[0] return fileName def _addPrjNode(self,llama_node:BaseNode): existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, []) existing_nodes.append(EntityNode(name=self.ProjectName, label=self.ProjectName)) llama_node.metadata[KG_NODES_KEY] = existing_nodes def _getRecordName(self,fileName:str,record:dict): for name,value in record.items(): if '名称' in name: return value raise ValueError('记录名称为空')