删除 kg_lab_6.13/parameter_rewriting.py
This commit is contained in:
@@ -1,383 +0,0 @@
|
||||
import json
|
||||
import ast
|
||||
import inspect
|
||||
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
from langchain.agents import Tool, initialize_agent, AgentType
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import BaseOutputParser
|
||||
from langchain.schema import BaseRetriever
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.schema import Document
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.llms.base import BaseLLM
|
||||
|
||||
from prompt_templates import FUNCTION_RETURNS_LOOP_PROMPT
|
||||
from llm import llm as base_llm
|
||||
import project # 明确导入project模块
|
||||
|
||||
|
||||
class CodeOnlyOutputParser(BaseOutputParser):
|
||||
def parse(self, text: str) -> dict:
|
||||
if "Final Answer:" in text:
|
||||
code_part = text.split("Final Answer:")[-1].strip()
|
||||
else:
|
||||
code_part = text.strip()
|
||||
|
||||
return {"code": code_part}
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return "只输出最终的Python代码,不要包含其他解释或内容。"
|
||||
|
||||
|
||||
class KnowledgeGraphProcessor:
|
||||
"""知识图谱处理器"""
|
||||
|
||||
def __init__(self, kg_file="kg_simple_hierarchy.json"):
|
||||
"""初始化处理器"""
|
||||
self.kg_data = self._load_kg_data(kg_file)
|
||||
self.prompt = FUNCTION_RETURNS_LOOP_PROMPT
|
||||
|
||||
def _load_kg_data(self, file_path: str) -> Dict[str, Any]:
|
||||
"""加载知识图谱数据"""
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except FileNotFoundError:
|
||||
print(f"警告: 未找到 {file_path} 文件")
|
||||
return {}
|
||||
except Exception as e:
|
||||
print(f"加载知识图谱数据时出错: {e}")
|
||||
return {}
|
||||
|
||||
def _extract_parameters_from_code(self, code: str) -> List[str]:
|
||||
"""从代码中提取字符串参数"""
|
||||
try:
|
||||
tree = ast.parse(code)
|
||||
parameters = []
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Str):
|
||||
parameters.append(node.s)
|
||||
elif isinstance(node, ast.Constant) and isinstance(node.value, str):
|
||||
parameters.append(node.value)
|
||||
|
||||
return parameters
|
||||
except Exception as e:
|
||||
print(f"解析代码时出错: {e}")
|
||||
return []
|
||||
|
||||
def _find_node_by_path(self, path: str, data: Union[Dict[str, Any], List[Any]] = None) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
根据路径查找节点
|
||||
|
||||
Args:
|
||||
path: 节点路径,格式如 "工程数据/安装过程/安装/架空输电线路本体工程/杆塔工程"
|
||||
data: 当前数据节点
|
||||
|
||||
Returns:
|
||||
找到的节点及其子节点
|
||||
"""
|
||||
if data is None:
|
||||
data = self.kg_data
|
||||
|
||||
path_parts = path.split("/")
|
||||
current_part = path_parts[0]
|
||||
|
||||
# 如果data是列表,则遍历列表中的每个元素
|
||||
if isinstance(data, list):
|
||||
for item in data:
|
||||
result = self._find_node_by_path(path, item)
|
||||
if result:
|
||||
return result
|
||||
return None
|
||||
|
||||
# 如果data是字典,则按原来的逻辑处理
|
||||
if isinstance(data, dict):
|
||||
# 检查当前节点是否匹配
|
||||
found_node = None
|
||||
for key, value in data.items():
|
||||
if isinstance(value, str) and current_part in value:
|
||||
found_node = data
|
||||
break
|
||||
|
||||
# 如果在当前层级找到了节点
|
||||
if found_node:
|
||||
# 如果还有更深的路径部分,则继续递归查找
|
||||
if len(path_parts) > 1:
|
||||
# 查找子节点
|
||||
if "children" in found_node:
|
||||
for child in found_node["children"]:
|
||||
result = self._find_node_by_path("/".join(path_parts[1:]), child)
|
||||
if result:
|
||||
return result
|
||||
return None
|
||||
else:
|
||||
# 已经找到最终节点
|
||||
return found_node
|
||||
|
||||
# 如果当前层级没找到,则查找子节点
|
||||
if "children" in data:
|
||||
for child in data["children"]:
|
||||
result = self._find_node_by_path(path, child)
|
||||
if result:
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
def _search_in_kg(self, parameter: str, data: Union[Dict[str, Any], List[Any]] = None) -> Optional[Dict[str, Any]]:
|
||||
"""在知识图谱中递归搜索参数"""
|
||||
# 先尝试按路径查找
|
||||
if "/" in parameter:
|
||||
result = self._find_node_by_path(parameter)
|
||||
if result:
|
||||
return result
|
||||
|
||||
# 如果路径查找失败,则按关键词搜索
|
||||
if data is None:
|
||||
data = self.kg_data
|
||||
|
||||
# 如果是列表,遍历每个元素
|
||||
if isinstance(data, list):
|
||||
for item in data:
|
||||
if isinstance(item, (dict, list)):
|
||||
result = self._search_in_kg(parameter, item)
|
||||
if result:
|
||||
return result
|
||||
return None
|
||||
|
||||
# 如果是字典,按原来的逻辑处理
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
if isinstance(value, str) and parameter in value:
|
||||
return data
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, (dict, list)):
|
||||
result = self._search_in_kg(parameter, item)
|
||||
if result:
|
||||
return result
|
||||
elif isinstance(value, dict):
|
||||
result = self._search_in_kg(parameter, value)
|
||||
if result:
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
def _get_node_definition(self, node_type: str) -> str:
|
||||
"""获取节点类型定义"""
|
||||
try:
|
||||
# 检查project模块中是否存在该类
|
||||
if hasattr(project, node_type):
|
||||
cls = getattr(project, node_type)
|
||||
source = inspect.getsource(cls)
|
||||
return source
|
||||
else:
|
||||
return f"未找到类型 {node_type} 的定义"
|
||||
except NameError:
|
||||
return f"获取类型 {node_type} 定义时出错: project模块未正确导入"
|
||||
except Exception as e:
|
||||
return f"获取类型 {node_type} 定义时出错: {e}"
|
||||
|
||||
def _extract_node_types(self, data: Dict[str, Any], node_types: set):
|
||||
"""递归提取所有节点类型"""
|
||||
if not isinstance(data, dict):
|
||||
return
|
||||
|
||||
for key, value in data.items():
|
||||
# 所有键都可能是节点类型
|
||||
node_types.add(key)
|
||||
|
||||
if isinstance(value, dict):
|
||||
self._extract_node_types(value, node_types)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
self._extract_node_types(item, node_types)
|
||||
|
||||
def _get_relevant_knowledge(self, code: str) -> tuple[str, str]:
|
||||
"""根据代码获取相关知识库内容和节点定义"""
|
||||
parameters = self._extract_parameters_from_code(code)
|
||||
|
||||
knowledge_base = ""
|
||||
node_definitions = ""
|
||||
|
||||
for param in parameters:
|
||||
# 处理【】格式的路径,提取最后一个部分
|
||||
if "【" in param:
|
||||
path_parts = []
|
||||
parts = param.split("【")
|
||||
for part in parts:
|
||||
if "】" in part:
|
||||
clean_part = part.split("】")[0].strip()
|
||||
if clean_part:
|
||||
path_parts.append(clean_part)
|
||||
|
||||
# 对于每个路径,只取最后一个部分
|
||||
for path in path_parts:
|
||||
if "/" in path:
|
||||
# 提取路径中的最后一个部分
|
||||
last_part = path.split("/")[-1].strip()
|
||||
if last_part:
|
||||
found_data = self._search_in_kg(last_part)
|
||||
if found_data:
|
||||
knowledge_base += f"节点 '{last_part}' 相关信息:\n"
|
||||
knowledge_base += json.dumps(found_data, ensure_ascii=False, indent=2) + "\n\n"
|
||||
|
||||
# 提取节点类型
|
||||
node_types = set()
|
||||
self._extract_node_types(found_data, node_types)
|
||||
|
||||
for node_type in node_types:
|
||||
definition = self._get_node_definition(node_type)
|
||||
if definition and "未找到" not in definition and "出错" not in definition:
|
||||
node_definitions += f"类型 {node_type} 定义:\n{definition}\n\n"
|
||||
else:
|
||||
# 如果没有/,直接使用整个部分
|
||||
found_data = self._search_in_kg(path)
|
||||
if found_data:
|
||||
knowledge_base += f"节点 '{path}' 相关信息:\n"
|
||||
knowledge_base += json.dumps(found_data, ensure_ascii=False, indent=2) + "\n\n"
|
||||
|
||||
# 提取节点类型
|
||||
node_types = set()
|
||||
self._extract_node_types(found_data, node_types)
|
||||
|
||||
for node_type in node_types:
|
||||
definition = self._get_node_definition(node_type)
|
||||
if definition and "未找到" not in definition and "出错" not in definition:
|
||||
node_definitions += f"类型 {node_type} 定义:\n{definition}\n\n"
|
||||
else:
|
||||
# 处理普通参数
|
||||
found_data = self._search_in_kg(param)
|
||||
if found_data:
|
||||
knowledge_base += f"参数 '{param}' 相关信息:\n"
|
||||
knowledge_base += json.dumps(found_data, ensure_ascii=False, indent=2) + "\n\n"
|
||||
|
||||
# 提取节点类型
|
||||
node_types = set()
|
||||
self._extract_node_types(found_data, node_types)
|
||||
|
||||
for node_type in node_types:
|
||||
definition = self._get_node_definition(node_type)
|
||||
if definition and "未找到" not in definition and "出错" not in definition:
|
||||
node_definitions += f"类型 {node_type} 定义:\n{definition}\n\n"
|
||||
|
||||
return knowledge_base, node_definitions
|
||||
|
||||
def process_query(
|
||||
self, original_query: str, original_code: str, error_info: Dict[str, str] = None
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
处理查询并尝试修复错误
|
||||
|
||||
Args:
|
||||
original_query: 原始查询字符串
|
||||
original_code: 原始代码字符串
|
||||
error_info: 错误信息字典 (可选)
|
||||
|
||||
Returns:
|
||||
Dict: 包含修改后的query和code的字典
|
||||
"""
|
||||
print("\n====== RAG重写查询开始 ======")
|
||||
print(f"原始查询: {original_query}")
|
||||
print(f"原始代码: {original_code}")
|
||||
print(f"错误信息: {error_info}")
|
||||
|
||||
# 初始化工具
|
||||
tools = [
|
||||
Tool(
|
||||
name="search_knowledge_base",
|
||||
func=lambda x: json.dumps(self._search_in_kg(x), ensure_ascii=False),
|
||||
description="用于在知识图谱中搜索节点信息",
|
||||
),
|
||||
Tool(
|
||||
name="get_node_definition",
|
||||
func=lambda x: self._get_node_definition(x),
|
||||
description="获取指定类型的节点定义",
|
||||
),
|
||||
]
|
||||
|
||||
# 使用基础 LLM
|
||||
llm = base_llm
|
||||
|
||||
# 初始化 Agent
|
||||
agent_executor = initialize_agent(
|
||||
tools,
|
||||
llm,
|
||||
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||
verbose=True,
|
||||
handle_parsing_errors=True,
|
||||
prompt=self.prompt,
|
||||
output_parser=CodeOnlyOutputParser(),
|
||||
)
|
||||
|
||||
# 构造输入变量
|
||||
knowledge_base, node_definitions = self._get_relevant_knowledge(original_code)
|
||||
|
||||
print(f"知识库内容: {knowledge_base}")
|
||||
print(f"节点定义: {node_definitions}")
|
||||
|
||||
input_vars = {
|
||||
"input": original_query,
|
||||
"original_query": original_query,
|
||||
"error_info": json.dumps(error_info, ensure_ascii=False) if error_info else "",
|
||||
"original_code": original_code,
|
||||
"KnowledgeBase": knowledge_base,
|
||||
"NodeDefinition": node_definitions,
|
||||
}
|
||||
|
||||
print(f"Agent输入变量: {input_vars}")
|
||||
|
||||
# 调用 Agent 执行
|
||||
response = agent_executor.invoke(input_vars)
|
||||
|
||||
print(f"Agent响应: {response}")
|
||||
|
||||
# 处理不同的返回格式
|
||||
if isinstance(response, dict):
|
||||
if "output" in response:
|
||||
if isinstance(response["output"], dict) and "code" in response["output"]:
|
||||
return {"query": original_query, "code": response["output"]["code"]}
|
||||
elif isinstance(response["output"], str):
|
||||
return {"query": original_query, "code": response["output"]}
|
||||
|
||||
# 如果无法解析,返回原始代码
|
||||
print(f"警告: 无法解析 Agent 响应,返回原始代码")
|
||||
return {"query": original_query, "code": original_code}
|
||||
|
||||
|
||||
# 全局处理器实例(延迟初始化)
|
||||
_processor = None
|
||||
|
||||
|
||||
def _get_processor():
|
||||
"""获取处理器实例(单例模式)"""
|
||||
global _processor
|
||||
if _processor is None:
|
||||
_processor = KnowledgeGraphProcessor()
|
||||
return _processor
|
||||
|
||||
|
||||
def rewrite_query_parameters(
|
||||
original_query: str, original_code: str, error_info: Dict[str, str] = None
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
重写查询参数的外部接口函数
|
||||
|
||||
Args:
|
||||
original_query: 原始查询字符串
|
||||
original_code: 原始代码字符串
|
||||
error_info: 错误信息字典 (可选)
|
||||
|
||||
Returns:
|
||||
Dict: 包含修改后的query和code的字典
|
||||
|
||||
Example:
|
||||
result = rewrite_query_parameters("查询动态费用", 'search_node("动态费用")')
|
||||
print(result["query"]) # 修改后的查询
|
||||
print(result["code"]) # 修改后的代码
|
||||
"""
|
||||
processor = _get_processor()
|
||||
return processor.process_query(original_query, original_code, error_info)
|
||||
Reference in New Issue
Block a user