上传文件
This commit is contained in:
+206
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
构建知识图谱
|
||||
清洗节点属性数据规范化
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import csv
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def load_template(template_path):
|
||||
"""
|
||||
加载模板CSV文件,获取每种类型允许的属性列表
|
||||
|
||||
Args:
|
||||
template_path: 模板CSV文件路径
|
||||
|
||||
Returns:
|
||||
dict: 按类型分类的属性名集合
|
||||
"""
|
||||
allowed_attributes = defaultdict(set)
|
||||
|
||||
try:
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
reader = csv.reader(f)
|
||||
next(reader) # 跳过表头
|
||||
|
||||
for row in reader:
|
||||
if len(row) >= 2:
|
||||
node_type = row[0].strip()
|
||||
attr_name = row[1].strip()
|
||||
if node_type and attr_name:
|
||||
allowed_attributes[node_type].add(attr_name)
|
||||
|
||||
return allowed_attributes
|
||||
except Exception as e:
|
||||
print(f"读取模板文件出错: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def get_mapped_type(original_type):
|
||||
"""
|
||||
将数字类型映射为对应的字符串类型
|
||||
|
||||
Args:
|
||||
original_type: 原始类型值
|
||||
|
||||
Returns:
|
||||
str: 映射后的类型值
|
||||
"""
|
||||
type_mapping = {
|
||||
"0": "定额",
|
||||
"1": "主材",
|
||||
"2": "人工",
|
||||
"3": "材料",
|
||||
"4": "机械",
|
||||
"5": "设备",
|
||||
"配件": "设备",
|
||||
"8": "清单",
|
||||
}
|
||||
|
||||
return type_mapping.get(original_type, original_type)
|
||||
|
||||
|
||||
def filter_attributes(json_file_path, template_path, output_file_path):
|
||||
"""
|
||||
根据模板筛选JSON文件中的属性
|
||||
|
||||
Args:
|
||||
json_file_path: JSON文件路径
|
||||
template_path: 模板CSV文件路径
|
||||
output_file_path: 输出JSON文件路径
|
||||
"""
|
||||
try:
|
||||
# 加载模板
|
||||
allowed_attributes = load_template(template_path)
|
||||
if not allowed_attributes:
|
||||
print("模板加载失败或为空")
|
||||
return
|
||||
|
||||
# 读取JSON文件
|
||||
with open(json_file_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# 检查是否存在projectData.projectDivision
|
||||
if "projectData" not in data or "projectDivision" not in data["projectData"]:
|
||||
print(f"文件 {json_file_path} 中不包含projectData.projectDivision数据")
|
||||
return
|
||||
|
||||
# 获取projectDivision
|
||||
project_division = data["projectData"]["projectDivision"]
|
||||
|
||||
# 递归筛选节点属性
|
||||
def filter_node(node):
|
||||
if not isinstance(node, dict):
|
||||
return node
|
||||
|
||||
# 处理类型映射和GUID大小写
|
||||
processed_node = {}
|
||||
for key, value in node.items():
|
||||
# 处理GUID大小写
|
||||
if key.lower() == "guid":
|
||||
processed_node["GUID"] = value
|
||||
else:
|
||||
processed_node[key] = value
|
||||
|
||||
# 处理类型映射
|
||||
if "类型" in processed_node:
|
||||
processed_node["类型"] = get_mapped_type(processed_node["类型"])
|
||||
if "type" in processed_node:
|
||||
processed_node["type"] = get_mapped_type(processed_node["type"])
|
||||
|
||||
# 获取节点类型,优先使用type字段,如果没有则使用"类型"字段
|
||||
node_type = processed_node.get("type", processed_node.get("类型", "未知类型"))
|
||||
node_type = get_mapped_type(node_type) # 确保类型已映射
|
||||
|
||||
# 如果模板中有该类型的定义
|
||||
if node_type in allowed_attributes:
|
||||
# 筛选属性
|
||||
filtered_node = {}
|
||||
for attr_name, attr_value in processed_node.items():
|
||||
# 处理属性名大小写
|
||||
template_attr_name = attr_name
|
||||
if attr_name.lower() == "guid":
|
||||
template_attr_name = "GUID"
|
||||
|
||||
# 如果属性在模板中定义,或者是children或材机列表属性,则保留
|
||||
if (
|
||||
template_attr_name in allowed_attributes[node_type]
|
||||
or attr_name == "children"
|
||||
or attr_name == "材机列表"
|
||||
):
|
||||
# 如果是children或材机列表属性,递归处理子节点
|
||||
if (attr_name == "children" or attr_name == "材机列表") and isinstance(attr_value, list):
|
||||
filtered_node[attr_name] = [filter_node(child) for child in attr_value]
|
||||
else:
|
||||
# 如果是GUID属性,统一使用大写
|
||||
if attr_name.lower() == "guid":
|
||||
filtered_node["GUID"] = attr_value
|
||||
else:
|
||||
filtered_node[attr_name] = attr_value
|
||||
return filtered_node
|
||||
else:
|
||||
# 如果模板中没有该类型的定义,则递归处理所有可能的子结构
|
||||
result = {}
|
||||
for key, value in processed_node.items():
|
||||
if isinstance(value, dict):
|
||||
result[key] = filter_node(value)
|
||||
elif isinstance(value, list):
|
||||
result[key] = [filter_node(item) for item in value]
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
# 深度遍历整个projectDivision结构
|
||||
def deep_traverse(obj):
|
||||
if isinstance(obj, dict):
|
||||
# 检查是否是一个有类型的节点
|
||||
if "type" in obj or "类型" in obj:
|
||||
return filter_node(obj)
|
||||
else:
|
||||
# 不是有类型的节点,递归处理所有字段
|
||||
result = {}
|
||||
for key, value in obj.items():
|
||||
result[key] = deep_traverse(value)
|
||||
return result
|
||||
elif isinstance(obj, list):
|
||||
return [deep_traverse(item) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
# 处理整个projectDivision
|
||||
filtered_project_division = deep_traverse(project_division)
|
||||
|
||||
# 更新数据
|
||||
data["projectData"]["projectDivision"] = filtered_project_division
|
||||
|
||||
# 保存到新文件
|
||||
with open(output_file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
||||
|
||||
print(f"筛选完成,结果已保存到 {output_file_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理文件时出错: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
主函数
|
||||
"""
|
||||
# 指定文件路径
|
||||
json_file_path = "KG_generation/dataset/json/主网清单/架空.json"
|
||||
template_path = "KG_generation/节点属性模板.csv"
|
||||
output_file_path = "KG_generation/dataset/json/主网清单/架空_clean.json"
|
||||
|
||||
# 执行筛选
|
||||
filter_attributes(json_file_path, template_path, output_file_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user