diff --git a/parameter_rewriting.py b/parameter_rewriting.py deleted file mode 100644 index 99609df..0000000 --- a/parameter_rewriting.py +++ /dev/null @@ -1,383 +0,0 @@ -import json -import ast -import inspect - -from typing import Dict, Any, List, Optional, Union -from langchain.agents import Tool, initialize_agent, AgentType -from langchain.chat_models import ChatOpenAI -from langchain.prompts import PromptTemplate -from langchain.schema import BaseOutputParser -from langchain.schema import BaseRetriever -from langchain.callbacks.manager import CallbackManagerForRetrieverRun -from langchain.schema import Document -from langchain.chains import LLMChain -from langchain.llms.base import BaseLLM - -from prompt_templates import FUNCTION_RETURNS_LOOP_PROMPT -from llm import llm as base_llm -import project # 明确导入project模块 - - -class CodeOnlyOutputParser(BaseOutputParser): - def parse(self, text: str) -> dict: - if "Final Answer:" in text: - code_part = text.split("Final Answer:")[-1].strip() - else: - code_part = text.strip() - - return {"code": code_part} - - def get_format_instructions(self) -> str: - return "只输出最终的Python代码,不要包含其他解释或内容。" - - -class KnowledgeGraphProcessor: - """知识图谱处理器""" - - def __init__(self, kg_file="kg_simple_hierarchy.json"): - """初始化处理器""" - self.kg_data = self._load_kg_data(kg_file) - self.prompt = FUNCTION_RETURNS_LOOP_PROMPT - - def _load_kg_data(self, file_path: str) -> Dict[str, Any]: - """加载知识图谱数据""" - try: - with open(file_path, "r", encoding="utf-8") as f: - return json.load(f) - except FileNotFoundError: - print(f"警告: 未找到 {file_path} 文件") - return {} - except Exception as e: - print(f"加载知识图谱数据时出错: {e}") - return {} - - def _extract_parameters_from_code(self, code: str) -> List[str]: - """从代码中提取字符串参数""" - try: - tree = ast.parse(code) - parameters = [] - - for node in ast.walk(tree): - if isinstance(node, ast.Str): - parameters.append(node.s) - elif isinstance(node, ast.Constant) and isinstance(node.value, str): - parameters.append(node.value) - - return parameters - except Exception as e: - print(f"解析代码时出错: {e}") - return [] - - def _find_node_by_path(self, path: str, data: Union[Dict[str, Any], List[Any]] = None) -> Optional[Dict[str, Any]]: - """ - 根据路径查找节点 - - Args: - path: 节点路径,格式如 "工程数据/安装过程/安装/架空输电线路本体工程/杆塔工程" - data: 当前数据节点 - - Returns: - 找到的节点及其子节点 - """ - if data is None: - data = self.kg_data - - path_parts = path.split("/") - current_part = path_parts[0] - - # 如果data是列表,则遍历列表中的每个元素 - if isinstance(data, list): - for item in data: - result = self._find_node_by_path(path, item) - if result: - return result - return None - - # 如果data是字典,则按原来的逻辑处理 - if isinstance(data, dict): - # 检查当前节点是否匹配 - found_node = None - for key, value in data.items(): - if isinstance(value, str) and current_part in value: - found_node = data - break - - # 如果在当前层级找到了节点 - if found_node: - # 如果还有更深的路径部分,则继续递归查找 - if len(path_parts) > 1: - # 查找子节点 - if "children" in found_node: - for child in found_node["children"]: - result = self._find_node_by_path("/".join(path_parts[1:]), child) - if result: - return result - return None - else: - # 已经找到最终节点 - return found_node - - # 如果当前层级没找到,则查找子节点 - if "children" in data: - for child in data["children"]: - result = self._find_node_by_path(path, child) - if result: - return result - - return None - - def _search_in_kg(self, parameter: str, data: Union[Dict[str, Any], List[Any]] = None) -> Optional[Dict[str, Any]]: - """在知识图谱中递归搜索参数""" - # 先尝试按路径查找 - if "/" in parameter: - result = self._find_node_by_path(parameter) - if result: - return result - - # 如果路径查找失败,则按关键词搜索 - if data is None: - data = self.kg_data - - # 如果是列表,遍历每个元素 - if isinstance(data, list): - for item in data: - if isinstance(item, (dict, list)): - result = self._search_in_kg(parameter, item) - if result: - return result - return None - - # 如果是字典,按原来的逻辑处理 - if isinstance(data, dict): - for key, value in data.items(): - if isinstance(value, str) and parameter in value: - return data - elif isinstance(value, list): - for item in value: - if isinstance(item, (dict, list)): - result = self._search_in_kg(parameter, item) - if result: - return result - elif isinstance(value, dict): - result = self._search_in_kg(parameter, value) - if result: - return result - - return None - - def _get_node_definition(self, node_type: str) -> str: - """获取节点类型定义""" - try: - # 检查project模块中是否存在该类 - if hasattr(project, node_type): - cls = getattr(project, node_type) - source = inspect.getsource(cls) - return source - else: - return f"未找到类型 {node_type} 的定义" - except NameError: - return f"获取类型 {node_type} 定义时出错: project模块未正确导入" - except Exception as e: - return f"获取类型 {node_type} 定义时出错: {e}" - - def _extract_node_types(self, data: Dict[str, Any], node_types: set): - """递归提取所有节点类型""" - if not isinstance(data, dict): - return - - for key, value in data.items(): - # 所有键都可能是节点类型 - node_types.add(key) - - if isinstance(value, dict): - self._extract_node_types(value, node_types) - elif isinstance(value, list): - for item in value: - if isinstance(item, dict): - self._extract_node_types(item, node_types) - - def _get_relevant_knowledge(self, code: str) -> tuple[str, str]: - """根据代码获取相关知识库内容和节点定义""" - parameters = self._extract_parameters_from_code(code) - - knowledge_base = "" - node_definitions = "" - - for param in parameters: - # 处理【】格式的路径,提取最后一个部分 - if "【" in param: - path_parts = [] - parts = param.split("【") - for part in parts: - if "】" in part: - clean_part = part.split("】")[0].strip() - if clean_part: - path_parts.append(clean_part) - - # 对于每个路径,只取最后一个部分 - for path in path_parts: - if "/" in path: - # 提取路径中的最后一个部分 - last_part = path.split("/")[-1].strip() - if last_part: - found_data = self._search_in_kg(last_part) - if found_data: - knowledge_base += f"节点 '{last_part}' 相关信息:\n" - knowledge_base += json.dumps(found_data, ensure_ascii=False, indent=2) + "\n\n" - - # 提取节点类型 - node_types = set() - self._extract_node_types(found_data, node_types) - - for node_type in node_types: - definition = self._get_node_definition(node_type) - if definition and "未找到" not in definition and "出错" not in definition: - node_definitions += f"类型 {node_type} 定义:\n{definition}\n\n" - else: - # 如果没有/,直接使用整个部分 - found_data = self._search_in_kg(path) - if found_data: - knowledge_base += f"节点 '{path}' 相关信息:\n" - knowledge_base += json.dumps(found_data, ensure_ascii=False, indent=2) + "\n\n" - - # 提取节点类型 - node_types = set() - self._extract_node_types(found_data, node_types) - - for node_type in node_types: - definition = self._get_node_definition(node_type) - if definition and "未找到" not in definition and "出错" not in definition: - node_definitions += f"类型 {node_type} 定义:\n{definition}\n\n" - else: - # 处理普通参数 - found_data = self._search_in_kg(param) - if found_data: - knowledge_base += f"参数 '{param}' 相关信息:\n" - knowledge_base += json.dumps(found_data, ensure_ascii=False, indent=2) + "\n\n" - - # 提取节点类型 - node_types = set() - self._extract_node_types(found_data, node_types) - - for node_type in node_types: - definition = self._get_node_definition(node_type) - if definition and "未找到" not in definition and "出错" not in definition: - node_definitions += f"类型 {node_type} 定义:\n{definition}\n\n" - - return knowledge_base, node_definitions - - def process_query( - self, original_query: str, original_code: str, error_info: Dict[str, str] = None - ) -> Dict[str, str]: - """ - 处理查询并尝试修复错误 - - Args: - original_query: 原始查询字符串 - original_code: 原始代码字符串 - error_info: 错误信息字典 (可选) - - Returns: - Dict: 包含修改后的query和code的字典 - """ - print("\n====== RAG重写查询开始 ======") - print(f"原始查询: {original_query}") - print(f"原始代码: {original_code}") - print(f"错误信息: {error_info}") - - # 初始化工具 - tools = [ - Tool( - name="search_knowledge_base", - func=lambda x: json.dumps(self._search_in_kg(x), ensure_ascii=False), - description="用于在知识图谱中搜索节点信息", - ), - Tool( - name="get_node_definition", - func=lambda x: self._get_node_definition(x), - description="获取指定类型的节点定义", - ), - ] - - # 使用基础 LLM - llm = base_llm - - # 初始化 Agent - agent_executor = initialize_agent( - tools, - llm, - agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, - verbose=True, - handle_parsing_errors=True, - prompt=self.prompt, - output_parser=CodeOnlyOutputParser(), - ) - - # 构造输入变量 - knowledge_base, node_definitions = self._get_relevant_knowledge(original_code) - - print(f"知识库内容: {knowledge_base}") - print(f"节点定义: {node_definitions}") - - input_vars = { - "input": original_query, - "original_query": original_query, - "error_info": json.dumps(error_info, ensure_ascii=False) if error_info else "", - "original_code": original_code, - "KnowledgeBase": knowledge_base, - "NodeDefinition": node_definitions, - } - - print(f"Agent输入变量: {input_vars}") - - # 调用 Agent 执行 - response = agent_executor.invoke(input_vars) - - print(f"Agent响应: {response}") - - # 处理不同的返回格式 - if isinstance(response, dict): - if "output" in response: - if isinstance(response["output"], dict) and "code" in response["output"]: - return {"query": original_query, "code": response["output"]["code"]} - elif isinstance(response["output"], str): - return {"query": original_query, "code": response["output"]} - - # 如果无法解析,返回原始代码 - print(f"警告: 无法解析 Agent 响应,返回原始代码") - return {"query": original_query, "code": original_code} - - -# 全局处理器实例(延迟初始化) -_processor = None - - -def _get_processor(): - """获取处理器实例(单例模式)""" - global _processor - if _processor is None: - _processor = KnowledgeGraphProcessor() - return _processor - - -def rewrite_query_parameters( - original_query: str, original_code: str, error_info: Dict[str, str] = None -) -> Dict[str, str]: - """ - 重写查询参数的外部接口函数 - - Args: - original_query: 原始查询字符串 - original_code: 原始代码字符串 - error_info: 错误信息字典 (可选) - - Returns: - Dict: 包含修改后的query和code的字典 - - Example: - result = rewrite_query_parameters("查询动态费用", 'search_node("动态费用")') - print(result["query"]) # 修改后的查询 - print(result["code"]) # 修改后的代码 - """ - processor = _get_processor() - return processor.process_query(original_query, original_code, error_info)