From d0ac7d00895fe6e9a933d60cba333e8da52b6fcd Mon Sep 17 00:00:00 2001 From: chentianrui Date: Thu, 19 Jun 2025 16:53:24 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=A0=E9=99=A4=20langchain=5Fneo4j.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- langchain_neo4j.py | 732 --------------------------------------------- 1 file changed, 732 deletions(-) delete mode 100644 langchain_neo4j.py diff --git a/langchain_neo4j.py b/langchain_neo4j.py deleted file mode 100644 index 1815cb7..0000000 --- a/langchain_neo4j.py +++ /dev/null @@ -1,732 +0,0 @@ -from langchain.chains import LLMChain -from langchain_openai import OpenAI -from langchain_experimental.utilities import PythonREPL -from project_implementation import ProjectBuilder -from prompt_templates import FUNCTION_CALL_PROMPT -import inspect -import project -import io -import sys -from parameter_rewriting import rewrite_query_parameters, KnowledgeGraphProcessor -import json -from langchain.agents import Tool, AgentExecutor, create_react_agent - -from llm import llm - - -# 获取ProjectTookiIt类的方法定义 -def get_project_class_methods(): - """ - 从project模块中提取ProjectTookiIt类的方法定义 - - Returns: - str: 格式化后的方法定义字符串 - """ - project_class_code = inspect.getsource(project.ProjectTookiIt) - - lines = project_class_code.split("\n") - result_lines = [] - - in_class = False - skip_init = False - - for line in lines: - if line.strip().startswith("class ProjectTookiIt"): - in_class = True - result_lines.append(line) - elif in_class: - if line.strip().startswith("def __init__"): - skip_init = True - elif skip_init and line.strip() and not line.startswith(" " * 8): - skip_init = False - - if not skip_init: - result_lines.append(line) - - return "\n".join(result_lines) - - -# 创建动态提示模板 -project_class_methods = get_project_class_methods() - -# 创建 Chain -function_call_chain = LLMChain(llm=llm, prompt=FUNCTION_CALL_PROMPT, output_key="code") - -# Python 执行器 -repl = PythonREPL() - -# 创建知识图谱处理器实例 -kg_processor = KnowledgeGraphProcessor() - - -# 定义搜索知识库的工具 -def search_knowledge_and_node_definition(query): - """ - 在知识库中搜索关键词 - - Args: - query (str): 搜索关键词 - - Returns: - str: 搜索结果的JSON字符串 - """ - found_data = kg_processor._get_relevant_knowledge(query) - if found_data: - return json.dumps(found_data, ensure_ascii=False, indent=2) - else: - return f"未找到与'{query}'相关的信息" - - -# 创建工具列表 -tools = [ - Tool( - name="search_knowledge_and_node_definition", - func=search_knowledge_and_node_definition, - description="获取输入节点的知识图谱结构和对应节点定义类型代码。输入应该是一个节点类型名称。", - ), -] - -# 创建Agent -agent = create_react_agent(llm, tools, FUNCTION_CALL_PROMPT) - -# 创建Agent执行器 -agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True) - - -def nl_query_to_function_call(input_data): - """ - 将自然语言查询转换为函数调用并执行,或直接执行提供的代码 - - Args: - input_data (dict): 包含type和value的字典 - { - "type": "query|code", - "value": "查询内容或代码" - } - - Returns: - dict: 包含状态码、消息和数据的字典 - """ - input_type = input_data.get("type", "query") - input_value = input_data.get("value", "") - max_retries = 1 # 设置最大重试次数 - current_retry = 0 - - original_query = input_value # 保存原始查询用于RAG - - print(f"\n====== 开始处理查询 ======") - print(f"查询类型: {input_type}") - print(f"查询内容: {input_value}") - - while current_retry <= max_retries: - print(f"\n----- 尝试 #{current_retry + 1} -----") - - # 如果type是query,使用LLM生成代码 - if input_type == "query" and current_retry == 0: - - # 使用Agent执行查询 - agent_response = agent_executor.invoke( - { - "query": input_value, - "project_class_methods": project_class_methods, - } - ) - - # 从Agent响应中提取代码 - code = agent_response["output"] - print(f"\n生成的代码:\n{code}") - else: - print(f"\n使用重写后的代码:\n{code}") - code = code - - # 保存原始代码用于返回 - original_code = code - - # 执行生成的函数并捕获输出 - try: - # 创建一个新的命名空间来执行代码,包含必要的导入 - namespace = { - "ProjectBuilder": ProjectBuilder, # 添加ProjectBuilder到命名空间 - "project_implementation": __import__("project_implementation"), - "project": __import__("project"), - } - - # 执行生成的代码,定义neo4j_find_function函数 - exec(code, namespace) - - # 重定向stdout来捕获print输出 - old_stdout = sys.stdout - redirected_output = io.StringIO() - sys.stdout = redirected_output - - try: - # 执行函数并获取元组结果 - result_tuple = namespace["neo4j_find_function"]() - - # 确保结果是元组且包含4个元素 - if not isinstance(result_tuple, tuple) or len(result_tuple) != 4: - raise ValueError("函数应返回包含4个元素的元组(status, data, error, helper_info)") - - status, data, error, helper_info = result_tuple - - # 获取捕获的输出(如果有) - output = redirected_output.getvalue().strip() - - # 根据状态处理结果 - if status == "success": - query_status = ( - "第一次查询成功" if current_retry == 0 else f"第{current_retry+1}次查询成功(RAG重写后)" - ) - print(f"\n{query_status}") - return { - "code": 0, - "message": "成功", - "data": { - "value": data, # 直接使用data字典,不再需要JSON转换 - "code": original_code, - }, - "query_status": query_status, - } - else: - # 错误情况 - error_info = { - "error": error, - "helper_info": helper_info, - } - - # 重试逻辑 - if current_retry < max_retries: - print("\n尝试使用RAG重写查询...") - try: - rewritten = rewrite_query_parameters(original_code, error_info) - if rewritten and "code" in rewritten and rewritten["code"]: - print(f"\n重写成功,新代码:\n{rewritten['code']}") - current_retry += 1 - code = rewritten["code"] - continue # 重要!继续下一次循环迭代 - except Exception as e: - print(f"\nRAG重写失败: {e}") - # 尝试使用原始查询 - try: - rewritten = rewrite_query_parameters(original_code, error_info) - except Exception as e2: - print(f"再次重写失败: {e2}") - # 继续执行,返回原始错误 - - # 返回错误信息 - query_status = "第一次查询失败" if current_retry == 0 else f"第{current_retry+1}次查询失败" - print(f"\n{query_status}") - return { - "code": 1, - "message": error, - "data": {"value": "", "code": original_code}, - "error_info": error_info, - "query_status": query_status, - } - - # 如果不是新格式,则按照原有逻辑处理 - # 检查结果是否为空 - is_empty_result = True # 默认假设为空 - - # 先检查是否有明显的数据内容 - if output and output.strip().startswith("[{") and output.strip().endswith("}]"): - # 看起来是一个对象数组,尝试转换为有效的JSON格式 - try: - # 将单引号替换为双引号以使其成为有效的JSON - # 但要小心处理嵌套的引号 - import ast - - # 使用ast.literal_eval安全地将Python字符串表示转换为Python对象 - parsed_obj = ast.literal_eval(output) - if parsed_obj and len(parsed_obj) > 0: - is_empty_result = False - except: - # 如果解析失败,继续尝试其他方法 - pass - - # 如果上面的方法没有确定结果不为空,继续尝试JSON解析 - if is_empty_result: - try: - if output.strip(): - parsed_output = json.loads(output) - if parsed_output and ( - isinstance(parsed_output, list) - and len(parsed_output) > 0 - or isinstance(parsed_output, dict) - and len(parsed_output) > 0 - ): - is_empty_result = False - except json.JSONDecodeError: - # 不是有效的JSON,使用其他判断方法 - is_empty_result = ( - not output - or output.lower() == "none" - or output == "[]" - or "未找到" in output - or "None" in output - or result is None - ) - - # 如果结果为空,走重写流程 - if is_empty_result: - # 创建错误信息 - error_info = { - "error": "查询结果为空", - "helper_info": [], - "traceback": "查询执行成功但未返回数据", - } - - if current_retry < max_retries: - print("\n尝试使用RAG重写查询...") - try: - rewritten = rewrite_query_parameters(original_code, error_info) - - # 检查重写是否成功 - if rewritten and "code" in rewritten and rewritten["code"]: - print(f"\n重写成功,新代码:\n{rewritten['code']}") - - # 增加重试计数 - current_retry += 1 - - # 使用重写后的代码进行下一次迭代 - code = rewritten["code"] - continue # 重要!继续下一次循环迭代 - else: - print("\n重写未返回有效代码") - - except Exception as e: - print(f"\nRAG重写失败: {e}") - # 尝试使用原始查询 - try: - rewritten = rewrite_query_parameters(original_code, error_info) - - # 检查重写是否成功 - if rewritten and "code" in rewritten and rewritten["code"]: - print(f"\n第二次重写成功,新代码:\n{rewritten['code']}") - - # 增加重试计数 - current_retry += 1 - - # 使用重写后的代码进行下一次迭代 - code = rewritten["code"] - continue # 重要!继续下一次循环迭代 - else: - print("\n第二次重写未返回有效代码") - except Exception as e2: - print(f"再次重写失败: {e2}") - # 继续执行,返回原始错误 - - # RAG重写失败或未产生新代码,返回原始错误 - query_status = ( - "第一次查询失败,RAG重写也失败" - if current_retry == 0 - else f"第{current_retry+1}次查询失败,RAG重写也失败" - ) - print(f"\n{query_status}") - return { - "code": 1, - "message": "未找到匹配的节点,请检查该节点是否存在。", - "data": {"value": "", "code": original_code}, - "error_info": error_info, - "query_status": query_status, - } - - # 清理输出,只保留有用的结果部分 - clean_output = output - - # 如果输出包含查询结果数量和对象引用 - if "查询结果数量:" in output and "", output, re.DOTALL) - if node_match: - props_str = node_match.group(1).replace("'", '"') - try: - import ast - - props = ast.literal_eval(props_str) - clean_output = json.dumps(props, ensure_ascii=False, indent=2) - except: - pass - - # 如果有查询结果数量信息 - count_match = re.search(r"查询结果数量: (\d+)", output) - if count_match: - count = count_match.group(1) - if count == "0": - clean_output = "未找到匹配的数据。" - is_empty_result = True - elif not node_match: # 如果没有提取到节点属性但有结果 - clean_output = f"找到 {count} 条匹配结果" - - # 检查结果对象 - if result is not None: - if isinstance(result, list): - if not result: # 空列表 - is_empty_result = True - else: - # 处理非空列表 - formatted_items = [] - for item in result: - if hasattr(item, "__dict__"): - # 提取对象的所有属性 - attrs = {k: v for k, v in item.__dict__.items() if not k.startswith("_")} - formatted_items.append(attrs) - else: - formatted_items.append(str(item)) - - if not is_empty_result: # 只有在不是空结果时才返回成功 - query_status = ( - "第一次查询成功" - if current_retry == 0 - else f"第{current_retry+1}次查询成功(RAG重写后)" - ) - print(f"\n{query_status}") - return { - "code": 0, - "message": "成功", - "data": { - "value": json.dumps(formatted_items, ensure_ascii=False, indent=2), - "code": original_code, - }, - "query_status": query_status, - } - elif hasattr(result, "__dict__"): - # 单个对象 - attrs = {k: v for k, v in result.__dict__.items() if not k.startswith("_")} - - if not is_empty_result: # 只有在不是空结果时才返回成功 - query_status = ( - "第一次查询成功" if current_retry == 0 else f"第{current_retry+1}次查询成功(RAG重写后)" - ) - print(f"\n{query_status}") - return { - "code": 0, - "message": "成功", - "data": { - "value": json.dumps(attrs, ensure_ascii=False, indent=2), - "code": original_code, - }, - "query_status": query_status, - } - - # 如果没有对象属性但有清理后的输出,且不是空结果 - if ( - clean_output - and clean_output.lower() != "none" - and clean_output != "[]" - and "未找到" not in clean_output - and not is_empty_result - ): - query_status = ( - "第一次查询成功" if current_retry == 0 else f"第{current_retry+1}次查询成功(RAG重写后)" - ) - print(f"\n{query_status}") - return { - "code": 0, - "message": "成功", - "data": {"value": clean_output, "code": original_code}, - "query_status": query_status, - } - - finally: - # 恢复stdout - sys.stdout = old_stdout - - except Exception as e: - import traceback - - error_details = traceback.format_exc() - print(f"\n执行代码时出错: {error_details}") - - # 使用实际的错误信息创建error_info - error_info = { - "error": str(e), # 使用实际异常消息 - "helper_info": [], # 空的辅助信息 - "traceback": error_details, # 添加完整的堆栈跟踪 - } - - # 如果走到这里,说明结果为空或未找到匹配项,应该执行RAG重写流程 - print("\n查询未找到结果,尝试定位具体缺失节点...") - - # 解析原始查询路径中的最后一个节点名 - import re - - match = re.search(r"【([^】]+)】\s*$", original_query) - missing_node = match.group(1) if match else "未知节点" - - if current_retry < max_retries: - print("\n尝试使用RAG重写查询...") - try: - # 使用提取的值重写 - rewritten = rewrite_query_parameters(original_code, error_info) - - # 检查重写是否成功 - if rewritten and "code" in rewritten and rewritten["code"]: - print(f"\n重写成功,新代码:\n{rewritten['code']}") - - # 增加重试计数 - current_retry += 1 - - # 使用重写后的代码进行下一次迭代 - code = rewritten["code"] - continue # 重要!继续下一次循环迭代 - else: - print("\n重写未返回有效代码") - - except Exception as e: - print(f"\nRAG重写失败: {e}") - # 尝试使用原始查询 - try: - rewritten = rewrite_query_parameters(original_code, error_info) - - # 检查重写是否成功 - if rewritten and "code" in rewritten and rewritten["code"]: - print(f"\n第二次重写成功,新代码:\n{rewritten['code']}") - - # 增加重试计数 - current_retry += 1 - - # 使用重写后的代码进行下一次迭代 - code = rewritten["code"] - continue # 重要!继续下一次循环迭代 - else: - print("\n第二次重写未返回有效代码") - except Exception as e2: - print(f"再次重写失败: {e2}") - # 继续执行,返回原始错误 - - # RAG重写失败或未产生新代码,返回原始错误 - query_status = ( - "第一次查询失败,RAG重写也失败" - if current_retry == 0 - else f"第{current_retry+1}次查询失败,RAG重写也失败" - ) - print(f"\n{query_status}") - return { - "code": 1, - "message": f"{missing_node} 未找到,请检查该节点是否存在。", - "data": {"value": "", "code": original_code}, - "error_info": error_info, - "query_status": query_status, - } - - # 如果所有重试都失败 - print("\n所有重试都失败,无法找到匹配的结果") - query_status = "所有重试都失败" - return { - "code": 1, - "message": "所有重试都失败,无法找到匹配的结果", - "data": {"value": "", "code": original_code}, - "query_status": query_status, - } - - -def format_result(result): - """ - 格式化查询结果 - - Args: - result: 查询结果(可能为 list、dict 或其他类型) - - Returns: - str: 格式化后的结果 - """ - # 处理 project 对象 - if hasattr(result, "__module__") and result.__module__ == "project": - # 这是一个 project 模块中的对象 - attrs = {k: v for k, v in result.__dict__.items() if not k.startswith("_")} - return json.dumps(attrs, ensure_ascii=False, indent=2) - - # 处理 project 对象列表 - if isinstance(result, list) and all( - hasattr(item, "__module__") and item.__module__ == "project" for item in result if hasattr(item, "__module__") - ): - formatted_items = [] - for item in result: - if hasattr(item, "__dict__"): - attrs = {k: v for k, v in item.__dict__.items() if not k.startswith("_")} - formatted_items.append(attrs) - else: - formatted_items.append(str(item)) - - return json.dumps(formatted_items, ensure_ascii=False, indent=2) - - # 如果结果是字符串,可能包含调试信息,需要提取有用部分 - if isinstance(result, str): - # 尝试提取最终结果部分 - if "[]" in result: - return "未找到匹配的数据。" - - # 如果包含节点信息,提取关键部分 - import re - - node_match = re.search(r"找到.*?labels=.*?properties=(.*?)>", result) - if node_match: - try: - # 提取属性部分并格式化 - props_str = node_match.group(1).replace("'", '"') - import ast - - props = ast.literal_eval(props_str) - - formatted = "找到节点:\n" - for k, v in props.items(): - formatted += f" {k}: {v}\n" - return formatted - except: - pass - - # 如果包含查询结果数量 - count_match = re.search(r"查询结果数量: (\d+)", result) - if count_match: - count = count_match.group(1) - if count == "0": - return "未找到匹配的数据。" - - # 如果是列表 - if isinstance(result, list): - if not result: - return "未找到匹配的数据。" - - lines = [f"找到 {len(result)} 条匹配结果:"] - for i, item in enumerate(result, 1): - lines.append(f"\n结果 {i}:") - if hasattr(item, "items"): # 检查是否有items方法(字典或类似字典的对象) - try: - for k, v in item.items(): - lines.append(f" {k}: {v}") - except: - lines.append(f" {item}") - else: - lines.append(f" {item}") - return "\n".join(lines) - - # 如果是字典 - elif isinstance(result, dict): - lines = ["查询结果:"] - for k, v in result.items(): - lines.append(f" {k}: {v}") - return "\n".join(lines) - - # 其他类型 - else: - return str(result) - - -def format_dict_or_item(item): - """ - 格式化字典或其他对象 - - Args: - item: 字典或其他对象 - - Returns: - str: 格式化后的字符串 - """ - if isinstance(item, dict): - formatted = "" - for key, value in item.items(): - formatted += f" {key}: {value}\n" - return formatted - - return str(item) - - -def _extract_all_values(item, result_list): - """ - 递归提取字典中的所有值并添加到结果列表中 - - Args: - item: 字典或其他对象 - result_list: 用于存储结果的列表 - """ - if isinstance(item, dict): - for key, value in item.items(): - if key != "children": # 跳过children键,它会在主逻辑中单独处理 - if isinstance(value, str): - result_list.append(value) - elif isinstance(value, (dict, list)): - _extract_all_values(value, result_list) - elif isinstance(item, list): - for element in item: - _extract_all_values(element, result_list) - - -# 定义一个辅助函数来提取和处理相似参数 -def extract_similar_parameters(path_parts, knowledge_base): - """ - 从路径和知识库中提取相似参数 - - Args: - path_parts (list): 路径部分列表 - knowledge_base (list): 知识库数据 - - Returns: - str: 相似参数字符串 - """ - # 从路径中提取关键词 - extracted_parts = [] - for part in path_parts: - if "/" in part: - last_part = part.split("/")[-1].strip() - if last_part: - extracted_parts.append(last_part) - else: - extracted_parts.append(part) - - # 收集所有可能的相似参数 - similar_params = [] - - # 从knowledge_base中查找相关项及其子节点 - for part in extracted_parts: - for item in knowledge_base: - # 检查当前项是否匹配 - match_found = False - for key, value in item.items(): - if isinstance(value, str) and part.lower() in value.lower(): - match_found = True - break - - if match_found: - # 如果找到匹配项,提取所有值 - _extract_all_values(item, similar_params) - - # 特别处理子节点 - if "children" in item and isinstance(item["children"], list): - for child in item["children"]: - _extract_all_values(child, similar_params) - - # 移除重复项并排序 - similar_params = list(set(similar_params)) - similar_params.sort() - - print(f"找到的相似参数: {similar_params}") - - # 创建一个包含所有相似参数的字符串 - return ", ".join(similar_params) - - -# 在查询前先增加一个简单函数,专门提取字符串中的键值对值 -def extract_values_from_kb_string(kb_string): - """从知识库字符串中提取所有键值对的值""" - import re - - # 匹配所有键值对:"key": "value" 的模式 - # 这里我们直接取第二个捕获组,即值部分 - values = re.findall(r'"([^"]+)"\s*:\s*"([^"]+)"', kb_string) - - # 只保留值(第二个元素) - result = [match[1] for match in values] - - return result - - -question = { - "type": "query", - "value": "查找一下【工程数据/安装工程/安装/架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立】的类型为【主材】的【塔材】", -} -result = nl_query_to_function_call(question) -print(result)