Files
KG_generation/supplement_kg.py
chentianrui 0a4dedda1c 更新代码
2025-10-14 16:13:18 +08:00

411 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
第三步:向上汇总费用预览
"""
import json
import os
from typing import Dict, List, Any, Tuple, Optional
import copy
import re
class ExpenseProcessor:
def __init__(self):
pass
@staticmethod
def normalize_guid(guid: str) -> str:
"""
标准化GUID格式,确保只有单中括号
:param guid: 原始GUID字符串
:return: 标准化后的GUID字符串
"""
if not guid:
return guid
# 移除所有中括号,然后添加单中括号
normalized = guid.strip("{}")
return "{" + normalized + "}"
@staticmethod
def is_cost_item(obj: Any) -> bool:
"""
判断一个对象是否为费用项(只有 id 和 cost 字段)
"""
return (
isinstance(obj, dict)
and "id" in obj
and "cost" in obj
and len(obj) <= 2 # 允许有额外字段,但核心是 id 和 cost
)
@staticmethod
def extract_costs_from_children(node: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
从节点的 children 中提取费用项(用于叶子节点)
:param node: 节点
:return: 费用项列表
"""
costs = []
if "children" in node and isinstance(node["children"], list):
for child in node["children"]:
if ExpenseProcessor.is_cost_item(child):
# 深拷贝费用项
costs.append(copy.deepcopy(child))
return costs
@staticmethod
def calculate_parent_costs(node: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
计算节点的汇总费用(包括自身和所有后代)
:param node: 费用预览节点
:return: 汇总后的费用项列表
"""
result_nodes = []
processed_ids = {}
# 1. 收集本节点自身的 sum 费用
if "sum" in node and isinstance(node["sum"], list):
for cost_item in node["sum"]:
if "id" in cost_item and "cost" in cost_item:
item_id = cost_item["id"]
if item_id not in processed_ids:
processed_ids[item_id] = 0.0
try:
processed_ids[item_id] += float(cost_item["cost"])
except (ValueError, TypeError):
pass # 忽略无效 cost
# 2. 检查 children 中是否直接包含费用项(叶子节点)
child_costs = ExpenseProcessor.extract_costs_from_children(node)
for cost_item in child_costs:
item_id = cost_item["id"]
if item_id not in processed_ids:
processed_ids[item_id] = 0.0
try:
processed_ids[item_id] += float(cost_item["cost"])
except (ValueError, TypeError):
pass
# 3. 递归处理子节点(结构化节点)
# 注意:这里我们不需要再递归计算,因为每个子节点已经在process_node中计算了自己的sum
# 我们只需要直接使用子节点的sum即可
if "children" in node and isinstance(node["children"], list):
for child in node["children"]:
# 只处理非费用项的子节点
if not ExpenseProcessor.is_cost_item(child):
# 直接使用子节点的sum
if "sum" in child and isinstance(child["sum"], list):
for cost_item in child["sum"]:
if "id" in cost_item and "cost" in cost_item:
item_id = cost_item["id"]
if item_id not in processed_ids:
processed_ids[item_id] = 0.0
try:
processed_ids[item_id] += float(cost_item["cost"])
except (ValueError, TypeError):
pass
# 构建结果
result_nodes = [{"id": item_id, "cost": str(total_cost)} for item_id, total_cost in processed_ids.items()]
return result_nodes
@staticmethod
def find_guid_quantity(project_data: Optional[Dict[str, Any]], guid: str) -> float:
"""
在 projectDivision 中查找指定 GUID 节点的数量。
:param project_data: 项目数据
:param guid: 要查找的 GUID(带花括号的格式,如 "{12345678-...}"
:return: 数量值(float
:raises KeyError: 如果未找到指定 GUID 的节点
:raises ValueError: 如果找到节点但缺少 "数量" 字段,或数量无法转换为 float
"""
if not project_data or "projectDivision" not in project_data:
raise KeyError(f"projectDivision not found in project_data")
guid_clean = guid.strip("{}")
def search_node_quantity(node):
if isinstance(node, dict):
node_guid = node.get("GUID", "").strip("{}")
if node_guid == guid_clean:
if "数量" not in node:
raise ValueError(f"Node with GUID {guid} has no '数量' field")
quantity = node["数量"]
try:
return float(quantity)
except (ValueError, TypeError) as e:
raise ValueError(f"Invalid quantity value for GUID {guid}: {quantity}") from e
# 递归搜索子节点
for value in node.values():
if isinstance(value, (dict, list)):
try:
result = search_node_quantity(value)
return result
except (KeyError, ValueError):
continue # 继续搜索其他分支
# 当前 dict 分支未找到
raise KeyError(f"GUID {guid} not found in this branch")
elif isinstance(node, list):
for item in node:
try:
return search_node_quantity(item)
except (KeyError, ValueError):
continue
# 整个列表都未找到
raise KeyError(f"GUID {guid} not found in list")
else:
# 非 dict/list 类型,不可能包含目标节点
raise KeyError(f"GUID {guid} not found")
try:
return search_node_quantity(project_data["projectDivision"])
except KeyError:
raise KeyError(f"projectDivision中没找到对应的GUID {guid}")
@staticmethod
def process_node(
node: Dict[str, Any], project_data: Optional[Dict[str, Any]] = None, is_bill_engineering: Optional[bool] = None
) -> Dict[str, Any]:
"""
处理单个节点,计算汇总费用并更新sum数组
:param node: 费用预览节点
:param project_data: 项目数据,用于查找GUID对应的数量
:param is_bill_engineering: 是否为清单工程
:return: 处理后的节点
"""
result = copy.deepcopy(node)
# 标准化GUID格式
if "GUID" in result:
result["GUID"] = ExpenseProcessor.normalize_guid(result["GUID"])
# 确保关键字段存在
if "sum" not in result:
result["sum"] = []
if "rcj" not in result:
result["rcj"] = []
if "children" not in result:
result["children"] = []
# 如果is_bill_engineering为None,默认为False
if is_bill_engineering is None:
is_bill_engineering = False
# === 特殊处理:如果 children 包含的是费用项(叶子节点)===
direct_costs = ExpenseProcessor.extract_costs_from_children(result)
if direct_costs:
# 如果是清单工程且有项目数据,需要根据GUID调整费用
if is_bill_engineering and project_data and "GUID" in result:
guid = result["GUID"]
quantity = ExpenseProcessor.find_guid_quantity(project_data, guid)
# 调整费用值:乘以数量
for cost_item in direct_costs:
try:
original_cost = float(cost_item["cost"])
adjusted_cost = original_cost * quantity
cost_item["cost"] = str(adjusted_cost)
except (ValueError, TypeError):
pass # 忽略无效 cost
# 将直接费用项迁移到 sum
result["sum"] = direct_costs
# 清空 children(因为已经迁移)
result["children"] = []
# 不再递归处理 children
return result
# === 普通节点处理:children 是子节点列表 ===
# 递归处理所有子节点
processed_children = []
if result["children"]:
for child in node["children"]:
if not ExpenseProcessor.is_cost_item(child):
processed_child = ExpenseProcessor.process_node(child, project_data, is_bill_engineering)
processed_children.append(processed_child)
# 更新处理后的子节点
result["children"] = processed_children
# 重要修改:使用处理后的result(包含已处理的子节点)来计算汇总费用
# 而不是使用原始的node
total_costs = ExpenseProcessor.calculate_parent_costs(result)
result["sum"] = total_costs
return result
@staticmethod
def process_expense_preview(
expense_preview: Dict[str, Any],
project_data: Optional[Dict[str, Any]] = None,
is_bill_engineering: Optional[bool] = None,
) -> Dict[str, Any]:
"""
处理整个费用预览结构
:param expense_preview: 费用预览数据
:param project_data: 项目数据,用于查找GUID对应的数量
:param is_bill_engineering: 是否为清单工程
:return: 处理后的费用预览数据
"""
# 如果is_bill_engineering为None,默认为False
if is_bill_engineering is None:
is_bill_engineering = False
result = copy.deepcopy(expense_preview)
for category_key, category_value in expense_preview.items():
if isinstance(category_value, dict):
for subcategory_key, subcategory_value in category_value.items():
if isinstance(subcategory_value, list):
result[category_key][subcategory_key] = [
ExpenseProcessor.process_node(item, project_data, is_bill_engineering)
for item in subcategory_value
]
elif isinstance(category_value, list):
result[category_key] = [
ExpenseProcessor.process_node(item, project_data, is_bill_engineering) for item in category_value
]
return result
# 以下方法保持不变
@classmethod
def load_and_process_from_file(
cls, input_path: str, output_path: str | None = None, is_bill_engineering: Optional[bool] = None
) -> Optional[Dict[str, Any]]:
try:
with open(input_path, "r", encoding="utf-8") as f:
data = json.load(f)
if "projectData" in data and "expensePreview" in data["projectData"]:
# 如果没有指定工程类型,则自动判断
if is_bill_engineering is None:
project_type = _determine_project_type(data)
is_bill_engineering = project_type == "inventory"
print(f"自动判断工程类型: {'清单工程' if is_bill_engineering else '预算工程'}")
processed_data = copy.deepcopy(data)
processed_data["projectData"]["expensePreview"] = cls.process_expense_preview(
data["projectData"]["expensePreview"],
data["projectData"] if is_bill_engineering else None,
is_bill_engineering,
)
if output_path:
with open(output_path, "w", encoding="utf-8") as f:
json.dump(processed_data, f, ensure_ascii=False, indent=4)
print(f"处理完成,结果已保存到 {output_path}")
return processed_data
else:
print(f"警告: 文件 {input_path} 中未找到 projectData.expensePreview 路径")
return None
except Exception as e:
print(f"处理文件 {input_path} 时出错: {str(e)}")
return None
@classmethod
def process_raw_data(cls, raw_data: Dict[str, Any], is_bill_engineering: Optional[bool] = None) -> Dict[str, Any]:
if "projectData" in raw_data and "expensePreview" in raw_data["projectData"]:
# 如果没有指定工程类型,则自动判断
if is_bill_engineering is None:
project_type = _determine_project_type(raw_data)
is_bill_engineering = project_type == "inventory"
print(f"自动判断工程类型: {'清单工程' if is_bill_engineering else '预算工程'}")
processed_data = copy.deepcopy(raw_data)
processed_data["projectData"]["expensePreview"] = cls.process_expense_preview(
raw_data["projectData"]["expensePreview"],
raw_data["projectData"] if is_bill_engineering else None,
is_bill_engineering,
)
return processed_data
else:
raise ValueError("未找到 projectData.expensePreview 路径")
@classmethod
def process_directory(
cls, input_dir: str, output_dir: str, is_bill_engineering: Optional[bool] = None
) -> List[Tuple[str, str]]:
os.makedirs(output_dir, exist_ok=True)
json_files = [f for f in os.listdir(input_dir) if f.lower().endswith(".json")]
if not json_files:
print(f"警告: 在目录 {input_dir} 中没有找到JSON文件")
return []
successful_files = []
for file in json_files:
input_file = os.path.join(input_dir, file)
output_file = os.path.join(output_dir, file)
print(f"处理文件: {input_file}")
processed_data = cls.load_and_process_from_file(input_file, output_file, is_bill_engineering)
if processed_data:
successful_files.append((input_file, output_file))
print(f"✅ 成功处理: {file}")
else:
print(f"❌ 处理失败: {file}")
return successful_files
import re
def _determine_project_type(data):
"""
根据basicData中的"项目类型"或"工程类型"判断工程类型
:param data: 项目数据
:return: 'inventory' 表示清单工程,'budget' 表示预算工程
"""
# 项目类型名称映射字典:将各种变体映射到标准类型(预算/清单)
PROJECT_TYPE_MAPPING = {
"概预算工程": "预算",
"初步设计概算": "预算",
"可行性研究投资估算": "预算",
"施工图预算": "预算",
"配网定额计价": "预算",
"招标控制价": "清单",
"投标报价": "清单",
"招投标工程": "清单",
"配网清单招投标计价": "清单",
}
# 获取 basicData
basic_data = data.get("basicData") or {}
# 尝试获取 "项目类型",若不存在则尝试获取 "工程类型"
engineering_type = basic_data.get("项目类型") or basic_data.get("工程类型") or basic_data.get("工程类别")
if engineering_type:
# 去除前后空格
engineering_type = engineering_type.strip()
# 查找映射
mapped_type = PROJECT_TYPE_MAPPING.get(engineering_type)
if mapped_type == "预算":
print(f"根据项目类型 '{engineering_type}' 判断为预算工程")
return "budget"
elif mapped_type == "清单":
print(f"根据项目类型 '{engineering_type}' 判断为清单工程")
return "inventory"
else:
print(f"项目类型 '{engineering_type}' 未在映射中定义,跳过")
return "inventory" if is_inventory_project else "budget"
def costsummary_upwards(
input_dir: str, output_dir: str, is_bill_engineering: Optional[bool] = None
) -> List[Tuple[str, str]]:
return ExpenseProcessor.process_directory(input_dir, output_dir, is_bill_engineering)
if __name__ == "__main__":
input_directory = "data/input/json"
output_directory = "data/input/merged"
# 自动判断工程类型
result = costsummary_upwards(input_directory, output_directory)
if result:
print(f"\n成功处理了 {len(result)} 个文件:")
for src, dst in result:
print(f" {os.path.basename(src)} -> {os.path.basename(dst)}")
else:
print("\n没有文件被成功处理")