diff --git a/langchain_neo4j.py b/langchain_neo4j.py new file mode 100644 index 0000000..1815cb7 --- /dev/null +++ b/langchain_neo4j.py @@ -0,0 +1,732 @@ +from langchain.chains import LLMChain +from langchain_openai import OpenAI +from langchain_experimental.utilities import PythonREPL +from project_implementation import ProjectBuilder +from prompt_templates import FUNCTION_CALL_PROMPT +import inspect +import project +import io +import sys +from parameter_rewriting import rewrite_query_parameters, KnowledgeGraphProcessor +import json +from langchain.agents import Tool, AgentExecutor, create_react_agent + +from llm import llm + + +# 获取ProjectTookiIt类的方法定义 +def get_project_class_methods(): + """ + 从project模块中提取ProjectTookiIt类的方法定义 + + Returns: + str: 格式化后的方法定义字符串 + """ + project_class_code = inspect.getsource(project.ProjectTookiIt) + + lines = project_class_code.split("\n") + result_lines = [] + + in_class = False + skip_init = False + + for line in lines: + if line.strip().startswith("class ProjectTookiIt"): + in_class = True + result_lines.append(line) + elif in_class: + if line.strip().startswith("def __init__"): + skip_init = True + elif skip_init and line.strip() and not line.startswith(" " * 8): + skip_init = False + + if not skip_init: + result_lines.append(line) + + return "\n".join(result_lines) + + +# 创建动态提示模板 +project_class_methods = get_project_class_methods() + +# 创建 Chain +function_call_chain = LLMChain(llm=llm, prompt=FUNCTION_CALL_PROMPT, output_key="code") + +# Python 执行器 +repl = PythonREPL() + +# 创建知识图谱处理器实例 +kg_processor = KnowledgeGraphProcessor() + + +# 定义搜索知识库的工具 +def search_knowledge_and_node_definition(query): + """ + 在知识库中搜索关键词 + + Args: + query (str): 搜索关键词 + + Returns: + str: 搜索结果的JSON字符串 + """ + found_data = kg_processor._get_relevant_knowledge(query) + if found_data: + return json.dumps(found_data, ensure_ascii=False, indent=2) + else: + return f"未找到与'{query}'相关的信息" + + +# 创建工具列表 +tools = [ + Tool( + name="search_knowledge_and_node_definition", + func=search_knowledge_and_node_definition, + description="获取输入节点的知识图谱结构和对应节点定义类型代码。输入应该是一个节点类型名称。", + ), +] + +# 创建Agent +agent = create_react_agent(llm, tools, FUNCTION_CALL_PROMPT) + +# 创建Agent执行器 +agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True) + + +def nl_query_to_function_call(input_data): + """ + 将自然语言查询转换为函数调用并执行,或直接执行提供的代码 + + Args: + input_data (dict): 包含type和value的字典 + { + "type": "query|code", + "value": "查询内容或代码" + } + + Returns: + dict: 包含状态码、消息和数据的字典 + """ + input_type = input_data.get("type", "query") + input_value = input_data.get("value", "") + max_retries = 1 # 设置最大重试次数 + current_retry = 0 + + original_query = input_value # 保存原始查询用于RAG + + print(f"\n====== 开始处理查询 ======") + print(f"查询类型: {input_type}") + print(f"查询内容: {input_value}") + + while current_retry <= max_retries: + print(f"\n----- 尝试 #{current_retry + 1} -----") + + # 如果type是query,使用LLM生成代码 + if input_type == "query" and current_retry == 0: + + # 使用Agent执行查询 + agent_response = agent_executor.invoke( + { + "query": input_value, + "project_class_methods": project_class_methods, + } + ) + + # 从Agent响应中提取代码 + code = agent_response["output"] + print(f"\n生成的代码:\n{code}") + else: + print(f"\n使用重写后的代码:\n{code}") + code = code + + # 保存原始代码用于返回 + original_code = code + + # 执行生成的函数并捕获输出 + try: + # 创建一个新的命名空间来执行代码,包含必要的导入 + namespace = { + "ProjectBuilder": ProjectBuilder, # 添加ProjectBuilder到命名空间 + "project_implementation": __import__("project_implementation"), + "project": __import__("project"), + } + + # 执行生成的代码,定义neo4j_find_function函数 + exec(code, namespace) + + # 重定向stdout来捕获print输出 + old_stdout = sys.stdout + redirected_output = io.StringIO() + sys.stdout = redirected_output + + try: + # 执行函数并获取元组结果 + result_tuple = namespace["neo4j_find_function"]() + + # 确保结果是元组且包含4个元素 + if not isinstance(result_tuple, tuple) or len(result_tuple) != 4: + raise ValueError("函数应返回包含4个元素的元组(status, data, error, helper_info)") + + status, data, error, helper_info = result_tuple + + # 获取捕获的输出(如果有) + output = redirected_output.getvalue().strip() + + # 根据状态处理结果 + if status == "success": + query_status = ( + "第一次查询成功" if current_retry == 0 else f"第{current_retry+1}次查询成功(RAG重写后)" + ) + print(f"\n{query_status}") + return { + "code": 0, + "message": "成功", + "data": { + "value": data, # 直接使用data字典,不再需要JSON转换 + "code": original_code, + }, + "query_status": query_status, + } + else: + # 错误情况 + error_info = { + "error": error, + "helper_info": helper_info, + } + + # 重试逻辑 + if current_retry < max_retries: + print("\n尝试使用RAG重写查询...") + try: + rewritten = rewrite_query_parameters(original_code, error_info) + if rewritten and "code" in rewritten and rewritten["code"]: + print(f"\n重写成功,新代码:\n{rewritten['code']}") + current_retry += 1 + code = rewritten["code"] + continue # 重要!继续下一次循环迭代 + except Exception as e: + print(f"\nRAG重写失败: {e}") + # 尝试使用原始查询 + try: + rewritten = rewrite_query_parameters(original_code, error_info) + except Exception as e2: + print(f"再次重写失败: {e2}") + # 继续执行,返回原始错误 + + # 返回错误信息 + query_status = "第一次查询失败" if current_retry == 0 else f"第{current_retry+1}次查询失败" + print(f"\n{query_status}") + return { + "code": 1, + "message": error, + "data": {"value": "", "code": original_code}, + "error_info": error_info, + "query_status": query_status, + } + + # 如果不是新格式,则按照原有逻辑处理 + # 检查结果是否为空 + is_empty_result = True # 默认假设为空 + + # 先检查是否有明显的数据内容 + if output and output.strip().startswith("[{") and output.strip().endswith("}]"): + # 看起来是一个对象数组,尝试转换为有效的JSON格式 + try: + # 将单引号替换为双引号以使其成为有效的JSON + # 但要小心处理嵌套的引号 + import ast + + # 使用ast.literal_eval安全地将Python字符串表示转换为Python对象 + parsed_obj = ast.literal_eval(output) + if parsed_obj and len(parsed_obj) > 0: + is_empty_result = False + except: + # 如果解析失败,继续尝试其他方法 + pass + + # 如果上面的方法没有确定结果不为空,继续尝试JSON解析 + if is_empty_result: + try: + if output.strip(): + parsed_output = json.loads(output) + if parsed_output and ( + isinstance(parsed_output, list) + and len(parsed_output) > 0 + or isinstance(parsed_output, dict) + and len(parsed_output) > 0 + ): + is_empty_result = False + except json.JSONDecodeError: + # 不是有效的JSON,使用其他判断方法 + is_empty_result = ( + not output + or output.lower() == "none" + or output == "[]" + or "未找到" in output + or "None" in output + or result is None + ) + + # 如果结果为空,走重写流程 + if is_empty_result: + # 创建错误信息 + error_info = { + "error": "查询结果为空", + "helper_info": [], + "traceback": "查询执行成功但未返回数据", + } + + if current_retry < max_retries: + print("\n尝试使用RAG重写查询...") + try: + rewritten = rewrite_query_parameters(original_code, error_info) + + # 检查重写是否成功 + if rewritten and "code" in rewritten and rewritten["code"]: + print(f"\n重写成功,新代码:\n{rewritten['code']}") + + # 增加重试计数 + current_retry += 1 + + # 使用重写后的代码进行下一次迭代 + code = rewritten["code"] + continue # 重要!继续下一次循环迭代 + else: + print("\n重写未返回有效代码") + + except Exception as e: + print(f"\nRAG重写失败: {e}") + # 尝试使用原始查询 + try: + rewritten = rewrite_query_parameters(original_code, error_info) + + # 检查重写是否成功 + if rewritten and "code" in rewritten and rewritten["code"]: + print(f"\n第二次重写成功,新代码:\n{rewritten['code']}") + + # 增加重试计数 + current_retry += 1 + + # 使用重写后的代码进行下一次迭代 + code = rewritten["code"] + continue # 重要!继续下一次循环迭代 + else: + print("\n第二次重写未返回有效代码") + except Exception as e2: + print(f"再次重写失败: {e2}") + # 继续执行,返回原始错误 + + # RAG重写失败或未产生新代码,返回原始错误 + query_status = ( + "第一次查询失败,RAG重写也失败" + if current_retry == 0 + else f"第{current_retry+1}次查询失败,RAG重写也失败" + ) + print(f"\n{query_status}") + return { + "code": 1, + "message": "未找到匹配的节点,请检查该节点是否存在。", + "data": {"value": "", "code": original_code}, + "error_info": error_info, + "query_status": query_status, + } + + # 清理输出,只保留有用的结果部分 + clean_output = output + + # 如果输出包含查询结果数量和对象引用 + if "查询结果数量:" in output and "", output, re.DOTALL) + if node_match: + props_str = node_match.group(1).replace("'", '"') + try: + import ast + + props = ast.literal_eval(props_str) + clean_output = json.dumps(props, ensure_ascii=False, indent=2) + except: + pass + + # 如果有查询结果数量信息 + count_match = re.search(r"查询结果数量: (\d+)", output) + if count_match: + count = count_match.group(1) + if count == "0": + clean_output = "未找到匹配的数据。" + is_empty_result = True + elif not node_match: # 如果没有提取到节点属性但有结果 + clean_output = f"找到 {count} 条匹配结果" + + # 检查结果对象 + if result is not None: + if isinstance(result, list): + if not result: # 空列表 + is_empty_result = True + else: + # 处理非空列表 + formatted_items = [] + for item in result: + if hasattr(item, "__dict__"): + # 提取对象的所有属性 + attrs = {k: v for k, v in item.__dict__.items() if not k.startswith("_")} + formatted_items.append(attrs) + else: + formatted_items.append(str(item)) + + if not is_empty_result: # 只有在不是空结果时才返回成功 + query_status = ( + "第一次查询成功" + if current_retry == 0 + else f"第{current_retry+1}次查询成功(RAG重写后)" + ) + print(f"\n{query_status}") + return { + "code": 0, + "message": "成功", + "data": { + "value": json.dumps(formatted_items, ensure_ascii=False, indent=2), + "code": original_code, + }, + "query_status": query_status, + } + elif hasattr(result, "__dict__"): + # 单个对象 + attrs = {k: v for k, v in result.__dict__.items() if not k.startswith("_")} + + if not is_empty_result: # 只有在不是空结果时才返回成功 + query_status = ( + "第一次查询成功" if current_retry == 0 else f"第{current_retry+1}次查询成功(RAG重写后)" + ) + print(f"\n{query_status}") + return { + "code": 0, + "message": "成功", + "data": { + "value": json.dumps(attrs, ensure_ascii=False, indent=2), + "code": original_code, + }, + "query_status": query_status, + } + + # 如果没有对象属性但有清理后的输出,且不是空结果 + if ( + clean_output + and clean_output.lower() != "none" + and clean_output != "[]" + and "未找到" not in clean_output + and not is_empty_result + ): + query_status = ( + "第一次查询成功" if current_retry == 0 else f"第{current_retry+1}次查询成功(RAG重写后)" + ) + print(f"\n{query_status}") + return { + "code": 0, + "message": "成功", + "data": {"value": clean_output, "code": original_code}, + "query_status": query_status, + } + + finally: + # 恢复stdout + sys.stdout = old_stdout + + except Exception as e: + import traceback + + error_details = traceback.format_exc() + print(f"\n执行代码时出错: {error_details}") + + # 使用实际的错误信息创建error_info + error_info = { + "error": str(e), # 使用实际异常消息 + "helper_info": [], # 空的辅助信息 + "traceback": error_details, # 添加完整的堆栈跟踪 + } + + # 如果走到这里,说明结果为空或未找到匹配项,应该执行RAG重写流程 + print("\n查询未找到结果,尝试定位具体缺失节点...") + + # 解析原始查询路径中的最后一个节点名 + import re + + match = re.search(r"【([^】]+)】\s*$", original_query) + missing_node = match.group(1) if match else "未知节点" + + if current_retry < max_retries: + print("\n尝试使用RAG重写查询...") + try: + # 使用提取的值重写 + rewritten = rewrite_query_parameters(original_code, error_info) + + # 检查重写是否成功 + if rewritten and "code" in rewritten and rewritten["code"]: + print(f"\n重写成功,新代码:\n{rewritten['code']}") + + # 增加重试计数 + current_retry += 1 + + # 使用重写后的代码进行下一次迭代 + code = rewritten["code"] + continue # 重要!继续下一次循环迭代 + else: + print("\n重写未返回有效代码") + + except Exception as e: + print(f"\nRAG重写失败: {e}") + # 尝试使用原始查询 + try: + rewritten = rewrite_query_parameters(original_code, error_info) + + # 检查重写是否成功 + if rewritten and "code" in rewritten and rewritten["code"]: + print(f"\n第二次重写成功,新代码:\n{rewritten['code']}") + + # 增加重试计数 + current_retry += 1 + + # 使用重写后的代码进行下一次迭代 + code = rewritten["code"] + continue # 重要!继续下一次循环迭代 + else: + print("\n第二次重写未返回有效代码") + except Exception as e2: + print(f"再次重写失败: {e2}") + # 继续执行,返回原始错误 + + # RAG重写失败或未产生新代码,返回原始错误 + query_status = ( + "第一次查询失败,RAG重写也失败" + if current_retry == 0 + else f"第{current_retry+1}次查询失败,RAG重写也失败" + ) + print(f"\n{query_status}") + return { + "code": 1, + "message": f"{missing_node} 未找到,请检查该节点是否存在。", + "data": {"value": "", "code": original_code}, + "error_info": error_info, + "query_status": query_status, + } + + # 如果所有重试都失败 + print("\n所有重试都失败,无法找到匹配的结果") + query_status = "所有重试都失败" + return { + "code": 1, + "message": "所有重试都失败,无法找到匹配的结果", + "data": {"value": "", "code": original_code}, + "query_status": query_status, + } + + +def format_result(result): + """ + 格式化查询结果 + + Args: + result: 查询结果(可能为 list、dict 或其他类型) + + Returns: + str: 格式化后的结果 + """ + # 处理 project 对象 + if hasattr(result, "__module__") and result.__module__ == "project": + # 这是一个 project 模块中的对象 + attrs = {k: v for k, v in result.__dict__.items() if not k.startswith("_")} + return json.dumps(attrs, ensure_ascii=False, indent=2) + + # 处理 project 对象列表 + if isinstance(result, list) and all( + hasattr(item, "__module__") and item.__module__ == "project" for item in result if hasattr(item, "__module__") + ): + formatted_items = [] + for item in result: + if hasattr(item, "__dict__"): + attrs = {k: v for k, v in item.__dict__.items() if not k.startswith("_")} + formatted_items.append(attrs) + else: + formatted_items.append(str(item)) + + return json.dumps(formatted_items, ensure_ascii=False, indent=2) + + # 如果结果是字符串,可能包含调试信息,需要提取有用部分 + if isinstance(result, str): + # 尝试提取最终结果部分 + if "[]" in result: + return "未找到匹配的数据。" + + # 如果包含节点信息,提取关键部分 + import re + + node_match = re.search(r"找到.*?labels=.*?properties=(.*?)>", result) + if node_match: + try: + # 提取属性部分并格式化 + props_str = node_match.group(1).replace("'", '"') + import ast + + props = ast.literal_eval(props_str) + + formatted = "找到节点:\n" + for k, v in props.items(): + formatted += f" {k}: {v}\n" + return formatted + except: + pass + + # 如果包含查询结果数量 + count_match = re.search(r"查询结果数量: (\d+)", result) + if count_match: + count = count_match.group(1) + if count == "0": + return "未找到匹配的数据。" + + # 如果是列表 + if isinstance(result, list): + if not result: + return "未找到匹配的数据。" + + lines = [f"找到 {len(result)} 条匹配结果:"] + for i, item in enumerate(result, 1): + lines.append(f"\n结果 {i}:") + if hasattr(item, "items"): # 检查是否有items方法(字典或类似字典的对象) + try: + for k, v in item.items(): + lines.append(f" {k}: {v}") + except: + lines.append(f" {item}") + else: + lines.append(f" {item}") + return "\n".join(lines) + + # 如果是字典 + elif isinstance(result, dict): + lines = ["查询结果:"] + for k, v in result.items(): + lines.append(f" {k}: {v}") + return "\n".join(lines) + + # 其他类型 + else: + return str(result) + + +def format_dict_or_item(item): + """ + 格式化字典或其他对象 + + Args: + item: 字典或其他对象 + + Returns: + str: 格式化后的字符串 + """ + if isinstance(item, dict): + formatted = "" + for key, value in item.items(): + formatted += f" {key}: {value}\n" + return formatted + + return str(item) + + +def _extract_all_values(item, result_list): + """ + 递归提取字典中的所有值并添加到结果列表中 + + Args: + item: 字典或其他对象 + result_list: 用于存储结果的列表 + """ + if isinstance(item, dict): + for key, value in item.items(): + if key != "children": # 跳过children键,它会在主逻辑中单独处理 + if isinstance(value, str): + result_list.append(value) + elif isinstance(value, (dict, list)): + _extract_all_values(value, result_list) + elif isinstance(item, list): + for element in item: + _extract_all_values(element, result_list) + + +# 定义一个辅助函数来提取和处理相似参数 +def extract_similar_parameters(path_parts, knowledge_base): + """ + 从路径和知识库中提取相似参数 + + Args: + path_parts (list): 路径部分列表 + knowledge_base (list): 知识库数据 + + Returns: + str: 相似参数字符串 + """ + # 从路径中提取关键词 + extracted_parts = [] + for part in path_parts: + if "/" in part: + last_part = part.split("/")[-1].strip() + if last_part: + extracted_parts.append(last_part) + else: + extracted_parts.append(part) + + # 收集所有可能的相似参数 + similar_params = [] + + # 从knowledge_base中查找相关项及其子节点 + for part in extracted_parts: + for item in knowledge_base: + # 检查当前项是否匹配 + match_found = False + for key, value in item.items(): + if isinstance(value, str) and part.lower() in value.lower(): + match_found = True + break + + if match_found: + # 如果找到匹配项,提取所有值 + _extract_all_values(item, similar_params) + + # 特别处理子节点 + if "children" in item and isinstance(item["children"], list): + for child in item["children"]: + _extract_all_values(child, similar_params) + + # 移除重复项并排序 + similar_params = list(set(similar_params)) + similar_params.sort() + + print(f"找到的相似参数: {similar_params}") + + # 创建一个包含所有相似参数的字符串 + return ", ".join(similar_params) + + +# 在查询前先增加一个简单函数,专门提取字符串中的键值对值 +def extract_values_from_kb_string(kb_string): + """从知识库字符串中提取所有键值对的值""" + import re + + # 匹配所有键值对:"key": "value" 的模式 + # 这里我们直接取第二个捕获组,即值部分 + values = re.findall(r'"([^"]+)"\s*:\s*"([^"]+)"', kb_string) + + # 只保留值(第二个元素) + result = [match[1] for match in values] + + return result + + +question = { + "type": "query", + "value": "查找一下【工程数据/安装工程/安装/架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立】的类型为【主材】的【塔材】", +} +result = nl_query_to_function_call(question) +print(result) diff --git a/prompt_templates.py b/prompt_templates.py new file mode 100644 index 0000000..1392816 --- /dev/null +++ b/prompt_templates.py @@ -0,0 +1,76 @@ +from langchain.prompts import PromptTemplate + +FUNCTION_CALL_TEMPLATE = """ +你是一个专业的Python工程师。我会给你一个用户问题,你需要将其转换为对应的Python代码 + +可用工具: +{tools} + +工具名称: +{tool_names} + +# 工作流程 +1. 从用户问题中{query}提取关键信息(节点路径、节点类型、节点名称等) +2. 使用工具查询知识图谱结构以理解可用节点和节点属性 +3. 根据查询结果选择最匹配的{project_class_methods}中的方法 +4. 生成可直接执行的Python代码 + +# 代码模板(必须严格遵循) +def neo4j_find_function(): + project = ProjectBuilder.build() + status, data, error, helper_info = project.[SELECTED_METHOD]([PARAMETERS]) + return status, data, error, helper_info + +# 执行规则 +- 每次只能调用一个工具或生成最终代码 +- 参数必须从用户问题或知识图谱查询结果中提取 +- 必须确保生成的代码可以直接执行 +- 禁止修改代码模板结构 +- 禁止添加任何注释或解释 + +# 当前进度 +{agent_scratchpad} + +# 响应格式 +思考: 分析当前步骤需要做什么 +行动: 选择工具名称 +行动输入: 工具参数 +观察: 工具返回结果 + +...(重复直到准备好生成代码)... + +思考: 已收集足够信息,可以生成代码 +Final Answer: +def neo4j_find_function(): + project = ProjectBuilder.build() + status, data, error, helper_info = project.[SELECTED_METHOD]([PARAMETERS]) + return status, data, error, helper_info +""" + +FUNCTION_CALL_PROMPT = PromptTemplate.from_template(FUNCTION_CALL_TEMPLATE) + + +########################################################################################################################################################################### + +FUNCTION_RETURNS_LOOP_TEMPLATE = """ + +你是一个专业的Python工程师。我会给你一段错误python代码和错误信息,你需要帮我修复这段出错的代码 + +你的任务是: +1. 根据需要修改的代码{original_code}和代码的错误信息{error_info}来对代码和参数进行修改 +2. 如果错误信息中是代码的逻辑出现错误,那么就需要对代码本身整体结构进行修改 +3. 如果是代码中参数出现问题了,那么就需要结合错误信息中的帮助信息(helper_info)来对代码总的参数进行修改 +4. 修复后的代码应该完整,可以直接执行,并且能够返回查询结果 + +注意: +- 必须只输出最终的Python代码,不要添加任何解释、注释、推理过程或自然语言描述。 +- 不要以“以下是修正后的代码”、“修改如下”等语句开头。 +- 不要输出任何其他无关的内容。 +- 输出格式必须完全符合指定的函数模板。 +- 如果无法根据已有信息进行修改,请原样返回原始代码。 + +请输出你修补后的代码: +""" + + +FUNCTION_RETURNS_LOOP_PROMPT: PromptTemplate = PromptTemplate.from_template(FUNCTION_RETURNS_LOOP_TEMPLATE)