From 439c809340f165ccbe6f5734a6ebb00cc5727eae Mon Sep 17 00:00:00 2001 From: chentianrui Date: Tue, 3 Jun 2025 14:50:45 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=8A=E4=BC=A0=E6=96=87=E4=BB=B6=E8=87=B3?= =?UTF-8?q?=20/?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- langchain_neo4j.py | 608 ++++++++++++++++++++++++ project_implementation.py | 948 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 1556 insertions(+) create mode 100644 langchain_neo4j.py create mode 100644 project_implementation.py diff --git a/langchain_neo4j.py b/langchain_neo4j.py new file mode 100644 index 0000000..12dedde --- /dev/null +++ b/langchain_neo4j.py @@ -0,0 +1,608 @@ +from langchain.chains import LLMChain +from langchain_openai import OpenAI +from langchain_experimental.utilities import PythonREPL +from project_implementation import ProjectBuilder +from prompt_templates import FUNCTION_CALL_PROMPT +import inspect +import project +import io +import sys +from parameter_rewriting import rewrite_query_parameters, KnowledgeGraphProcessor +import json +from langchain.agents import Tool, AgentExecutor, create_react_agent +from langchain.prompts import PromptTemplate + +from llm import llm + + +# 获取ProjectTookiIt类的方法定义 +def get_project_class_methods(): + """ + 从project模块中提取ProjectTookiIt类的方法定义 + + Returns: + str: 格式化后的方法定义字符串 + """ + project_class_code = inspect.getsource(project.ProjectTookiIt) + + lines = project_class_code.split("\n") + result_lines = [] + + in_class = False + skip_init = False + + for line in lines: + if line.strip().startswith("class ProjectTookiIt"): + in_class = True + result_lines.append(line) + elif in_class: + if line.strip().startswith("def __init__"): + skip_init = True + elif skip_init and line.strip() and not line.startswith(" " * 8): + skip_init = False + + if not skip_init: + result_lines.append(line) + + return "\n".join(result_lines) + + +# 创建动态提示模板 +project_class_methods = get_project_class_methods() + +# 创建 Chain +function_call_chain = LLMChain(llm=llm, prompt=FUNCTION_CALL_PROMPT, output_key="code") + +# Python 执行器 +repl = PythonREPL() + +# 创建知识图谱处理器实例 +kg_processor = KnowledgeGraphProcessor() + + +# 定义搜索知识库的工具 +def search_knowledge_base(query): + """ + 在知识库中搜索关键词 + + Args: + query (str): 搜索关键词 + + Returns: + str: 搜索结果的JSON字符串 + """ + found_data = kg_processor._search_in_kg(query) + if found_data: + return json.dumps(found_data, ensure_ascii=False, indent=2) + else: + return f"未找到与'{query}'相关的信息" + + +# 定义获取节点定义的工具 +def get_node_definition(node_type): + """ + 获取节点类型的定义 + + Args: + node_type (str): 节点类型名称 + + Returns: + str: 节点类型定义 + """ + definition = kg_processor._get_node_definition(node_type) + return definition + + +# 创建工具列表 +tools = [ + Tool( + name="search_knowledge_base", + func=search_knowledge_base, + description="在知识库中搜索关键词,返回相关信息。输入应该是一个搜索关键词。", + ), + Tool( + name="get_node_definition", + func=get_node_definition, + description="获取节点类型的定义,返回类型定义代码。输入应该是一个节点类型名称。", + ), +] + +# 创建Agent +agent = create_react_agent(llm, tools, FUNCTION_CALL_PROMPT) + +# 创建Agent执行器 +agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True) + + +def nl_query_to_function_call(input_data): + """ + 将自然语言查询转换为函数调用并执行,或直接执行提供的代码 + + Args: + input_data (dict): 包含type和value的字典 + { + "type": "query|code", + "value": "查询内容或代码" + } + + Returns: + dict: 包含状态码、消息和数据的字典 + """ + input_type = input_data.get("type", "query") + input_value = input_data.get("value", "") + max_retries = 1 # 设置最大重试次数 + current_retry = 0 + + original_query = input_value # 保存原始查询用于RAG + + print(f"\n====== 开始处理查询 ======") + print(f"查询类型: {input_type}") + print(f"查询内容: {input_value}") + + while current_retry <= max_retries: + print(f"\n----- 尝试 #{current_retry + 1} -----") + + # 如果type是query,使用LLM生成代码 + if input_type == "query" and current_retry == 0: + # 从查询中提取关键部分 + import re + + # 提取【】中的内容 + path_parts = re.findall(r"【([^】]+)】", input_value) + + # 创建临时代码 + temp_code = f'search("{input_value}")' + + # 对于每个路径,只添加最后一个部分 + for part in path_parts: + if "/" in part: + # 提取路径中的最后一个部分 + last_part = part.split("/")[-1].strip() + if last_part: + temp_code += f'\nsearch("{last_part}")' + else: + # 如果没有/,直接使用整个部分 + temp_code += f'\nsearch("{part}")' + + # 获取知识库内容和节点定义 + knowledge_base, node_definitions = kg_processor._get_relevant_knowledge(temp_code) + + # 使用Agent执行查询 + agent_response = agent_executor.invoke( + { + "query": input_value, + "project_class_methods": project_class_methods, + "KnowledgeBase": knowledge_base, + "NodeDefinition": node_definitions, + } + ) + + # 从Agent响应中提取代码 + code = agent_response["output"] + print(f"\n生成的代码:\n{code}") + else: + print(f"\n使用重写后的代码:\n{input_value}") + code = input_value + + # 保存原始代码用于返回 + original_code = code + + # 执行生成的函数并捕获输出 + try: + # 创建一个新的命名空间来执行代码,包含必要的导入 + namespace = { + "ProjectBuilder": ProjectBuilder, # 添加ProjectBuilder到命名空间 + "project_implementation": __import__("project_implementation"), + "project": __import__("project"), + } + + # 执行生成的代码,定义neo4j_find_function函数 + exec(code, namespace) + + # 重定向stdout来捕获print输出 + old_stdout = sys.stdout + redirected_output = io.StringIO() + sys.stdout = redirected_output + + try: + # 执行函数 + # print("\n执行代码...") + result = namespace["neo4j_find_function"]() + + # 获取捕获的输出 + output = redirected_output.getvalue().strip() + # print(f"\n原始输出:\n{output}") + + # 检查结果是否为空 + is_empty_result = True # 默认假设为空 + + # 先检查是否有明显的数据内容 + if output and output.strip().startswith("[{") and output.strip().endswith("}]"): + # 看起来是一个对象数组,尝试转换为有效的JSON格式 + try: + # 将单引号替换为双引号以使其成为有效的JSON + # 但要小心处理嵌套的引号 + import ast + + # 使用ast.literal_eval安全地将Python字符串表示转换为Python对象 + parsed_obj = ast.literal_eval(output) + if parsed_obj and len(parsed_obj) > 0: + is_empty_result = False + except: + # 如果解析失败,继续尝试其他方法 + pass + + # 如果上面的方法没有确定结果不为空,继续尝试JSON解析 + if is_empty_result: + try: + if output.strip(): + parsed_output = json.loads(output) + if parsed_output and ( + isinstance(parsed_output, list) + and len(parsed_output) > 0 + or isinstance(parsed_output, dict) + and len(parsed_output) > 0 + ): + is_empty_result = False + except json.JSONDecodeError: + # 不是有效的JSON,使用其他判断方法 + is_empty_result = ( + not output + or output.lower() == "none" + or output == "[]" + or "未找到" in output + or "None" in output + or result is None + ) + + # 如果结果为空,走重写流程 + if is_empty_result: + print("\n查询未找到结果,尝试定位具体缺失节点...") + + # 解析原始查询路径中的最后一个节点名 + import re + + match = re.search(r"【([^】]+)】\s*$", original_query) + missing_node = match.group(1) if match else "未知节点" + + error_info = { + "error_type": "NodeNotFoundError", + "error_message": f"{missing_node} 未找到,请检查该节点是否存在。", + "missing_node": missing_node, + "original_query": original_query, + "executed_code": original_code, + } + + print("结构化错误信息:") + print(json.dumps(error_info, ensure_ascii=False, indent=2)) + + if current_retry < max_retries: + print("\n尝试使用RAG重写查询...") + try: + # 使用RAG重写查询和代码,并传递错误信息 + rewritten = rewrite_query_parameters(original_query, original_code, error_info) + + print(f"\nRAG重写结果: {json.dumps(rewritten, ensure_ascii=False, indent=2)}") + + # 更新查询和代码 + if "query" in rewritten and "code" in rewritten and rewritten["code"] != original_code: + print("\nRAG重写成功,使用新代码重试...") + input_value = rewritten["code"] # 直接使用重写后的代码 + input_type = "code" # 切换到代码模式 + current_retry += 1 + continue # 继续下一次循环 + else: + print("\nRAG重写未产生新代码,返回原始错误") + except Exception as e: + print(f"\nRAG重写失败: {e}") + # 记录错误但继续执行 + + # RAG重写失败或未产生新代码,返回原始错误 + query_status = ( + "第一次查询失败,RAG重写也失败" + if current_retry == 0 + else f"第{current_retry+1}次查询失败,RAG重写也失败" + ) + print(f"\n{query_status}") + return { + "code": 1, + "message": f"{missing_node} 未找到,请检查该节点是否存在。", + "data": {"value": "", "code": original_code}, + "error_info": error_info, + "query_status": query_status, + } + + # 清理输出,只保留有用的结果部分 + clean_output = output + + # 如果输出包含查询结果数量和对象引用 + if "查询结果数量:" in output and "", output, re.DOTALL) + if node_match: + props_str = node_match.group(1).replace("'", '"') + try: + import ast + + props = ast.literal_eval(props_str) + clean_output = json.dumps(props, ensure_ascii=False, indent=2) + except: + pass + + # 如果有查询结果数量信息 + count_match = re.search(r"查询结果数量: (\d+)", output) + if count_match: + count = count_match.group(1) + if count == "0": + clean_output = "未找到匹配的数据。" + is_empty_result = True + elif not node_match: # 如果没有提取到节点属性但有结果 + clean_output = f"找到 {count} 条匹配结果" + + # 检查结果对象 + if result is not None: + if isinstance(result, list): + if not result: # 空列表 + is_empty_result = True + else: + # 处理非空列表 + formatted_items = [] + for item in result: + if hasattr(item, "__dict__"): + # 提取对象的所有属性 + attrs = {k: v for k, v in item.__dict__.items() if not k.startswith("_")} + formatted_items.append(attrs) + else: + formatted_items.append(str(item)) + + if not is_empty_result: # 只有在不是空结果时才返回成功 + query_status = ( + "第一次查询成功" + if current_retry == 0 + else f"第{current_retry+1}次查询成功(RAG重写后)" + ) + print(f"\n{query_status}") + return { + "code": 0, + "message": "成功", + "data": { + "value": json.dumps(formatted_items, ensure_ascii=False, indent=2), + "code": original_code, + }, + "query_status": query_status, + } + elif hasattr(result, "__dict__"): + # 单个对象 + attrs = {k: v for k, v in result.__dict__.items() if not k.startswith("_")} + + if not is_empty_result: # 只有在不是空结果时才返回成功 + query_status = ( + "第一次查询成功" if current_retry == 0 else f"第{current_retry+1}次查询成功(RAG重写后)" + ) + print(f"\n{query_status}") + return { + "code": 0, + "message": "成功", + "data": { + "value": json.dumps(attrs, ensure_ascii=False, indent=2), + "code": original_code, + }, + "query_status": query_status, + } + + # 如果没有对象属性但有清理后的输出,且不是空结果 + if ( + clean_output + and clean_output.lower() != "none" + and clean_output != "[]" + and "未找到" not in clean_output + and not is_empty_result + ): + query_status = ( + "第一次查询成功" if current_retry == 0 else f"第{current_retry+1}次查询成功(RAG重写后)" + ) + print(f"\n{query_status}") + return { + "code": 0, + "message": "成功", + "data": {"value": clean_output, "code": original_code}, + "query_status": query_status, + } + + finally: + # 恢复stdout + sys.stdout = old_stdout + + except Exception as e: + import traceback + + error_details = traceback.format_exc() + print(f"\n执行代码时出错: {error_details}") + + # 如果走到这里,说明结果为空或未找到匹配项,应该执行RAG重写流程 + print("\n查询未找到结果,尝试定位具体缺失节点...") + + # 解析原始查询路径中的最后一个节点名 + import re + + match = re.search(r"【([^】]+)】\s*$", original_query) + missing_node = match.group(1) if match else "未知节点" + + error_info = { + "error_type": "NodeNotFoundError", + "error_message": f"{missing_node} 未找到,请检查该节点是否存在。", + "missing_node": missing_node, + "original_query": original_query, + "executed_code": original_code, + } + + print("结构化错误信息:") + print(json.dumps(error_info, ensure_ascii=False, indent=2)) + + if current_retry < max_retries: + print("\n尝试使用RAG重写查询...") + try: + # 使用RAG重写查询和代码,并传递错误信息 + rewritten = rewrite_query_parameters(original_query, original_code, error_info) + + print(f"\nRAG重写结果: {json.dumps(rewritten, ensure_ascii=False, indent=2)}") + + # 更新查询和代码 + if "query" in rewritten and "code" in rewritten and rewritten["code"] != original_code: + print("\nRAG重写成功,使用新代码重试...") + input_value = rewritten["code"] # 直接使用重写后的代码 + input_type = "code" # 切换到代码模式 + current_retry += 1 + continue # 继续下一次循环 + else: + print("\nRAG重写未产生新代码,返回原始错误") + except Exception as e: + print(f"\nRAG重写失败: {e}") + # 记录错误但继续执行 + + # RAG重写失败或未产生新代码,返回原始错误 + query_status = ( + "第一次查询失败,RAG重写也失败" + if current_retry == 0 + else f"第{current_retry+1}次查询失败,RAG重写也失败" + ) + print(f"\n{query_status}") + return { + "code": 1, + "message": f"{missing_node} 未找到,请检查该节点是否存在。", + "data": {"value": "", "code": original_code}, + "error_info": error_info, + "query_status": query_status, + } + + # 如果所有重试都失败 + print("\n所有重试都失败,无法找到匹配的结果") + query_status = "所有重试都失败" + return { + "code": 1, + "message": "所有重试都失败,无法找到匹配的结果", + "data": {"value": "", "code": original_code}, + "query_status": query_status, + } + + +def format_result(result): + """ + 格式化查询结果 + + Args: + result: 查询结果(可能为 list、dict 或其他类型) + + Returns: + str: 格式化后的结果 + """ + # 处理 project 对象 + if hasattr(result, "__module__") and result.__module__ == "project": + # 这是一个 project 模块中的对象 + attrs = {k: v for k, v in result.__dict__.items() if not k.startswith("_")} + return json.dumps(attrs, ensure_ascii=False, indent=2) + + # 处理 project 对象列表 + if isinstance(result, list) and all( + hasattr(item, "__module__") and item.__module__ == "project" for item in result if hasattr(item, "__module__") + ): + formatted_items = [] + for item in result: + if hasattr(item, "__dict__"): + attrs = {k: v for k, v in item.__dict__.items() if not k.startswith("_")} + formatted_items.append(attrs) + else: + formatted_items.append(str(item)) + + return json.dumps(formatted_items, ensure_ascii=False, indent=2) + + # 如果结果是字符串,可能包含调试信息,需要提取有用部分 + if isinstance(result, str): + # 尝试提取最终结果部分 + if "[]" in result: + return "未找到匹配的数据。" + + # 如果包含节点信息,提取关键部分 + import re + + node_match = re.search(r"找到.*?labels=.*?properties=(.*?)>", result) + if node_match: + try: + # 提取属性部分并格式化 + props_str = node_match.group(1).replace("'", '"') + import ast + + props = ast.literal_eval(props_str) + + formatted = "找到节点:\n" + for k, v in props.items(): + formatted += f" {k}: {v}\n" + return formatted + except: + pass + + # 如果包含查询结果数量 + count_match = re.search(r"查询结果数量: (\d+)", result) + if count_match: + count = count_match.group(1) + if count == "0": + return "未找到匹配的数据。" + + # 如果是列表 + if isinstance(result, list): + if not result: + return "未找到匹配的数据。" + + lines = [f"找到 {len(result)} 条匹配结果:"] + for i, item in enumerate(result, 1): + lines.append(f"\n结果 {i}:") + if hasattr(item, "items"): # 检查是否有items方法(字典或类似字典的对象) + try: + for k, v in item.items(): + lines.append(f" {k}: {v}") + except: + lines.append(f" {item}") + else: + lines.append(f" {item}") + return "\n".join(lines) + + # 如果是字典 + elif isinstance(result, dict): + lines = ["查询结果:"] + for k, v in result.items(): + lines.append(f" {k}: {v}") + return "\n".join(lines) + + # 其他类型 + else: + return str(result) + + +def format_dict_or_item(item): + """ + 格式化字典或其他对象 + + Args: + item: 字典或其他对象 + + Returns: + str: 格式化后的字符串 + """ + if isinstance(item, dict): + formatted = "" + for key, value in item.items(): + formatted += f" {key}: {value}\n" + return formatted + + return str(item) + + +question = { + "type": "query", + "value": "查找一下【工程数据/安装工程/安装/架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立】的类型为【主材】的【塔材】", +} +result = nl_query_to_function_call(question) +print(result) diff --git a/project_implementation.py b/project_implementation.py new file mode 100644 index 0000000..1fc8fb7 --- /dev/null +++ b/project_implementation.py @@ -0,0 +1,948 @@ +from neo4j import GraphDatabase +from project import * +import atexit + + +class ProjectTookiItNeo4j(ProjectTookiIt): + """ + 基于Neo4j数据库的项目类实现 + """ + + def __init__(self): + """ + 初始化Neo4j连接 + + Args: + uri (str): Neo4j数据库URI + user (str): 用户名 + password (str): 密码 + """ + uri = "bolt://172.20.0.145:7687" + user = "neo4j" + password = "password" + + super().__init__() + self.driver = GraphDatabase.driver(uri, auth=(user, password)) + self.session = self.driver.session() + + # 初始化其他必要的数据结构 + self.material_equipment_dict = {} # 材机字典,键为ID + self.fee_templates = {} # 取费表模板字典,键为ID + self.fee_schedules = {} # 费用表字典,键为ID + self.project_properties = {} # 工程属性字典 + + def close(self): + """ + 关闭数据库连接 + """ + if self.session: + self.session.close() + if self.driver: + self.driver.close() + + # 通用节点查询方法 + def get_node_by_path(self, path, node_labels=None): + """ + 通过路径获取节点对象 + + Args: + path (str): 以'/'分隔的多级节点路径 + node_labels (list): 节点标签列表,用于过滤结果 + + Returns: + dict|None: 节点数据,如果路径不存在返回None + """ + if not path: + return None + + # 分割路径为各个部分 + path_parts = path.split("/") + + # 构建查询 + if len(path_parts) == 1: + # 只有一级路径,直接查询 + if node_labels: + labels_str = ":" + "|:".join(node_labels) + query = f""" + MATCH (n{labels_str}) + WHERE n.name = $name + RETURN n LIMIT 1 + """ + else: + query = """ + MATCH (n) + WHERE n.name = $name + RETURN n LIMIT 1 + """ + params = {"name": path_parts[0]} + else: + # 多级路径,构建路径查询 + last_part = path_parts[-1] + if node_labels: + labels_str = ":" + "|:".join(node_labels) + query = f""" + MATCH path = (root)-[*]->(target{labels_str}) + WHERE target.name = $last_part + RETURN target as n LIMIT 1 + """ + else: + query = """ + MATCH path = (root)-[*]->(target) + WHERE target.name = $last_part + RETURN target as n LIMIT 1 + """ + params = {"last_part": last_part} + + try: + result = self.session.run(query, **params) + record = result.single() + + if not record: + return None + + return record["n"] + except Exception as e: + print(f"获取节点对象时出错: {e}") + return None + + # 项目划分查询方法 + def get_division_item_by_path(self, path): + """ + 通过路径获取项目划分对象 + """ + node_data = self.get_node_by_path(path, ["ProjectDivisionItem"]) + if not node_data: + return None + + item = ProjectDivisionItem() + for key, value in node_data.items(): + if hasattr(item, key): + setattr(item, key, value) + + return item + + def get_division_node_by_parent_and_name(self, parent_path, partial_name): + """ + 通过父节点路径和模糊节点名称获取项目划分对象,包括子节点 + + Args: + parent_path (str): 父节点的路径,以'/'分隔的多级节点路径 + partial_name (str): 目标节点的模糊或不完整名称 + + Returns: + list: 包含所有匹配项目划分节点的列表,如果没有匹配返回空列表 + """ + if not partial_name: + return [] + + # 使用通用方法获取父节点 + parent_node_data = self.get_node_by_path(parent_path, ["ProjectDivisionItem"]) + parent_node_id = parent_node_data["id"] if parent_node_data and "id" in parent_node_data else None + + # 构建查询,根据是否有父节点ID调整查询条件,使用递归关系查询 + if parent_node_id: + query = """ + MATCH (p)-[:CONTAINS|HAS|RELATED_TO*]-(n:ProjectDivisionItem) + WHERE p.id = $parent_id AND n.name CONTAINS $partial_name + RETURN n LIMIT 50 + """ + params = {"parent_id": parent_node_id, "partial_name": partial_name} + else: + query = """ + MATCH (n:ProjectDivisionItem) + WHERE n.name CONTAINS $partial_name + RETURN n LIMIT 50 + """ + params = {"partial_name": partial_name} + + try: + result = self.session.run(query, **params) + + items = [] + for record in result: + node_data = record["n"] + item = ProjectDivisionItem() + for key, value in node_data.items(): + if hasattr(item, key): + setattr(item, key, value) + items.append(item) + + return items + except Exception as e: + print(f"通过父节点路径和模糊名称获取项目划分对象时出错: {e}") + return [] + + # 工程量查询方法 + def get_quantities_by_paths(self, paths_str): + """ + 获取指定项目路径下的工程量对象 + + Args: + paths_str (str): 以'/'分隔的多级节点路径 + + Returns: + ProjectQuantity|None: 对应的工程量对象,如果路径不存在返回None,成功找到返回工程量对象 + """ + if not paths_str: + return None + + # 使用通用方法获取节点,考虑所有可能的工程量类型 + node_data = self.get_node_by_path(paths_str, ["ProjectQuantity", "Quota", "MainMaterial", "Equipment"]) + + if not node_data: + return None + + # 根据节点标签或类型属性创建对应类型的对象 + quantity = self._create_quantity_object(node_data) + + # 填充属性 + for key, value in node_data.items(): + if hasattr(quantity, key): + setattr(quantity, key, value) + + return quantity + + def get_quantities_node_by_parent_and_code(self, parent_path, quantity_type=None, code=None): + """ + 通过父节点路径和编码获取工程量对象(定额),包括子节点 + + Args: + parent_path (str): 父节点的路径,以'/'分隔的多级节点路径 + quantity_type (str): 工程量类型('定额'、'主材'、'设备'或None表示所有类型) + code(str): 工程量编码,以'/'分隔的多个编码 + + Returns: + list: 包含所有匹配节点的列表,如果没有匹配返回空列表 + """ + if not code: + return [] + + # 使用通用方法获取父节点 + parent_node_data = self.get_node_by_path(parent_path) + + # 从路径中获取父节点名称 + path_parts = parent_path.split("/") + parent_name = path_parts[-1] + + # 处理编码,可能有多个编码用/分隔 + code_parts = code.split("/") + code_conditions = [] + for code_part in code_parts: + if code_part: + + code_conditions.append(f"q.编码 = '{code_part}'") + + if not code_conditions: + return [] + + code_query = " OR ".join(code_conditions) + + # 根据工程量类型确定标签 + node_labels = [] + if quantity_type == "定额": + node_labels = ["ProjectQuantity", "Quota"] + elif quantity_type == "主材": + node_labels = ["ProjectQuantity", "MainMaterial"] + elif quantity_type == "设备": + node_labels = ["ProjectQuantity", "Equipment"] + else: + node_labels = ["ProjectQuantity"] + + # 构建标签字符串 + labels_str = ":" + ":".join(node_labels) if node_labels else "" + + # 使用name属性进行匹配 + query = f""" + MATCH (p)-[*1..5]->(q{labels_str}) + WHERE p.name = $parent_name AND ({code_query}) + RETURN q + LIMIT 10 + """ + params = {"parent_name": parent_name} + + try: + result = self.session.run(query, params) + quantities = [] + + for record in result: + node_data = record["q"] + quantity = self._create_quantity_object(node_data, quantity_type) + + # 将节点属性赋值到对象 + for key, value in node_data.items(): + setattr(quantity, key, value) + + # 转换为字典 + if hasattr(quantity, "to_dict"): + quantities.append(quantity.to_dict()) + else: + # 如果没有 to_dict 方法,就用 vars() 动态获取属性 + quantities.append(vars(quantity)) + + return quantities + except Exception as e: + print(f"通过编码获取工程量对象时出错: {e}") + import traceback + + traceback.print_exc() + return [] + + def get_quantities_node_by_parent_and_name(self, parent_path, partial_name, quantity_type=None): + """ + 通过父节点路径、模糊节点名称和类型获取工程量对象(主材或者设备),包括子节点 + + Args: + parent_path (str): 父节点的路径,以'/'分隔的多级节点路径 + partial_name (str): 目标节点的模糊或不完整名称 + quantity_type (str): 工程量类型('定额'、'主材'、'设备'或None表示所有类型) + + Returns: + list: 包含所有匹配节点属性的列表,如果没有匹配返回空列表 + """ + if not partial_name: + return [] + + # 使用通用方法获取父节点 + parent_node_data = self.get_node_by_path(parent_path) + parent_node_id = parent_node_data["id"] if parent_node_data and "id" in parent_node_data else None + + # 根据工程量类型确定标签和类型条件 + node_labels = [] + type_condition = "" + + if quantity_type == "定额": + node_labels = ["ProjectQuantity", "Quota"] + type_condition = "q.类型 = '0'" + elif quantity_type == "主材": + node_labels = ["ProjectQuantity", "MainMaterial"] + type_condition = "q.类型 = '1'" + elif quantity_type == "设备": + node_labels = ["ProjectQuantity", "Equipment"] + type_condition = "q.类型 = '5'" + else: + node_labels = ["ProjectQuantity"] + + # 构建标签字符串 + labels_str = ":" + ":".join(node_labels) if node_labels else "" + + # 扩展关系类型 + relationship_types = "CONTAINS|HAS|RELATED_TO|USES|BELONGS_TO" + + # 构建查询 - 使用递归关系查询 + if parent_node_id: + query = f""" + MATCH (p)-[:{relationship_types}*1..10]->(q{labels_str}) + WHERE p.id = $parent_id AND q.name CONTAINS $partial_name + {f'AND {type_condition}' if type_condition else ''} + RETURN q LIMIT 50 + """ + params = {"parent_id": parent_node_id, "partial_name": partial_name} + else: + query = f""" + MATCH (q{labels_str}) + WHERE q.name CONTAINS $partial_name + {f'AND {type_condition}' if type_condition else ''} + RETURN q LIMIT 50 + """ + params = {"partial_name": partial_name} + + try: + result = self.session.run(query, **params) + + quantities = [] + for record in result: + node_data = record["q"] + + # 创建对应类型的对象 + quantity = self._create_quantity_object(node_data, quantity_type) + + # 填充属性 + for key, value in node_data.items(): + if hasattr(quantity, key): + setattr(quantity, key, value) + + # 将对象的属性转换为字典 + attrs = {} + for key, value in vars(quantity).items(): + if not key.startswith("_"): # 排除私有属性 + attrs[key] = value + + quantities.append(attrs) # 添加属性字典而不是对象 + + return quantities + except Exception as e: + import traceback + + traceback.print_exc() + return [] + + # 辅助方法,用于根据节点数据创建对应类型的工程量对象 + def _create_quantity_object(self, node_data, quantity_type=None): + """ + 根据节点数据创建对应类型的工程量对象 + + Args: + node_data (dict): 节点数据 + quantity_type (str): 工程量类型('定额'、'主材'、'设备'或None) + + Returns: + ProjectQuantity: 创建的工程量对象 + """ + # 如果指定了类型,直接创建对应类型的对象 + if quantity_type == "定额": + return Ration() + elif quantity_type == "主材": + return Material() + elif quantity_type == "设备": + return Equipment() + + # 如果没有指定类型,尝试通过节点属性或标签判断 + if "类型" in node_data: + if node_data["类型"] == "0": + return Ration() + elif node_data["类型"] == "1": + return Material() + elif node_data["类型"] == "5": + return Equipment() + + # 通过标签判断 + labels = list(node_data.labels) if hasattr(node_data, "labels") else [] + if "Quota" in labels: + return Ration() + elif "MainMaterial" in labels: + return Material() + elif "Equipment" in labels: + return Equipment() + + # 默认返回基类对象 + return ProjectQuantity() + + # 材机查询方法实现 + def get_material_equipment_by_path(self, paths_str): + """ + 通过路径获取材机对象 + + Args: + paths_str (str): 以'/'分隔的多级项目划分名称路径 + + Returns: + list: 包含所有匹配的材机对象的列表 + """ + if not paths_str: + return [] + + # 使用通用方法获取节点 + node_data = self.get_node_by_path(paths_str) + node_id = node_data["id"] if node_data and "id" in node_data else None + + if not node_id: + return [] + + # 查询与该节点关联的所有材机对象 + query = """ + MATCH (p)-[r]-(m) + WHERE p.id = $node_id AND (m:MaterialOrEquipment OR m:Material OR m:Equipment) + RETURN m LIMIT 50 + """ + params = {"node_id": node_id} + + try: + result = self.session.run(query, **params) + + materials = [] + for record in result: + material = self._create_material_object(record["m"]) + materials.append(material) + + # 更新缓存 + if hasattr(material, "id") and material.id: + self.material_equipment_dict[material.id] = material + + return materials + except Exception as e: + print(f"通过路径获取材机对象时出错: {e}") + return [] + + def get_material_equipment_by_parent_and_name(self, parent_path, partial_name): + """ + 通过父节点路径和模糊名称获取材机对象 + + Args: + parent_path (str): 父节点的路径,以'/'分隔的多级节点路径 + partial_name (str): 目标节点的模糊或不完整名称 + + Returns: + list: 包含所有匹配的材机对象的列表 + """ + if not partial_name: + return [] + + # 使用通用方法获取父节点 + parent_node_data = self.get_node_by_path(parent_path) + parent_node_id = parent_node_data["id"] if parent_node_data and "id" in parent_node_data else None + + # 构建查询,根据是否有父节点ID调整查询条件 + if parent_node_id: + # 如果找到了父节点,查找与父节点有关系的材机节点 + query = """ + MATCH (p)-[:CONTAINS|HAS|USES|RELATED_TO]-(m) + WHERE p.id = $parent_id AND m.name CONTAINS $partial_name + AND (m:MaterialOrEquipment OR m:Material OR m:Equipment) + RETURN m LIMIT 20 + """ + params = {"parent_id": parent_node_id, "partial_name": partial_name} + else: + # 如果没有找到父节点或没有提供父节点路径,只按名称查询 + query = """ + MATCH (m) + WHERE m.name CONTAINS $partial_name + AND (m:MaterialOrEquipment OR m:Material OR m:Equipment) + RETURN m LIMIT 20 + """ + params = {"partial_name": partial_name} + + try: + result = self.session.run(query, **params) + + materials = [] + for record in result: + material = self._create_material_object(record["m"]) + materials.append(material) + + # 更新缓存 + if hasattr(material, "id") and material.id: + self.material_equipment_dict[material.id] = material + + return materials + except Exception as e: + print(f"通过父节点路径和模糊名称获取材机对象时出错: {e}") + return [] + + # 辅助方法,用于创建材机对象并填充属性 + def _create_material_object(self, node_data): + """ + 根据节点数据创建材机对象并填充属性 + + Args: + node_data (dict): 节点数据 + + Returns: + MaterialOrEquipment: 创建的材机对象 + """ + material = MaterialOrEquipment() + + # 填充属性 + for key, value in node_data.items(): + if hasattr(material, key): + setattr(material, key, value) + + return material + + # 取费表模板查询方法实现 + def get_fee_template_by_path(self, paths_str): + """ + 通过路径获取取费表模板 + + Args: + paths_str (str): 以'/'分隔的多级项目划分名称路径 + + Returns: + list: 包含所有匹配的取费表模板对象的列表 + """ + if not paths_str: + return [] + + # 使用通用方法获取节点 + node_data = self.get_node_by_path(paths_str) + node_id = node_data["id"] if node_data and "id" in node_data else None + + if not node_id: + return [] + + # 查询与该节点关联的所有取费表模板 + query = """ + MATCH (p)-[r]-(t) + WHERE p.id = $node_id AND (t:FeeTableTemplate OR t.type = 'FeeTableTemplate') + RETURN t LIMIT 20 + """ + params = {"node_id": node_id} + + try: + result = self.session.run(query, **params) + + templates = [] + for record in result: + template = self._create_fee_template_object(record["t"]) + templates.append(template) + + return templates + except Exception as e: + print(f"通过路径获取取费表模板时出错: {e}") + return [] + + def get_fee_template_by_parent_and_name(self, parent_path, partial_name): + """ + 通过父节点路径和模糊名称获取取费表模板 + + Args: + parent_path (str): 父节点的路径,以'/'分隔的多级节点路径 + partial_name (str): 目标节点的模糊或不完整名称 + + Returns: + list: 包含所有匹配的取费表模板对象的列表 + """ + if not partial_name: + return [] + + # 使用通用方法获取父节点 + parent_node_data = self.get_node_by_path(parent_path) + parent_node_id = parent_node_data["id"] if parent_node_data and "id" in parent_node_data else None + + # 构建查询,根据是否有父节点ID调整查询条件 + if parent_node_id: + query = """ + MATCH (p)-[:CONTAINS|HAS|USES|RELATED_TO]-(t) + WHERE p.id = $parent_id AND t.name CONTAINS $partial_name + AND (t:FeeTableTemplate OR t.type = 'FeeTableTemplate') + RETURN t LIMIT 20 + """ + params = {"parent_id": parent_node_id, "partial_name": partial_name} + else: + query = """ + MATCH (t) + WHERE t.name CONTAINS $partial_name + AND (t:FeeTableTemplate OR t.type = 'FeeTableTemplate') + RETURN t LIMIT 20 + """ + params = {"partial_name": partial_name} + + try: + result = self.session.run(query, **params) + + templates = [] + for record in result: + template = self._create_fee_template_object(record["t"]) + templates.append(template) + + return templates + except Exception as e: + print(f"通过父节点路径和模糊名称获取取费表模板时出错: {e}") + return [] + + # 辅助方法,用于创建取费表模板对象并填充属性 + def _create_fee_template_object(self, node_data): + """ + 根据节点数据创建取费表模板对象并填充属性 + + Args: + node_data (dict): 节点数据 + + Returns: + FeeTableTemplateItem: 创建的取费表模板对象 + """ + template = FeeTableTemplateItem() + + # 填充属性 + for key, value in node_data.items(): + if hasattr(template, key): + setattr(template, key, value) + + # 更新缓存 + if hasattr(template, "OutlayID") and template.OutlayID: + self.fee_templates[template.OutlayID] = template + + return template + + # 费用表查询方法实现 + def get_fee_schedule_on_auxiliary_expense_table(self, table_name, fee_name, fee): + """ + 在辅助费用表中查找费用 + + Args: + table_name (str): 费用表名称 + fee_name (str): 要查找的费用名称 + fee (str): 匹配的费用值属性名 + + Returns: + str: 匹配到的费用名称节点对应的费用值 + """ + if not table_name or not fee_name or not fee: + return None + + # 构建查询,查找辅助费用表中的特定费用 - 使用任意关系类型 + query = """ + MATCH (t:FeeScheduleItem)-[r]->(f:Fee) + WHERE t.name = $table_name AND f.name = $fee_name + RETURN f LIMIT 1 + """ + params = {"table_name": table_name, "fee_name": fee_name} + + try: + result = self.session.run(query, **params) + + all_records = result.data() + + if len(all_records) > 0: + fee_node = all_records[0]["f"] + value = fee_node.get(fee) + if value is not None: + return value + + return None + except Exception as e: + print(f"在辅助费用表中查找费用时出错: {e}") + return None + + def get_fee_schedule_on_other_expense_table(self, table_name, fee_name, fee): + """ + 在其它费用表中查找费用 + + Args: + table_name (str): 费用表名称 + fee_name (str): 要查找的费用名称 + fee (str): 匹配的费用值属性名 + + Returns: + str: 匹配到的费用名称节点对应的费用值 + """ + if not table_name or not fee_name or not fee: + return None + + # 构建查询,查找其它费用表中的特定费用 - 使用任意关系类型 + query = """ + MATCH (t:FeeScheduleItem)-[r]->(f:Fee) + WHERE t.name = $table_name AND f.name = $fee_name + RETURN f LIMIT 1 + """ + params = {"table_name": table_name, "fee_name": fee_name} + + try: + result = self.session.run(query, **params) + + all_records = result.data() + + if len(all_records) > 0: + fee_node = all_records[0]["f"] + value = fee_node.get(fee) + if value is not None: + return value + + return None + except Exception as e: + print(f"在其它费用表中查找费用时出错: {e}") + return None + + def get_fee_schedule_on_land_acquisition_fee_table_table(self, table_name, fee_name, fee): + """ + 在土地征用费表中查找费用 + + Args: + table_name (str): 费用表名称 + fee_name (str): 要查找的费用名称 + fee (str): 匹配的费用值属性名 + + Returns: + str: 匹配到的费用名称节点对应的费用值 + """ + if not table_name or not fee_name or not fee: + return None + + # 构建查询,查找土地征用费表中的特定费用 - 使用任意关系类型 + query = """ + MATCH (t:FeeScheduleItem)-[r]->(f:Fee) + WHERE t.name = $table_name AND f.name = $fee_name + RETURN f LIMIT 1 + """ + params = {"table_name": table_name, "fee_name": fee_name} + + try: + result = self.session.run(query, **params) + + all_records = result.data() + + if len(all_records) > 0: + fee_node = all_records[0]["f"] + value = fee_node.get(fee) + if value is not None: + return value + + return None + except Exception as e: + print(f"在土地征用费表中查找费用时出错: {e}") + return None + + def get_fee_schedule_on_installation_price_difference_table(self, table_name, fee_name, fee): + """ + 在安装价差表中查找费用 + + Args: + table_name (str): 费用表名称 + fee_name (str): 要查找的费用名称 + fee (str): 匹配的费用值属性名 + + Returns: + str: 匹配到的费用名称节点对应的费用值 + """ + if not table_name or not fee_name or not fee: + return None + + # 构建查询,查找安装价差表中的特定费用 - 使用任意关系类型 + query = """ + MATCH (t:FeeScheduleItem)-[r]->(f:Fee) + WHERE t.name = $table_name AND f.name = $fee_name + RETURN f LIMIT 1 + """ + params = {"table_name": table_name, "fee_name": fee_name} + + try: + result = self.session.run(query, **params) + + all_records = result.data() + + if len(all_records) > 0: + fee_node = all_records[0]["f"] + value = fee_node.get(fee) + if value is not None: + return value + + return None + except Exception as e: + print(f"在安装价差表中查找费用时出错: {e}") + return None + + def get_fee_schedule_on_Engineering_Cost_table(self, table_name, fee_name, fee): + """ + 在工程费用表中查找费用 + + Args: + table_name (str): 费用表名称 + fee_name (str): 要查找的费用名称 + fee (str): 匹配的费用值属性名 + + Returns: + str: 匹配到的费用名称节点对应的费用值 + """ + if not table_name or not fee_name or not fee: + return None + + # 调试输出 + print(f"查询费用表: {table_name}") + print(f"查询费用名称: {fee_name}") + print(f"查询费用属性: {fee}") + + # 构建查询,使用递归关系查询,并扩展关系类型 + query = """ + MATCH (t:FeeScheduleItem)-[r*1..5]->(f:Fee) + WHERE t.name = $table_name AND f.name = $fee_name + RETURN f LIMIT 10 + """ + params = {"table_name": table_name, "fee_name": fee_name} + + print(f"执行查询: {query}") + print(f"参数: {params}") + + try: + result = self.session.run(query, **params) + all_records = result.data() + + print(f"查询结果数量: {len(all_records)}") + + if len(all_records) > 0: + fee_node = all_records[0]["f"] + print(f"找到费用节点: {fee_node}") + + # 获取节点的所有属性,用于调试 + for key, value in fee_node.items(): + print(f"属性: {key} = {value}") + + value = fee_node.get(fee) + if value is not None: + print(f"找到费用值: {value}") + return value + else: + print(f"节点中没有属性: {fee}") + + # 如果没有找到,尝试另一种查询方式,使用CONTAINS进行模糊匹配 + print("尝试使用模糊匹配...") + backup_query = """ + MATCH (t:FeeScheduleItem)-[r*1..5]->(f:Fee) + WHERE t.name CONTAINS $table_name AND f.name CONTAINS $fee_name + RETURN f LIMIT 10 + """ + backup_result = self.session.run(backup_query, **params) + backup_records = backup_result.data() + + print(f"模糊匹配结果数量: {len(backup_records)}") + + if len(backup_records) > 0: + fee_node = backup_records[0]["f"] + print(f"找到费用节点(模糊匹配): {fee_node}") + + # 获取节点的所有属性,用于调试 + for key, value in fee_node.items(): + print(f"属性: {key} = {value}") + + value = fee_node.get(fee) + if value is not None: + print(f"找到费用值: {value}") + return value + else: + print(f"节点中没有属性: {fee}") + + return None + except Exception as e: + print(f"在工程费用表中查找费用时出错: {e}") + import traceback + + traceback.print_exc() + return None + + +class ProjectBuilder: + """ + 项目构建器 + 描述: 用于构建项目对象的构建器 + """ + + _instance = None + + @staticmethod + def build(): + """ + 构建并返回项目实例 + + Returns: + ProjectTookiItNeo4j: 创建的项目实例 + """ + # 如果已经有实例,先关闭它 + if ProjectBuilder._instance is not None: + ProjectBuilder._instance.close() + + # 创建新实例 + + ProjectBuilder._instance = ProjectTookiItNeo4j() + + return ProjectBuilder._instance + + @staticmethod + def close(): + """ + 关闭当前项目实例的连接 + """ + if ProjectBuilder._instance is not None: + ProjectBuilder._instance.close() + ProjectBuilder._instance = None + + +# 注册退出处理函数,确保程序退出时自动关闭连接 +atexit.register(ProjectBuilder.close) + + +# project = ProjectBuilder.build() +# result = project.get_quantities_node_by_parent_and_name( +# "工程数据/安装工程/安装/架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立", +# "塔材", +# "主材", +# ) + +# print(result)