import json import ast import inspect from typing import Dict, Any, List, Optional, Union from langchain.agents import Tool, initialize_agent, AgentType from langchain.chat_models import ChatOpenAI from langchain.prompts import PromptTemplate from langchain.schema import BaseOutputParser from langchain.schema import BaseRetriever from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.schema import Document from langchain.chains import LLMChain from langchain.llms.base import BaseLLM from prompt_templates import FUNCTION_RETURNS_LOOP_PROMPT from llm import llm as base_llm import project # 明确导入project模块 class CodeOnlyOutputParser(BaseOutputParser): def parse(self, text: str) -> dict: if "Final Answer:" in text: code_part = text.split("Final Answer:")[-1].strip() else: code_part = text.strip() return {"code": code_part} def get_format_instructions(self) -> str: return "只输出最终的Python代码,不要包含其他解释或内容。" class KnowledgeGraphProcessor: """知识图谱处理器""" def __init__(self, kg_file="kg_simple_hierarchy.json"): """初始化处理器""" self.kg_data = self._load_kg_data(kg_file) self.prompt = FUNCTION_RETURNS_LOOP_PROMPT def _load_kg_data(self, file_path: str) -> Dict[str, Any]: """加载知识图谱数据""" try: with open(file_path, "r", encoding="utf-8") as f: return json.load(f) except FileNotFoundError: print(f"警告: 未找到 {file_path} 文件") return {} except Exception as e: print(f"加载知识图谱数据时出错: {e}") return {} def _extract_parameters_from_code(self, code: str) -> List[str]: """从代码中提取字符串参数""" try: tree = ast.parse(code) parameters = [] for node in ast.walk(tree): if isinstance(node, ast.Str): parameters.append(node.s) elif isinstance(node, ast.Constant) and isinstance(node.value, str): parameters.append(node.value) return parameters except Exception as e: print(f"解析代码时出错: {e}") return [] def _find_node_by_path(self, path: str, data: Union[Dict[str, Any], List[Any]] = None) -> Optional[Dict[str, Any]]: """ 根据路径查找节点 Args: path: 节点路径,格式如 "工程数据/安装过程/安装/架空输电线路本体工程/杆塔工程" data: 当前数据节点 Returns: 找到的节点及其子节点 """ if data is None: data = self.kg_data path_parts = path.split("/") current_part = path_parts[0] # 如果data是列表,则遍历列表中的每个元素 if isinstance(data, list): for item in data: result = self._find_node_by_path(path, item) if result: return result return None # 如果data是字典,则按原来的逻辑处理 if isinstance(data, dict): # 检查当前节点是否匹配 found_node = None for key, value in data.items(): if isinstance(value, str) and current_part in value: found_node = data break # 如果在当前层级找到了节点 if found_node: # 如果还有更深的路径部分,则继续递归查找 if len(path_parts) > 1: # 查找子节点 if "children" in found_node: for child in found_node["children"]: result = self._find_node_by_path("/".join(path_parts[1:]), child) if result: return result return None else: # 已经找到最终节点 return found_node # 如果当前层级没找到,则查找子节点 if "children" in data: for child in data["children"]: result = self._find_node_by_path(path, child) if result: return result return None def _search_in_kg(self, parameter: str, data: Union[Dict[str, Any], List[Any]] = None) -> Optional[Dict[str, Any]]: """在知识图谱中递归搜索参数""" # 先尝试按路径查找 if "/" in parameter: result = self._find_node_by_path(parameter) if result: return result # 如果路径查找失败,则按关键词搜索 if data is None: data = self.kg_data # 如果是列表,遍历每个元素 if isinstance(data, list): for item in data: if isinstance(item, (dict, list)): result = self._search_in_kg(parameter, item) if result: return result return None # 如果是字典,按原来的逻辑处理 if isinstance(data, dict): for key, value in data.items(): if isinstance(value, str) and parameter in value: return data elif isinstance(value, list): for item in value: if isinstance(item, (dict, list)): result = self._search_in_kg(parameter, item) if result: return result elif isinstance(value, dict): result = self._search_in_kg(parameter, value) if result: return result return None def _get_node_definition(self, node_type: str) -> str: """获取节点类型定义""" try: # 检查project模块中是否存在该类 if hasattr(project, node_type): cls = getattr(project, node_type) source = inspect.getsource(cls) return source else: return f"未找到类型 {node_type} 的定义" except NameError: return f"获取类型 {node_type} 定义时出错: project模块未正确导入" except Exception as e: return f"获取类型 {node_type} 定义时出错: {e}" def _extract_node_types(self, data: Dict[str, Any], node_types: set): """递归提取所有节点类型""" if not isinstance(data, dict): return for key, value in data.items(): # 所有键都可能是节点类型 node_types.add(key) if isinstance(value, dict): self._extract_node_types(value, node_types) elif isinstance(value, list): for item in value: if isinstance(item, dict): self._extract_node_types(item, node_types) def _get_relevant_knowledge(self, code: str) -> tuple[str, str]: """根据代码获取相关知识库内容和节点定义""" parameters = self._extract_parameters_from_code(code) knowledge_base = "" node_definitions = "" for param in parameters: # 处理【】格式的路径,提取最后一个部分 if "【" in param: path_parts = [] parts = param.split("【") for part in parts: if "】" in part: clean_part = part.split("】")[0].strip() if clean_part: path_parts.append(clean_part) # 对于每个路径,只取最后一个部分 for path in path_parts: if "/" in path: # 提取路径中的最后一个部分 last_part = path.split("/")[-1].strip() if last_part: found_data = self._search_in_kg(last_part) if found_data: knowledge_base += f"节点 '{last_part}' 相关信息:\n" knowledge_base += json.dumps(found_data, ensure_ascii=False, indent=2) + "\n\n" # 提取节点类型 node_types = set() self._extract_node_types(found_data, node_types) for node_type in node_types: definition = self._get_node_definition(node_type) if definition and "未找到" not in definition and "出错" not in definition: node_definitions += f"类型 {node_type} 定义:\n{definition}\n\n" else: # 如果没有/,直接使用整个部分 found_data = self._search_in_kg(path) if found_data: knowledge_base += f"节点 '{path}' 相关信息:\n" knowledge_base += json.dumps(found_data, ensure_ascii=False, indent=2) + "\n\n" # 提取节点类型 node_types = set() self._extract_node_types(found_data, node_types) for node_type in node_types: definition = self._get_node_definition(node_type) if definition and "未找到" not in definition and "出错" not in definition: node_definitions += f"类型 {node_type} 定义:\n{definition}\n\n" else: # 处理普通参数 found_data = self._search_in_kg(param) if found_data: knowledge_base += f"参数 '{param}' 相关信息:\n" knowledge_base += json.dumps(found_data, ensure_ascii=False, indent=2) + "\n\n" # 提取节点类型 node_types = set() self._extract_node_types(found_data, node_types) for node_type in node_types: definition = self._get_node_definition(node_type) if definition and "未找到" not in definition and "出错" not in definition: node_definitions += f"类型 {node_type} 定义:\n{definition}\n\n" return knowledge_base, node_definitions def process_query( self, original_query: str, original_code: str, error_info: Dict[str, str] = None ) -> Dict[str, str]: """ 处理查询并尝试修复错误 Args: original_query: 原始查询字符串 original_code: 原始代码字符串 error_info: 错误信息字典 (可选) Returns: Dict: 包含修改后的query和code的字典 """ print("\n====== RAG重写查询开始 ======") print(f"原始查询: {original_query}") print(f"原始代码: {original_code}") print(f"错误信息: {error_info}") # 初始化工具 tools = [ Tool( name="search_knowledge_base", func=lambda x: json.dumps(self._search_in_kg(x), ensure_ascii=False), description="用于在知识图谱中搜索节点信息", ), Tool( name="get_node_definition", func=lambda x: self._get_node_definition(x), description="获取指定类型的节点定义", ), ] # 使用基础 LLM llm = base_llm # 初始化 Agent agent_executor = initialize_agent( tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, handle_parsing_errors=True, prompt=self.prompt, output_parser=CodeOnlyOutputParser(), ) # 构造输入变量 knowledge_base, node_definitions = self._get_relevant_knowledge(original_code) print(f"知识库内容: {knowledge_base}") print(f"节点定义: {node_definitions}") input_vars = { "input": original_query, "original_query": original_query, "error_info": json.dumps(error_info, ensure_ascii=False) if error_info else "", "original_code": original_code, "KnowledgeBase": knowledge_base, "NodeDefinition": node_definitions, } print(f"Agent输入变量: {input_vars}") # 调用 Agent 执行 response = agent_executor.invoke(input_vars) print(f"Agent响应: {response}") # 处理不同的返回格式 if isinstance(response, dict): if "output" in response: if isinstance(response["output"], dict) and "code" in response["output"]: return {"query": original_query, "code": response["output"]["code"]} elif isinstance(response["output"], str): return {"query": original_query, "code": response["output"]} # 如果无法解析,返回原始代码 print(f"警告: 无法解析 Agent 响应,返回原始代码") return {"query": original_query, "code": original_code} # 全局处理器实例(延迟初始化) _processor = None def _get_processor(): """获取处理器实例(单例模式)""" global _processor if _processor is None: _processor = KnowledgeGraphProcessor() return _processor def rewrite_query_parameters( original_query: str, original_code: str, error_info: Dict[str, str] = None ) -> Dict[str, str]: """ 重写查询参数的外部接口函数 Args: original_query: 原始查询字符串 original_code: 原始代码字符串 error_info: 错误信息字典 (可选) Returns: Dict: 包含修改后的query和code的字典 Example: result = rewrite_query_parameters("查询动态费用", 'search_node("动态费用")') print(result["query"]) # 修改后的查询 print(result["code"]) # 修改后的代码 """ processor = _get_processor() return processor.process_query(original_query, original_code, error_info)