207 lines
6.9 KiB
Python
207 lines
6.9 KiB
Python
"""
|
|
构建知识图谱
|
|
清洗节点属性数据规范化
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import csv
|
|
from collections import defaultdict
|
|
|
|
|
|
def load_template(template_path):
|
|
"""
|
|
加载模板CSV文件,获取每种类型允许的属性列表
|
|
|
|
Args:
|
|
template_path: 模板CSV文件路径
|
|
|
|
Returns:
|
|
dict: 按类型分类的属性名集合
|
|
"""
|
|
allowed_attributes = defaultdict(set)
|
|
|
|
try:
|
|
with open(template_path, "r", encoding="utf-8") as f:
|
|
reader = csv.reader(f)
|
|
next(reader) # 跳过表头
|
|
|
|
for row in reader:
|
|
if len(row) >= 2:
|
|
node_type = row[0].strip()
|
|
attr_name = row[1].strip()
|
|
if node_type and attr_name:
|
|
allowed_attributes[node_type].add(attr_name)
|
|
|
|
return allowed_attributes
|
|
except Exception as e:
|
|
print(f"读取模板文件出错: {e}")
|
|
return {}
|
|
|
|
|
|
def get_mapped_type(original_type):
|
|
"""
|
|
将数字类型映射为对应的字符串类型
|
|
|
|
Args:
|
|
original_type: 原始类型值
|
|
|
|
Returns:
|
|
str: 映射后的类型值
|
|
"""
|
|
type_mapping = {
|
|
"0": "定额",
|
|
"1": "主材",
|
|
"2": "人工",
|
|
"3": "材料",
|
|
"4": "机械",
|
|
"5": "设备",
|
|
"配件": "设备",
|
|
"8": "清单",
|
|
}
|
|
|
|
return type_mapping.get(original_type, original_type)
|
|
|
|
|
|
def filter_attributes(json_file_path, template_path, output_file_path):
|
|
"""
|
|
根据模板筛选JSON文件中的属性
|
|
|
|
Args:
|
|
json_file_path: JSON文件路径
|
|
template_path: 模板CSV文件路径
|
|
output_file_path: 输出JSON文件路径
|
|
"""
|
|
try:
|
|
# 加载模板
|
|
allowed_attributes = load_template(template_path)
|
|
if not allowed_attributes:
|
|
print("模板加载失败或为空")
|
|
return
|
|
|
|
# 读取JSON文件
|
|
with open(json_file_path, "r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
|
|
# 检查是否存在projectData.projectDivision
|
|
if "projectData" not in data or "projectDivision" not in data["projectData"]:
|
|
print(f"文件 {json_file_path} 中不包含projectData.projectDivision数据")
|
|
return
|
|
|
|
# 获取projectDivision
|
|
project_division = data["projectData"]["projectDivision"]
|
|
|
|
# 递归筛选节点属性
|
|
def filter_node(node):
|
|
if not isinstance(node, dict):
|
|
return node
|
|
|
|
# 处理类型映射和GUID大小写
|
|
processed_node = {}
|
|
for key, value in node.items():
|
|
# 处理GUID大小写
|
|
if key.lower() == "guid":
|
|
processed_node["GUID"] = value
|
|
else:
|
|
processed_node[key] = value
|
|
|
|
# 处理类型映射
|
|
if "类型" in processed_node:
|
|
processed_node["类型"] = get_mapped_type(processed_node["类型"])
|
|
if "type" in processed_node:
|
|
processed_node["type"] = get_mapped_type(processed_node["type"])
|
|
|
|
# 获取节点类型,优先使用type字段,如果没有则使用"类型"字段
|
|
node_type = processed_node.get("type", processed_node.get("类型", "未知类型"))
|
|
node_type = get_mapped_type(node_type) # 确保类型已映射
|
|
|
|
# 如果模板中有该类型的定义
|
|
if node_type in allowed_attributes:
|
|
# 筛选属性
|
|
filtered_node = {}
|
|
for attr_name, attr_value in processed_node.items():
|
|
# 处理属性名大小写
|
|
template_attr_name = attr_name
|
|
if attr_name.lower() == "guid":
|
|
template_attr_name = "GUID"
|
|
|
|
# 如果属性在模板中定义,或者是children或材机列表属性,则保留
|
|
if (
|
|
template_attr_name in allowed_attributes[node_type]
|
|
or attr_name == "children"
|
|
or attr_name == "材机列表"
|
|
):
|
|
# 如果是children或材机列表属性,递归处理子节点
|
|
if (attr_name == "children" or attr_name == "材机列表") and isinstance(attr_value, list):
|
|
filtered_node[attr_name] = [filter_node(child) for child in attr_value]
|
|
else:
|
|
# 如果是GUID属性,统一使用大写
|
|
if attr_name.lower() == "guid":
|
|
filtered_node["GUID"] = attr_value
|
|
else:
|
|
filtered_node[attr_name] = attr_value
|
|
return filtered_node
|
|
else:
|
|
# 如果模板中没有该类型的定义,则递归处理所有可能的子结构
|
|
result = {}
|
|
for key, value in processed_node.items():
|
|
if isinstance(value, dict):
|
|
result[key] = filter_node(value)
|
|
elif isinstance(value, list):
|
|
result[key] = [filter_node(item) for item in value]
|
|
else:
|
|
result[key] = value
|
|
return result
|
|
|
|
# 深度遍历整个projectDivision结构
|
|
def deep_traverse(obj):
|
|
if isinstance(obj, dict):
|
|
# 检查是否是一个有类型的节点
|
|
if "type" in obj or "类型" in obj:
|
|
return filter_node(obj)
|
|
else:
|
|
# 不是有类型的节点,递归处理所有字段
|
|
result = {}
|
|
for key, value in obj.items():
|
|
result[key] = deep_traverse(value)
|
|
return result
|
|
elif isinstance(obj, list):
|
|
return [deep_traverse(item) for item in obj]
|
|
else:
|
|
return obj
|
|
|
|
# 处理整个projectDivision
|
|
filtered_project_division = deep_traverse(project_division)
|
|
|
|
# 更新数据
|
|
data["projectData"]["projectDivision"] = filtered_project_division
|
|
|
|
# 保存到新文件
|
|
with open(output_file_path, "w", encoding="utf-8") as f:
|
|
json.dump(data, f, ensure_ascii=False, indent=4)
|
|
|
|
print(f"筛选完成,结果已保存到 {output_file_path}")
|
|
|
|
except Exception as e:
|
|
print(f"处理文件时出错: {e}")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
def main():
|
|
"""
|
|
主函数
|
|
"""
|
|
# 指定文件路径
|
|
json_file_path = "KG_generation/dataset/json/主网清单/架空.json"
|
|
template_path = "KG_generation/节点属性模板.csv"
|
|
output_file_path = "KG_generation/dataset/json/主网清单/架空_clean.json"
|
|
|
|
# 执行筛选
|
|
filter_attributes(json_file_path, template_path, output_file_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|