from langchain.chains import LLMChain from langchain_openai import OpenAI from langchain_experimental.utilities import PythonREPL from project_implementation import ProjectBuilder from prompt_templates import FUNCTION_CALL_PROMPT import inspect import project import io import sys from parameter_rewriting import rewrite_query_parameters, KnowledgeGraphProcessor import json from langchain.agents import Tool, AgentExecutor, create_react_agent from llm import llm # 获取ProjectTookiIt类的方法定义 def get_project_class_methods(): """ 从project模块中提取ProjectTookiIt类的方法定义 Returns: str: 格式化后的方法定义字符串 """ project_class_code = inspect.getsource(project.ProjectTookiIt) lines = project_class_code.split("\n") result_lines = [] in_class = False skip_init = False for line in lines: if line.strip().startswith("class ProjectTookiIt"): in_class = True result_lines.append(line) elif in_class: if line.strip().startswith("def __init__"): skip_init = True elif skip_init and line.strip() and not line.startswith(" " * 8): skip_init = False if not skip_init: result_lines.append(line) return "\n".join(result_lines) # 创建动态提示模板 project_class_methods = get_project_class_methods() # 创建 Chain function_call_chain = LLMChain(llm=llm, prompt=FUNCTION_CALL_PROMPT, output_key="code") # Python 执行器 repl = PythonREPL() # 创建知识图谱处理器实例 kg_processor = KnowledgeGraphProcessor() # 定义搜索知识库的工具 def search_knowledge_and_node_definition(query): """ 在知识库中搜索关键词 Args: query (str): 搜索关键词 Returns: str: 搜索结果的JSON字符串 """ found_data = kg_processor._get_relevant_knowledge(query) if found_data: return json.dumps(found_data, ensure_ascii=False, indent=2) else: return f"未找到与'{query}'相关的信息" # 创建工具列表 tools = [ Tool( name="search_knowledge_and_node_definition", func=search_knowledge_and_node_definition, description="获取输入节点的知识图谱结构和对应节点定义类型代码。输入应该是一个节点类型名称。", ), ] # 创建Agent agent = create_react_agent(llm, tools, FUNCTION_CALL_PROMPT) # 创建Agent执行器 agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True) def nl_query_to_function_call(input_data): """ 将自然语言查询转换为函数调用并执行,或直接执行提供的代码 Args: input_data (dict): 包含type和value的字典 { "type": "query|code", "value": "查询内容或代码" } Returns: dict: 包含状态码、消息和数据的字典 """ input_type = input_data.get("type", "query") input_value = input_data.get("value", "") original_query = input_value print(f"\n====== 开始执行查询函数 ======") print(f"查询类型: {input_type}") print(f"查询内容: {input_value}") # 第一次执行:使用Agent生成代码 if input_type == "query": print("\n----- 第一次尝试:使用Agent生成代码 -----") try: agent_response = agent_executor.invoke( { "query": input_value, "project_class_methods": project_class_methods, } ) code = agent_response["output"].strip() print(f"\n生成的代码:\n{code}") except Exception as e: # 确保恢复stdout sys.stdout = old_stdout return { "code": 1, "message": f"Agent生成代码失败: {e}", "data": {"value": "", "code": ""}, "query_status": "Agent生成代码失败", } else: code = input_value.strip() print(f"\n使用提供的代码:\n{code}") original_code = code def execute_code(code_str): """封装代码执行逻辑""" print(f"开始执行代码: {code_str[:50]}...") try: namespace = { "ProjectBuilder": ProjectBuilder, "project_implementation": __import__("project_implementation"), "project": __import__("project"), } old_stdout = sys.stdout redirected_output = io.StringIO() sys.stdout = redirected_output exec(code_str, namespace) # 确保neo4j_find_function存在 if "neo4j_find_function" not in namespace: raise ValueError("代码中未定义neo4j_find_function函数") result_tuple = namespace["neo4j_find_function"]() sys.stdout = old_stdout output = redirected_output.getvalue().strip() if not isinstance(result_tuple, tuple) or len(result_tuple) != 4: raise ValueError("函数应返回包含4个元素的元组(status, data, error, helper_info)") status, data, error, helper_info = result_tuple # 添加更详细的日志 print(f"执行结果: status={status}, data={data}, error={error}") return { "status": status, "data": data, "error": error, "helper_info": helper_info, "output": output, } except Exception as e: # 确保恢复stdout sys.stdout = old_stdout import traceback print(f"执行代码时出错: {e}") print(traceback.format_exc()) return { "status": "error", "error": str(e), "helper_info": [], "traceback": traceback.format_exc(), } # 第一次执行 print("开始执行函数...") print("准备执行代码...") result = execute_code(code) print(f"执行完成,结果状态: {result.get('status', 'unknown')}") # 检查是否需要重写 if result["status"] == "success": print("执行成功,准备返回结果...") # 返回成功结果... final_result = { "code": 0, "message": "成功", "data": { "value": result["data"], "code": original_code, }, "query_status": "第一次查询成功", } else: print(f"执行失败,错误: {result.get('error', 'unknown')}") # 重写代码... # 第一次失败,记录错误信息 error_info = { "error": result["error"], "helper_info": result.get("helper_info", []), "traceback": result.get("traceback", ""), } print("\n第一次查询失败,尝试使用LLM修复代码...") # 使用LLM直接修复代码 rewritten = rewrite_query_parameters(original_code=original_code, error_info=error_info) fixed_code = rewritten.get("code", "").strip() if not fixed_code: print("\nLLM未返回有效代码") final_result = { "code": 1, "message": error_info["error"], "data": {"value": "", "code": original_code}, "error_info": error_info, "query_status": "第一次查询失败,LLM未返回有效代码", } else: print(f"\nLLM修复后的代码:\n{fixed_code}") # 第二次执行修复后的代码 result = execute_code(fixed_code) if result["status"] == "success": print("\nLLM修复后查询成功") print(f"\n执行结果详情: {json.dumps(result, ensure_ascii=False, indent=2)}") final_result = { "code": 0, "message": "成功", "data": { "value": result["data"], "code": fixed_code, }, "query_status": "LLM修复后查询成功", } else: print("\nLLM修复后查询仍然失败") final_result = { "code": 1, "message": result["error"], "data": {"value": "", "code": fixed_code}, "error_info": { "error": result["error"], "helper_info": result.get("helper_info", []), }, "query_status": "LLM修复后查询仍然失败", } # 最后返回前 print("函数执行完毕,准备返回最终结果") return final_result def format_result(result): """ 格式化查询结果 Args: result: 查询结果(可能为 list、dict 或其他类型) Returns: str: 格式化后的结果 """ # 处理 project 对象 if hasattr(result, "__module__") and result.__module__ == "project": # 这是一个 project 模块中的对象 attrs = {k: v for k, v in result.__dict__.items() if not k.startswith("_")} return json.dumps(attrs, ensure_ascii=False, indent=2) # 处理 project 对象列表 if isinstance(result, list) and all( hasattr(item, "__module__") and item.__module__ == "project" for item in result if hasattr(item, "__module__") ): formatted_items = [] for item in result: if hasattr(item, "__dict__"): attrs = {k: v for k, v in item.__dict__.items() if not k.startswith("_")} formatted_items.append(attrs) else: formatted_items.append(str(item)) return json.dumps(formatted_items, ensure_ascii=False, indent=2) # 如果结果是字符串,可能包含调试信息,需要提取有用部分 if isinstance(result, str): # 尝试提取最终结果部分 if "[]" in result: return "未找到匹配的数据。" # 如果包含节点信息,提取关键部分 import re node_match = re.search(r"找到.*?labels=.*?properties=(.*?)>", result) if node_match: try: # 提取属性部分并格式化 props_str = node_match.group(1).replace("'", '"') import ast props = ast.literal_eval(props_str) formatted = "找到节点:\n" for k, v in props.items(): formatted += f" {k}: {v}\n" return formatted except: pass # 如果包含查询结果数量 count_match = re.search(r"查询结果数量: (\d+)", result) if count_match: count = count_match.group(1) if count == "0": return "未找到匹配的数据。" # 如果是列表 if isinstance(result, list): if not result: return "未找到匹配的数据。" lines = [f"找到 {len(result)} 条匹配结果:"] for i, item in enumerate(result, 1): lines.append(f"\n结果 {i}:") if hasattr(item, "items"): # 检查是否有items方法(字典或类似字典的对象) try: for k, v in item.items(): lines.append(f" {k}: {v}") except: lines.append(f" {item}") else: lines.append(f" {item}") return "\n".join(lines) # 如果是字典 elif isinstance(result, dict): lines = ["查询结果:"] for k, v in result.items(): lines.append(f" {k}: {v}") return "\n".join(lines) # 其他类型 else: return str(result) def format_dict_or_item(item): """ 格式化字典或其他对象 Args: item: 字典或其他对象 Returns: str: 格式化后的字符串 """ if isinstance(item, dict): formatted = "" for key, value in item.items(): formatted += f" {key}: {value}\n" return formatted return str(item) def _extract_all_values(item, result_list): """ 递归提取字典中的所有值并添加到结果列表中 Args: item: 字典或其他对象 result_list: 用于存储结果的列表 """ if isinstance(item, dict): for key, value in item.items(): if key != "children": # 跳过children键,它会在主逻辑中单独处理 if isinstance(value, str): result_list.append(value) elif isinstance(value, (dict, list)): _extract_all_values(value, result_list) elif isinstance(item, list): for element in item: _extract_all_values(element, result_list) # 定义一个辅助函数来提取和处理相似参数 def extract_similar_parameters(path_parts, knowledge_base): """ 从路径和知识库中提取相似参数 Args: path_parts (list): 路径部分列表 knowledge_base (list): 知识库数据 Returns: str: 相似参数字符串 """ # 从路径中提取关键词 extracted_parts = [] for part in path_parts: if "/" in part: last_part = part.split("/")[-1].strip() if last_part: extracted_parts.append(last_part) else: extracted_parts.append(part) # 收集所有可能的相似参数 similar_params = [] # 从knowledge_base中查找相关项及其子节点 for part in extracted_parts: for item in knowledge_base: # 检查当前项是否匹配 match_found = False for key, value in item.items(): if isinstance(value, str) and part.lower() in value.lower(): match_found = True break if match_found: # 如果找到匹配项,提取所有值 _extract_all_values(item, similar_params) # 特别处理子节点 if "children" in item and isinstance(item["children"], list): for child in item["children"]: _extract_all_values(child, similar_params) # 移除重复项并排序 similar_params = list(set(similar_params)) similar_params.sort() print(f"找到的相似参数: {similar_params}") # 创建一个包含所有相似参数的字符串 return ", ".join(similar_params) # 在查询前先增加一个简单函数,专门提取字符串中的键值对值 def extract_values_from_kb_string(kb_string): """从知识库字符串中提取所有键值对的值""" import re # 匹配所有键值对:"key": "value" 的模式 # 这里我们直接取第二个捕获组,即值部分 values = re.findall(r'"([^"]+)"\s*:\s*"([^"]+)"', kb_string) # 只保留值(第二个元素) result = [match[1] for match in values] return result # question = { # "type": "query", # "value": "查找一下【工程数据/安装工程/安装/架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立】的类型为【主材】的【塔材】", # } # result = nl_query_to_function_call(question) # print("\n最终结果:") # print(result)