Files
langchain_KG/parameter_rewriting.py
2025-06-19 16:53:47 +08:00

328 lines
13 KiB
Python

import json
import ast
import inspect
from typing import Dict, Any, List, Optional, Union
from langchain.agents import Tool, initialize_agent, AgentType
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.schema import BaseOutputParser
from langchain.schema import BaseRetriever
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.schema import Document
from langchain.chains import LLMChain
from langchain.llms.base import BaseLLM
from prompt_templates import FUNCTION_RETURNS_LOOP_PROMPT
from llm import llm as base_llm
import project # 明确导入project模块
class CodeOnlyOutputParser(BaseOutputParser):
def parse(self, text: str) -> dict:
if "Final Answer:" in text:
code_part = text.split("Final Answer:")[-1].strip()
else:
code_part = text.strip()
return {"code": code_part}
def get_format_instructions(self) -> str:
return "只输出最终的Python代码,不要包含其他解释或内容。"
class KnowledgeGraphProcessor:
"""知识图谱处理器"""
def __init__(self, kg_file="kg_simple_hierarchy.json"):
"""初始化处理器"""
self.kg_data = self._load_kg_data(kg_file)
self.prompt = FUNCTION_RETURNS_LOOP_PROMPT
def _load_kg_data(self, file_path: str) -> Dict[str, Any]:
"""加载知识图谱数据"""
try:
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
except FileNotFoundError:
print(f"警告: 未找到 {file_path} 文件")
return {}
except Exception as e:
print(f"加载知识图谱数据时出错: {e}")
return {}
def _extract_parameters_from_code(self, code: str) -> List[str]:
"""从代码中提取字符串参数"""
try:
tree = ast.parse(code)
parameters = []
for node in ast.walk(tree):
if isinstance(node, ast.Str):
parameters.append(node.s)
elif isinstance(node, ast.Constant) and isinstance(node.value, str):
parameters.append(node.value)
return parameters
except Exception as e:
print(f"解析代码时出错: {e}")
return []
def _find_node_by_path(self, path: str, data: Union[Dict[str, Any], List[Any]] = None) -> Optional[Dict[str, Any]]:
"""
根据路径查找节点
Args:
path: 节点路径,格式如 "工程数据/安装过程/安装/架空输电线路本体工程/杆塔工程"
data: 当前数据节点
Returns:
找到的节点及其子节点
"""
if data is None:
data = self.kg_data
path_parts = path.split("/")
current_part = path_parts[0]
# 如果data是列表,则遍历列表中的每个元素
if isinstance(data, list):
for item in data:
result = self._find_node_by_path(path, item)
if result:
return result
return None
# 如果data是字典,则按原来的逻辑处理
if isinstance(data, dict):
# 检查当前节点是否匹配
found_node = None
for key, value in data.items():
if isinstance(value, str) and current_part in value:
found_node = data
break
# 如果在当前层级找到了节点
if found_node:
# 如果还有更深的路径部分,则继续递归查找
if len(path_parts) > 1:
# 查找子节点
if "children" in found_node:
for child in found_node["children"]:
result = self._find_node_by_path("/".join(path_parts[1:]), child)
if result:
return result
return None
else:
# 已经找到最终节点
return found_node
# 如果当前层级没找到,则查找子节点
if "children" in data:
for child in data["children"]:
result = self._find_node_by_path(path, child)
if result:
return result
return None
def _search_in_kg(self, parameter: str, data: Union[Dict[str, Any], List[Any]] = None) -> Optional[Dict[str, Any]]:
"""在知识图谱中递归搜索参数"""
# 先尝试按路径查找
if "/" in parameter:
result = self._find_node_by_path(parameter)
if result:
return result
# 如果路径查找失败,则按关键词搜索
if data is None:
data = self.kg_data
# 如果是列表,遍历每个元素
if isinstance(data, list):
for item in data:
if isinstance(item, (dict, list)):
result = self._search_in_kg(parameter, item)
if result:
return result
return None
# 如果是字典,按原来的逻辑处理
if isinstance(data, dict):
for key, value in data.items():
if isinstance(value, str) and parameter in value:
return data
elif isinstance(value, list):
for item in value:
if isinstance(item, (dict, list)):
result = self._search_in_kg(parameter, item)
if result:
return result
elif isinstance(value, dict):
result = self._search_in_kg(parameter, value)
if result:
return result
return None
def _get_node_definition(self, node_type: str) -> str:
"""获取节点类型定义"""
try:
# 检查project模块中是否存在该类
if hasattr(project, node_type):
cls = getattr(project, node_type)
source = inspect.getsource(cls)
return source
else:
return f"未找到类型 {node_type} 的定义"
except NameError:
return f"获取类型 {node_type} 定义时出错: project模块未正确导入"
except Exception as e:
return f"获取类型 {node_type} 定义时出错: {e}"
def _extract_node_types(self, data: Dict[str, Any], node_types: set):
"""递归提取所有节点类型"""
if not isinstance(data, dict):
return
for key, value in data.items():
# 所有键都可能是节点类型
node_types.add(key)
if isinstance(value, dict):
self._extract_node_types(value, node_types)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
self._extract_node_types(item, node_types)
def _get_relevant_knowledge(self, code: str) -> tuple[str, str]:
"""根据代码获取相关知识库内容和节点定义"""
parameters = self._extract_parameters_from_code(code)
knowledge_base = ""
node_definitions = ""
for param in parameters:
# 处理【】格式的路径,提取最后一个部分
if "【" in param:
path_parts = []
parts = param.split("【")
for part in parts:
if "】" in part:
clean_part = part.split("】")[0].strip()
if clean_part:
path_parts.append(clean_part)
# 对于每个路径,只取最后一个部分
for path in path_parts:
if "/" in path:
# 提取路径中的最后一个部分
last_part = path.split("/")[-1].strip()
if last_part:
found_data = self._search_in_kg(last_part)
if found_data:
knowledge_base += f"节点 '{last_part}' 相关信息:\n"
knowledge_base += json.dumps(found_data, ensure_ascii=False, indent=2) + "\n\n"
# 提取节点类型
node_types = set()
self._extract_node_types(found_data, node_types)
for node_type in node_types:
definition = self._get_node_definition(node_type)
if definition and "未找到" not in definition and "出错" not in definition:
node_definitions += f"类型 {node_type} 定义:\n{definition}\n\n"
else:
# 如果没有/,直接使用整个部分
found_data = self._search_in_kg(path)
if found_data:
knowledge_base += f"节点 '{path}' 相关信息:\n"
knowledge_base += json.dumps(found_data, ensure_ascii=False, indent=2) + "\n\n"
# 提取节点类型
node_types = set()
self._extract_node_types(found_data, node_types)
for node_type in node_types:
definition = self._get_node_definition(node_type)
if definition and "未找到" not in definition and "出错" not in definition:
node_definitions += f"类型 {node_type} 定义:\n{definition}\n\n"
else:
# 处理普通参数
found_data = self._search_in_kg(param)
if found_data:
knowledge_base += f"参数 '{param}' 相关信息:\n"
knowledge_base += json.dumps(found_data, ensure_ascii=False, indent=2) + "\n\n"
# 提取节点类型
node_types = set()
self._extract_node_types(found_data, node_types)
for node_type in node_types:
definition = self._get_node_definition(node_type)
if definition and "未找到" not in definition and "出错" not in definition:
node_definitions += f"类型 {node_type} 定义:\n{definition}\n\n"
return knowledge_base, node_definitions
def process_query(self, original_code: str, error_info: Dict[str, str] = None):
print("\n====== RAG重写查询开始 ======")
print(f"原始代码: {original_code}")
print(f"错误信息: {error_info}")
# 检查是否需要重写
if error_info and isinstance(error_info, dict) and "error" in error_info:
# 有错误信息,需要重写
input_vars = {
"original_code": original_code,
"error_info": json.dumps(error_info, ensure_ascii=False) if error_info else "",
}
# 使用LLM
llm_chain = LLMChain(llm=base_llm, prompt=FUNCTION_RETURNS_LOOP_PROMPT)
response = llm_chain.invoke(input_vars)
print(f"LLM响应: {response}")
# 提取代码
fixed_code = response["text"].strip()
return {"code": fixed_code}
else:
# 没有错误信息,不需要重写
print("没有错误信息,不需要重写")
return {"code": original_code}
# 全局处理器实例(延迟初始化)
_processor = None
def _get_processor():
"""获取处理器实例(单例模式)"""
global _processor
if _processor is None:
_processor = KnowledgeGraphProcessor()
return _processor
def rewrite_query_parameters(original_code: str, error_info: Dict[str, str] = None) -> Dict[str, str]:
"""
重写查询参数的外部接口函数
Args:
original_query: 原始查询字符串
original_code: 原始代码字符串
error_info: 错误信息字典 (可选)
Returns:
Dict: 包含修改后的query和code的字典
Example:
result = rewrite_query_parameters("查询动态费用", 'search_node("动态费用")')
print(result["query"]) # 修改后的查询
print(result["code"]) # 修改后的代码
"""
processor = _get_processor()
return processor.process_query(original_code, error_info)