128 lines
4.8 KiB
Python
128 lines
4.8 KiB
Python
import os
|
|
from llama_index.core.schema import TransformComponent, BaseNode
|
|
from llama_index.core.graph_stores.types import (
|
|
EntityNode,
|
|
Relation,
|
|
Triplet,
|
|
KG_NODES_KEY,
|
|
KG_RELATIONS_KEY,
|
|
)
|
|
from app.engine.loaders.projectJson import ProjectJson
|
|
from app.engine.loaders.markdownReader import ChunkMarkdownReader
|
|
|
|
class PrjGraphExtractor(TransformComponent):
|
|
ProjectName:str
|
|
def __init__(self,PrjName:str):
|
|
super().__init__(ProjectName = PrjName)
|
|
|
|
|
|
def __call__(
|
|
self, llama_nodes: list[BaseNode], **kwargs
|
|
) -> list[BaseNode]:
|
|
if len(llama_nodes) > 0:
|
|
self._addPrjNode(llama_nodes[0])
|
|
|
|
for llama_node in llama_nodes:
|
|
fileName = self._getFileName(llama_node)
|
|
if fileName == '工程属性':
|
|
self._dealAttributeNode(llama_node)
|
|
else:
|
|
self._dealCommonNode(llama_node)
|
|
return llama_nodes
|
|
|
|
def _dealCommonNode(self,llama_node:BaseNode):
|
|
fileName = self._getFileName(llama_node)
|
|
existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, [])
|
|
existing_relations:list = llama_node.metadata.pop(KG_RELATIONS_KEY, [])
|
|
|
|
records:dict[str,list] = self._getRecordNode(llama_node)
|
|
fInfos = fileName.split('_')
|
|
if len(fInfos) == 1:
|
|
existing_nodes.append(EntityNode(name=fInfos[0], label=fInfos[0]))
|
|
elif len(fInfos) == 2:
|
|
existing_nodes.append(EntityNode(name=fileName, label=fInfos[1]))
|
|
else:
|
|
raise ValueError("文件名存在多个下划线")
|
|
|
|
index = 0
|
|
for record in records:
|
|
index = index + 1
|
|
rcdName = self._getRecordName(fileName,record)
|
|
existing_nodes.append(EntityNode(name=rcdName, label=rcdName,properties = record))
|
|
existing_relations.append(
|
|
Relation(
|
|
label="包含",
|
|
source_id= fileName,
|
|
target_id= rcdName,
|
|
properties={},
|
|
)
|
|
)
|
|
|
|
existing_relations.append(
|
|
Relation(
|
|
label="包含",
|
|
source_id= self.ProjectName,
|
|
target_id= fileName,
|
|
properties={},
|
|
)
|
|
)
|
|
|
|
|
|
llama_node.metadata[KG_NODES_KEY] = existing_nodes
|
|
llama_node.metadata[KG_RELATIONS_KEY] = existing_relations
|
|
|
|
def _dealAttributeNode(self,llama_node:BaseNode):
|
|
fileName = self._getFileName(llama_node)
|
|
existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, [])
|
|
existing_relations:list = llama_node.metadata.pop(KG_RELATIONS_KEY, [])
|
|
records:dict[str,list] = self._getRecordNode(llama_node)
|
|
existing_nodes.append(EntityNode(name=fileName, label=fileName))
|
|
|
|
index = 0
|
|
for record in records:
|
|
index = index + 1
|
|
attName = self._getRecordName(fileName,record)
|
|
existing_nodes.append(EntityNode(name=attName, label='属性',properties = record))
|
|
existing_relations.append(
|
|
Relation(
|
|
label="聚合",
|
|
source_id= fileName,
|
|
target_id= attName,
|
|
properties={},
|
|
)
|
|
)
|
|
|
|
existing_relations.append(
|
|
Relation(
|
|
label="包含",
|
|
source_id= self.ProjectName,
|
|
target_id= fileName,
|
|
properties={},
|
|
)
|
|
)
|
|
|
|
llama_node.metadata[KG_NODES_KEY] = existing_nodes
|
|
llama_node.metadata[KG_RELATIONS_KEY] = existing_relations
|
|
|
|
def _getRecordNode(self,llama_node:BaseNode):
|
|
content = llama_node.get_content()
|
|
rd = ChunkMarkdownReader()
|
|
rd.markdown_to_tups(content)
|
|
records = rd.records()
|
|
return records
|
|
|
|
def _getFileName(self,llama_node:BaseNode):
|
|
meta = llama_node.metadata
|
|
fileName:str = os.path.splitext(meta['file_name'])[0]
|
|
return fileName
|
|
|
|
def _addPrjNode(self,llama_node:BaseNode):
|
|
existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, [])
|
|
existing_nodes.append(EntityNode(name=self.ProjectName, label=self.ProjectName))
|
|
llama_node.metadata[KG_NODES_KEY] = existing_nodes
|
|
|
|
def _getRecordName(self,fileName:str,record:dict):
|
|
for name,value in record.items():
|
|
if '名称' in name:
|
|
return value
|
|
raise ValueError('记录名称为空') |