from langchain.chains import LLMChain from langchain_openai import OpenAI from langchain_experimental.utilities import PythonREPL from project_implementation import ProjectBuilder from prompt_templates import FUNCTION_CALL_PROMPT import inspect import project import io import sys from parameter_rewriting import rewrite_query_parameters, KnowledgeGraphProcessor import json from langchain.agents import Tool, AgentExecutor, create_react_agent from langchain.prompts import PromptTemplate from llm import llm # 获取ProjectTookiIt类的方法定义 def get_project_class_methods(): """ 从project模块中提取ProjectTookiIt类的方法定义 Returns: str: 格式化后的方法定义字符串 """ project_class_code = inspect.getsource(project.ProjectTookiIt) lines = project_class_code.split("\n") result_lines = [] in_class = False skip_init = False for line in lines: if line.strip().startswith("class ProjectTookiIt"): in_class = True result_lines.append(line) elif in_class: if line.strip().startswith("def __init__"): skip_init = True elif skip_init and line.strip() and not line.startswith(" " * 8): skip_init = False if not skip_init: result_lines.append(line) return "\n".join(result_lines) # 创建动态提示模板 project_class_methods = get_project_class_methods() # 创建 Chain function_call_chain = LLMChain(llm=llm, prompt=FUNCTION_CALL_PROMPT, output_key="code") # Python 执行器 repl = PythonREPL() # 创建知识图谱处理器实例 kg_processor = KnowledgeGraphProcessor() # 定义搜索知识库的工具 def search_knowledge_base(query): """ 在知识库中搜索关键词 Args: query (str): 搜索关键词 Returns: str: 搜索结果的JSON字符串 """ found_data = kg_processor._search_in_kg(query) if found_data: return json.dumps(found_data, ensure_ascii=False, indent=2) else: return f"未找到与'{query}'相关的信息" # 定义获取节点定义的工具 def get_node_definition(node_type): """ 获取节点类型的定义 Args: node_type (str): 节点类型名称 Returns: str: 节点类型定义 """ definition = kg_processor._get_node_definition(node_type) return definition # 创建工具列表 tools = [ Tool( name="search_knowledge_base", func=search_knowledge_base, description="在知识库中搜索关键词,返回相关信息。输入应该是一个搜索关键词。", ), Tool( name="get_node_definition", func=get_node_definition, description="获取节点类型的定义,返回类型定义代码。输入应该是一个节点类型名称。", ), ] # 创建Agent agent = create_react_agent(llm, tools, FUNCTION_CALL_PROMPT) # 创建Agent执行器 agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True) def nl_query_to_function_call(input_data): """ 将自然语言查询转换为函数调用并执行,或直接执行提供的代码 Args: input_data (dict): 包含type和value的字典 { "type": "query|code", "value": "查询内容或代码" } Returns: dict: 包含状态码、消息和数据的字典 """ input_type = input_data.get("type", "query") input_value = input_data.get("value", "") max_retries = 1 # 设置最大重试次数 current_retry = 0 original_query = input_value # 保存原始查询用于RAG print(f"\n====== 开始处理查询 ======") print(f"查询类型: {input_type}") print(f"查询内容: {input_value}") while current_retry <= max_retries: print(f"\n----- 尝试 #{current_retry + 1} -----") # 如果type是query,使用LLM生成代码 if input_type == "query" and current_retry == 0: # 从查询中提取关键部分 import re # 提取【】中的内容 path_parts = re.findall(r"【([^】]+)】", input_value) # 创建临时代码 temp_code = f'search("{input_value}")' # 对于每个路径,只添加最后一个部分 for part in path_parts: if "/" in part: # 提取路径中的最后一个部分 last_part = part.split("/")[-1].strip() if last_part: temp_code += f'\nsearch("{last_part}")' else: # 如果没有/,直接使用整个部分 temp_code += f'\nsearch("{part}")' # 获取知识库内容和节点定义 knowledge_base, node_definitions = kg_processor._get_relevant_knowledge(temp_code) # 使用Agent执行查询 agent_response = agent_executor.invoke( { "query": input_value, "project_class_methods": project_class_methods, "KnowledgeBase": knowledge_base, "NodeDefinition": node_definitions, } ) # 从Agent响应中提取代码 code = agent_response["output"] print(f"\n生成的代码:\n{code}") else: print(f"\n使用重写后的代码:\n{input_value}") code = input_value # 保存原始代码用于返回 original_code = code # 执行生成的函数并捕获输出 try: # 创建一个新的命名空间来执行代码,包含必要的导入 namespace = { "ProjectBuilder": ProjectBuilder, # 添加ProjectBuilder到命名空间 "project_implementation": __import__("project_implementation"), "project": __import__("project"), } # 执行生成的代码,定义neo4j_find_function函数 exec(code, namespace) # 重定向stdout来捕获print输出 old_stdout = sys.stdout redirected_output = io.StringIO() sys.stdout = redirected_output try: # 执行函数 print("\n执行代码...") result = namespace["neo4j_find_function"]() # 获取捕获的输出 output = redirected_output.getvalue().strip() print(f"\n原始输出:\n{output}") # 检查结果是否为空 is_empty_result = ( not output or output.lower() == "none" or output == "[]" or "未找到" in output or "[]" in output or "None" in output or result is None ) # 如果结果为空,走重写流程 if is_empty_result: print("\n查询未找到结果,尝试定位具体缺失节点...") # 解析原始查询路径中的最后一个节点名 import re match = re.search(r"【([^】]+)】\s*$", original_query) missing_node = match.group(1) if match else "未知节点" error_info = { "error_type": "NodeNotFoundError", "error_message": f"{missing_node} 未找到,请检查该节点是否存在。", "missing_node": missing_node, "original_query": original_query, "executed_code": original_code, } print("结构化错误信息:") print(json.dumps(error_info, ensure_ascii=False, indent=2)) if current_retry < max_retries: print("\n尝试使用RAG重写查询...") try: # 使用RAG重写查询和代码,并传递错误信息 rewritten = rewrite_query_parameters(original_query, original_code, error_info) print(f"\nRAG重写结果: {json.dumps(rewritten, ensure_ascii=False, indent=2)}") # 更新查询和代码 if "query" in rewritten and "code" in rewritten and rewritten["code"] != original_code: print("\nRAG重写成功,使用新代码重试...") input_value = rewritten["code"] # 直接使用重写后的代码 input_type = "code" # 切换到代码模式 current_retry += 1 continue # 继续下一次循环 else: print("\nRAG重写未产生新代码,返回原始错误") except Exception as e: print(f"\nRAG重写失败: {e}") # 记录错误但继续执行 # RAG重写失败或未产生新代码,返回原始错误 query_status = ( "第一次查询失败,RAG重写也失败" if current_retry == 0 else f"第{current_retry+1}次查询失败,RAG重写也失败" ) print(f"\n{query_status}") return { "code": 1, "message": f"{missing_node} 未找到,请检查该节点是否存在。", "data": {"value": "", "code": original_code}, "error_info": error_info, "query_status": query_status, } # 清理输出,只保留有用的结果部分 clean_output = output # 如果输出包含查询结果数量和对象引用 if "查询结果数量:" in output and "", output, re.DOTALL) if node_match: props_str = node_match.group(1).replace("'", '"') try: import ast props = ast.literal_eval(props_str) clean_output = json.dumps(props, ensure_ascii=False, indent=2) except: pass # 如果有查询结果数量信息 count_match = re.search(r"查询结果数量: (\d+)", output) if count_match: count = count_match.group(1) if count == "0": clean_output = "未找到匹配的数据。" is_empty_result = True elif not node_match: # 如果没有提取到节点属性但有结果 clean_output = f"找到 {count} 条匹配结果" # 检查结果对象 if result is not None: if isinstance(result, list): if not result: # 空列表 is_empty_result = True else: # 处理非空列表 formatted_items = [] for item in result: if hasattr(item, "__dict__"): # 提取对象的所有属性 attrs = {k: v for k, v in item.__dict__.items() if not k.startswith("_")} formatted_items.append(attrs) else: formatted_items.append(str(item)) if not is_empty_result: # 只有在不是空结果时才返回成功 query_status = ( "第一次查询成功" if current_retry == 0 else f"第{current_retry+1}次查询成功(RAG重写后)" ) print(f"\n{query_status}") return { "code": 0, "message": "成功", "data": { "value": json.dumps(formatted_items, ensure_ascii=False, indent=2), "code": original_code, }, "query_status": query_status, } elif hasattr(result, "__dict__"): # 单个对象 attrs = {k: v for k, v in result.__dict__.items() if not k.startswith("_")} if not is_empty_result: # 只有在不是空结果时才返回成功 query_status = ( "第一次查询成功" if current_retry == 0 else f"第{current_retry+1}次查询成功(RAG重写后)" ) print(f"\n{query_status}") return { "code": 0, "message": "成功", "data": { "value": json.dumps(attrs, ensure_ascii=False, indent=2), "code": original_code, }, "query_status": query_status, } # 如果没有对象属性但有清理后的输出,且不是空结果 if ( clean_output and clean_output.lower() != "none" and clean_output != "[]" and "未找到" not in clean_output and not is_empty_result ): query_status = ( "第一次查询成功" if current_retry == 0 else f"第{current_retry+1}次查询成功(RAG重写后)" ) print(f"\n{query_status}") return { "code": 0, "message": "成功", "data": {"value": clean_output, "code": original_code}, "query_status": query_status, } finally: # 恢复stdout sys.stdout = old_stdout except Exception as e: import traceback error_details = traceback.format_exc() print(f"\n执行代码时出错: {error_details}") # 如果走到这里,说明结果为空或未找到匹配项,应该执行RAG重写流程 print("\n查询未找到结果,尝试定位具体缺失节点...") # 解析原始查询路径中的最后一个节点名 import re match = re.search(r"【([^】]+)】\s*$", original_query) missing_node = match.group(1) if match else "未知节点" error_info = { "error_type": "NodeNotFoundError", "error_message": f"{missing_node} 未找到,请检查该节点是否存在。", "missing_node": missing_node, "original_query": original_query, "executed_code": original_code, } print("结构化错误信息:") print(json.dumps(error_info, ensure_ascii=False, indent=2)) if current_retry < max_retries: print("\n尝试使用RAG重写查询...") try: # 使用RAG重写查询和代码,并传递错误信息 rewritten = rewrite_query_parameters(original_query, original_code, error_info) print(f"\nRAG重写结果: {json.dumps(rewritten, ensure_ascii=False, indent=2)}") # 更新查询和代码 if "query" in rewritten and "code" in rewritten and rewritten["code"] != original_code: print("\nRAG重写成功,使用新代码重试...") input_value = rewritten["code"] # 直接使用重写后的代码 input_type = "code" # 切换到代码模式 current_retry += 1 continue # 继续下一次循环 else: print("\nRAG重写未产生新代码,返回原始错误") except Exception as e: print(f"\nRAG重写失败: {e}") # 记录错误但继续执行 # RAG重写失败或未产生新代码,返回原始错误 query_status = ( "第一次查询失败,RAG重写也失败" if current_retry == 0 else f"第{current_retry+1}次查询失败,RAG重写也失败" ) print(f"\n{query_status}") return { "code": 1, "message": f"{missing_node} 未找到,请检查该节点是否存在。", "data": {"value": "", "code": original_code}, "error_info": error_info, "query_status": query_status, } # 如果所有重试都失败 print("\n所有重试都失败,无法找到匹配的结果") query_status = "所有重试都失败" return { "code": 1, "message": "所有重试都失败,无法找到匹配的结果", "data": {"value": "", "code": original_code}, "query_status": query_status, } def format_result(result): """ 格式化查询结果 Args: result: 查询结果(可能为 list、dict 或其他类型) Returns: str: 格式化后的结果 """ # 处理 project 对象 if hasattr(result, "__module__") and result.__module__ == "project": # 这是一个 project 模块中的对象 attrs = {k: v for k, v in result.__dict__.items() if not k.startswith("_")} return json.dumps(attrs, ensure_ascii=False, indent=2) # 处理 project 对象列表 if isinstance(result, list) and all( hasattr(item, "__module__") and item.__module__ == "project" for item in result if hasattr(item, "__module__") ): formatted_items = [] for item in result: if hasattr(item, "__dict__"): attrs = {k: v for k, v in item.__dict__.items() if not k.startswith("_")} formatted_items.append(attrs) else: formatted_items.append(str(item)) return json.dumps(formatted_items, ensure_ascii=False, indent=2) # 如果结果是字符串,可能包含调试信息,需要提取有用部分 if isinstance(result, str): # 尝试提取最终结果部分 if "[]" in result: return "未找到匹配的数据。" # 如果包含节点信息,提取关键部分 import re node_match = re.search(r"找到.*?labels=.*?properties=(.*?)>", result) if node_match: try: # 提取属性部分并格式化 props_str = node_match.group(1).replace("'", '"') import ast props = ast.literal_eval(props_str) formatted = "找到节点:\n" for k, v in props.items(): formatted += f" {k}: {v}\n" return formatted except: pass # 如果包含查询结果数量 count_match = re.search(r"查询结果数量: (\d+)", result) if count_match: count = count_match.group(1) if count == "0": return "未找到匹配的数据。" # 如果是列表 if isinstance(result, list): if not result: return "未找到匹配的数据。" lines = [f"找到 {len(result)} 条匹配结果:"] for i, item in enumerate(result, 1): lines.append(f"\n结果 {i}:") if hasattr(item, "items"): # 检查是否有items方法(字典或类似字典的对象) try: for k, v in item.items(): lines.append(f" {k}: {v}") except: lines.append(f" {item}") else: lines.append(f" {item}") return "\n".join(lines) # 如果是字典 elif isinstance(result, dict): lines = ["查询结果:"] for k, v in result.items(): lines.append(f" {k}: {v}") return "\n".join(lines) # 其他类型 else: return str(result) def format_dict_or_item(item): """ 格式化字典或其他对象 Args: item: 字典或其他对象 Returns: str: 格式化后的字符串 """ if isinstance(item, dict): formatted = "" for key, value in item.items(): formatted += f" {key}: {value}\n" return formatted return str(item) question = { "type": "query", "value": "查找一下【工程数据/安装工程/安装/架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立】的类型为【主材】的【阿巴阿巴】", } result = nl_query_to_function_call(question) print(result)