上传文件至 /

This commit is contained in:
2025-06-19 16:53:47 +08:00
parent 39e1f822a6
commit e36e970be1
3 changed files with 899 additions and 0 deletions
+492
View File
@@ -0,0 +1,492 @@
from langchain.chains import LLMChain
from langchain_openai import OpenAI
from langchain_experimental.utilities import PythonREPL
from project_implementation import ProjectBuilder
from prompt_templates import FUNCTION_CALL_PROMPT
import inspect
import project
import io
import sys
from parameter_rewriting import rewrite_query_parameters, KnowledgeGraphProcessor
import json
from langchain.agents import Tool, AgentExecutor, create_react_agent
from llm import llm
# 获取ProjectTookiIt类的方法定义
def get_project_class_methods():
"""
从project模块中提取ProjectTookiIt类的方法定义
Returns:
str: 格式化后的方法定义字符串
"""
project_class_code = inspect.getsource(project.ProjectTookiIt)
lines = project_class_code.split("\n")
result_lines = []
in_class = False
skip_init = False
for line in lines:
if line.strip().startswith("class ProjectTookiIt"):
in_class = True
result_lines.append(line)
elif in_class:
if line.strip().startswith("def __init__"):
skip_init = True
elif skip_init and line.strip() and not line.startswith(" " * 8):
skip_init = False
if not skip_init:
result_lines.append(line)
return "\n".join(result_lines)
# 创建动态提示模板
project_class_methods = get_project_class_methods()
# 创建 Chain
function_call_chain = LLMChain(llm=llm, prompt=FUNCTION_CALL_PROMPT, output_key="code")
# Python 执行器
repl = PythonREPL()
# 创建知识图谱处理器实例
kg_processor = KnowledgeGraphProcessor()
# 定义搜索知识库的工具
def search_knowledge_and_node_definition(query):
"""
在知识库中搜索关键词
Args:
query (str): 搜索关键词
Returns:
str: 搜索结果的JSON字符串
"""
found_data = kg_processor._get_relevant_knowledge(query)
if found_data:
return json.dumps(found_data, ensure_ascii=False, indent=2)
else:
return f"未找到与'{query}'相关的信息"
# 创建工具列表
tools = [
Tool(
name="search_knowledge_and_node_definition",
func=search_knowledge_and_node_definition,
description="获取输入节点的知识图谱结构和对应节点定义类型代码。输入应该是一个节点类型名称。",
),
]
# 创建Agent
agent = create_react_agent(llm, tools, FUNCTION_CALL_PROMPT)
# 创建Agent执行器
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True)
def nl_query_to_function_call(input_data):
"""
将自然语言查询转换为函数调用并执行,或直接执行提供的代码
Args:
input_data (dict): 包含type和value的字典
{
"type": "query|code",
"value": "查询内容或代码"
}
Returns:
dict: 包含状态码、消息和数据的字典
"""
input_type = input_data.get("type", "query")
input_value = input_data.get("value", "")
original_query = input_value
print(f"\n====== 开始执行查询函数 ======")
print(f"查询类型: {input_type}")
print(f"查询内容: {input_value}")
# 第一次执行:使用Agent生成代码
if input_type == "query":
print("\n----- 第一次尝试:使用Agent生成代码 -----")
try:
agent_response = agent_executor.invoke(
{
"query": input_value,
"project_class_methods": project_class_methods,
}
)
code = agent_response["output"].strip()
print(f"\n生成的代码:\n{code}")
except Exception as e:
# 确保恢复stdout
sys.stdout = old_stdout
return {
"code": 1,
"message": f"Agent生成代码失败: {e}",
"data": {"value": "", "code": ""},
"query_status": "Agent生成代码失败",
}
else:
code = input_value.strip()
print(f"\n使用提供的代码:\n{code}")
original_code = code
def execute_code(code_str):
"""封装代码执行逻辑"""
print(f"开始执行代码: {code_str[:50]}...")
try:
namespace = {
"ProjectBuilder": ProjectBuilder,
"project_implementation": __import__("project_implementation"),
"project": __import__("project"),
}
old_stdout = sys.stdout
redirected_output = io.StringIO()
sys.stdout = redirected_output
exec(code_str, namespace)
# 确保neo4j_find_function存在
if "neo4j_find_function" not in namespace:
raise ValueError("代码中未定义neo4j_find_function函数")
result_tuple = namespace["neo4j_find_function"]()
sys.stdout = old_stdout
output = redirected_output.getvalue().strip()
if not isinstance(result_tuple, tuple) or len(result_tuple) != 4:
raise ValueError("函数应返回包含4个元素的元组(status, data, error, helper_info)")
status, data, error, helper_info = result_tuple
# 添加更详细的日志
print(f"执行结果: status={status}, data={data}, error={error}")
return {
"status": status,
"data": data,
"error": error,
"helper_info": helper_info,
"output": output,
}
except Exception as e:
# 确保恢复stdout
sys.stdout = old_stdout
import traceback
print(f"执行代码时出错: {e}")
print(traceback.format_exc())
return {
"status": "error",
"error": str(e),
"helper_info": [],
"traceback": traceback.format_exc(),
}
# 第一次执行
print("开始执行函数...")
print("准备执行代码...")
result = execute_code(code)
print(f"执行完成,结果状态: {result.get('status', 'unknown')}")
# 检查是否需要重写
if result["status"] == "success":
print("执行成功,准备返回结果...")
# 返回成功结果...
final_result = {
"code": 0,
"message": "成功",
"data": {
"value": result["data"],
"code": original_code,
},
"query_status": "第一次查询成功",
}
else:
print(f"执行失败,错误: {result.get('error', 'unknown')}")
# 重写代码...
# 第一次失败,记录错误信息
error_info = {
"error": result["error"],
"helper_info": result.get("helper_info", []),
"traceback": result.get("traceback", ""),
}
print("\n第一次查询失败,尝试使用LLM修复代码...")
# 使用LLM直接修复代码
rewritten = rewrite_query_parameters(original_code=original_code, error_info=error_info)
fixed_code = rewritten.get("code", "").strip()
if not fixed_code:
print("\nLLM未返回有效代码")
final_result = {
"code": 1,
"message": error_info["error"],
"data": {"value": "", "code": original_code},
"error_info": error_info,
"query_status": "第一次查询失败,LLM未返回有效代码",
}
else:
print(f"\nLLM修复后的代码:\n{fixed_code}")
# 第二次执行修复后的代码
result = execute_code(fixed_code)
if result["status"] == "success":
print("\nLLM修复后查询成功")
print(f"\n执行结果详情: {json.dumps(result, ensure_ascii=False, indent=2)}")
final_result = {
"code": 0,
"message": "成功",
"data": {
"value": result["data"],
"code": fixed_code,
},
"query_status": "LLM修复后查询成功",
}
else:
print("\nLLM修复后查询仍然失败")
final_result = {
"code": 1,
"message": result["error"],
"data": {"value": "", "code": fixed_code},
"error_info": {
"error": result["error"],
"helper_info": result.get("helper_info", []),
},
"query_status": "LLM修复后查询仍然失败",
}
# 最后返回前
print("函数执行完毕,准备返回最终结果")
return final_result
def format_result(result):
"""
格式化查询结果
Args:
result: 查询结果(可能为 list、dict 或其他类型)
Returns:
str: 格式化后的结果
"""
# 处理 project 对象
if hasattr(result, "__module__") and result.__module__ == "project":
# 这是一个 project 模块中的对象
attrs = {k: v for k, v in result.__dict__.items() if not k.startswith("_")}
return json.dumps(attrs, ensure_ascii=False, indent=2)
# 处理 project 对象列表
if isinstance(result, list) and all(
hasattr(item, "__module__") and item.__module__ == "project" for item in result if hasattr(item, "__module__")
):
formatted_items = []
for item in result:
if hasattr(item, "__dict__"):
attrs = {k: v for k, v in item.__dict__.items() if not k.startswith("_")}
formatted_items.append(attrs)
else:
formatted_items.append(str(item))
return json.dumps(formatted_items, ensure_ascii=False, indent=2)
# 如果结果是字符串,可能包含调试信息,需要提取有用部分
if isinstance(result, str):
# 尝试提取最终结果部分
if "[]" in result:
return "未找到匹配的数据。"
# 如果包含节点信息,提取关键部分
import re
node_match = re.search(r"找到.*?labels=.*?properties=(.*?)>", result)
if node_match:
try:
# 提取属性部分并格式化
props_str = node_match.group(1).replace("'", '"')
import ast
props = ast.literal_eval(props_str)
formatted = "找到节点:\n"
for k, v in props.items():
formatted += f" {k}: {v}\n"
return formatted
except:
pass
# 如果包含查询结果数量
count_match = re.search(r"查询结果数量: (\d+)", result)
if count_match:
count = count_match.group(1)
if count == "0":
return "未找到匹配的数据。"
# 如果是列表
if isinstance(result, list):
if not result:
return "未找到匹配的数据。"
lines = [f"找到 {len(result)} 条匹配结果:"]
for i, item in enumerate(result, 1):
lines.append(f"\n结果 {i}:")
if hasattr(item, "items"): # 检查是否有items方法(字典或类似字典的对象)
try:
for k, v in item.items():
lines.append(f" {k}: {v}")
except:
lines.append(f" {item}")
else:
lines.append(f" {item}")
return "\n".join(lines)
# 如果是字典
elif isinstance(result, dict):
lines = ["查询结果:"]
for k, v in result.items():
lines.append(f" {k}: {v}")
return "\n".join(lines)
# 其他类型
else:
return str(result)
def format_dict_or_item(item):
"""
格式化字典或其他对象
Args:
item: 字典或其他对象
Returns:
str: 格式化后的字符串
"""
if isinstance(item, dict):
formatted = ""
for key, value in item.items():
formatted += f" {key}: {value}\n"
return formatted
return str(item)
def _extract_all_values(item, result_list):
"""
递归提取字典中的所有值并添加到结果列表中
Args:
item: 字典或其他对象
result_list: 用于存储结果的列表
"""
if isinstance(item, dict):
for key, value in item.items():
if key != "children": # 跳过children键,它会在主逻辑中单独处理
if isinstance(value, str):
result_list.append(value)
elif isinstance(value, (dict, list)):
_extract_all_values(value, result_list)
elif isinstance(item, list):
for element in item:
_extract_all_values(element, result_list)
# 定义一个辅助函数来提取和处理相似参数
def extract_similar_parameters(path_parts, knowledge_base):
"""
从路径和知识库中提取相似参数
Args:
path_parts (list): 路径部分列表
knowledge_base (list): 知识库数据
Returns:
str: 相似参数字符串
"""
# 从路径中提取关键词
extracted_parts = []
for part in path_parts:
if "/" in part:
last_part = part.split("/")[-1].strip()
if last_part:
extracted_parts.append(last_part)
else:
extracted_parts.append(part)
# 收集所有可能的相似参数
similar_params = []
# 从knowledge_base中查找相关项及其子节点
for part in extracted_parts:
for item in knowledge_base:
# 检查当前项是否匹配
match_found = False
for key, value in item.items():
if isinstance(value, str) and part.lower() in value.lower():
match_found = True
break
if match_found:
# 如果找到匹配项,提取所有值
_extract_all_values(item, similar_params)
# 特别处理子节点
if "children" in item and isinstance(item["children"], list):
for child in item["children"]:
_extract_all_values(child, similar_params)
# 移除重复项并排序
similar_params = list(set(similar_params))
similar_params.sort()
print(f"找到的相似参数: {similar_params}")
# 创建一个包含所有相似参数的字符串
return ", ".join(similar_params)
# 在查询前先增加一个简单函数,专门提取字符串中的键值对值
def extract_values_from_kb_string(kb_string):
"""从知识库字符串中提取所有键值对的值"""
import re
# 匹配所有键值对:"key": "value" 的模式
# 这里我们直接取第二个捕获组,即值部分
values = re.findall(r'"([^"]+)"\s*:\s*"([^"]+)"', kb_string)
# 只保留值(第二个元素)
result = [match[1] for match in values]
return result
# question = {
# "type": "query",
# "value": "查找一下【工程数据/安装工程/安装/架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立】的类型为【主材】的【塔材】",
# }
# result = nl_query_to_function_call(question)
# print("\n最终结果:")
# print(result)
+327
View File
@@ -0,0 +1,327 @@
import json
import ast
import inspect
from typing import Dict, Any, List, Optional, Union
from langchain.agents import Tool, initialize_agent, AgentType
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.schema import BaseOutputParser
from langchain.schema import BaseRetriever
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.schema import Document
from langchain.chains import LLMChain
from langchain.llms.base import BaseLLM
from prompt_templates import FUNCTION_RETURNS_LOOP_PROMPT
from llm import llm as base_llm
import project # 明确导入project模块
class CodeOnlyOutputParser(BaseOutputParser):
def parse(self, text: str) -> dict:
if "Final Answer:" in text:
code_part = text.split("Final Answer:")[-1].strip()
else:
code_part = text.strip()
return {"code": code_part}
def get_format_instructions(self) -> str:
return "只输出最终的Python代码,不要包含其他解释或内容。"
class KnowledgeGraphProcessor:
"""知识图谱处理器"""
def __init__(self, kg_file="kg_simple_hierarchy.json"):
"""初始化处理器"""
self.kg_data = self._load_kg_data(kg_file)
self.prompt = FUNCTION_RETURNS_LOOP_PROMPT
def _load_kg_data(self, file_path: str) -> Dict[str, Any]:
"""加载知识图谱数据"""
try:
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
except FileNotFoundError:
print(f"警告: 未找到 {file_path} 文件")
return {}
except Exception as e:
print(f"加载知识图谱数据时出错: {e}")
return {}
def _extract_parameters_from_code(self, code: str) -> List[str]:
"""从代码中提取字符串参数"""
try:
tree = ast.parse(code)
parameters = []
for node in ast.walk(tree):
if isinstance(node, ast.Str):
parameters.append(node.s)
elif isinstance(node, ast.Constant) and isinstance(node.value, str):
parameters.append(node.value)
return parameters
except Exception as e:
print(f"解析代码时出错: {e}")
return []
def _find_node_by_path(self, path: str, data: Union[Dict[str, Any], List[Any]] = None) -> Optional[Dict[str, Any]]:
"""
根据路径查找节点
Args:
path: 节点路径,格式如 "工程数据/安装过程/安装/架空输电线路本体工程/杆塔工程"
data: 当前数据节点
Returns:
找到的节点及其子节点
"""
if data is None:
data = self.kg_data
path_parts = path.split("/")
current_part = path_parts[0]
# 如果data是列表,则遍历列表中的每个元素
if isinstance(data, list):
for item in data:
result = self._find_node_by_path(path, item)
if result:
return result
return None
# 如果data是字典,则按原来的逻辑处理
if isinstance(data, dict):
# 检查当前节点是否匹配
found_node = None
for key, value in data.items():
if isinstance(value, str) and current_part in value:
found_node = data
break
# 如果在当前层级找到了节点
if found_node:
# 如果还有更深的路径部分,则继续递归查找
if len(path_parts) > 1:
# 查找子节点
if "children" in found_node:
for child in found_node["children"]:
result = self._find_node_by_path("/".join(path_parts[1:]), child)
if result:
return result
return None
else:
# 已经找到最终节点
return found_node
# 如果当前层级没找到,则查找子节点
if "children" in data:
for child in data["children"]:
result = self._find_node_by_path(path, child)
if result:
return result
return None
def _search_in_kg(self, parameter: str, data: Union[Dict[str, Any], List[Any]] = None) -> Optional[Dict[str, Any]]:
"""在知识图谱中递归搜索参数"""
# 先尝试按路径查找
if "/" in parameter:
result = self._find_node_by_path(parameter)
if result:
return result
# 如果路径查找失败,则按关键词搜索
if data is None:
data = self.kg_data
# 如果是列表,遍历每个元素
if isinstance(data, list):
for item in data:
if isinstance(item, (dict, list)):
result = self._search_in_kg(parameter, item)
if result:
return result
return None
# 如果是字典,按原来的逻辑处理
if isinstance(data, dict):
for key, value in data.items():
if isinstance(value, str) and parameter in value:
return data
elif isinstance(value, list):
for item in value:
if isinstance(item, (dict, list)):
result = self._search_in_kg(parameter, item)
if result:
return result
elif isinstance(value, dict):
result = self._search_in_kg(parameter, value)
if result:
return result
return None
def _get_node_definition(self, node_type: str) -> str:
"""获取节点类型定义"""
try:
# 检查project模块中是否存在该类
if hasattr(project, node_type):
cls = getattr(project, node_type)
source = inspect.getsource(cls)
return source
else:
return f"未找到类型 {node_type} 的定义"
except NameError:
return f"获取类型 {node_type} 定义时出错: project模块未正确导入"
except Exception as e:
return f"获取类型 {node_type} 定义时出错: {e}"
def _extract_node_types(self, data: Dict[str, Any], node_types: set):
"""递归提取所有节点类型"""
if not isinstance(data, dict):
return
for key, value in data.items():
# 所有键都可能是节点类型
node_types.add(key)
if isinstance(value, dict):
self._extract_node_types(value, node_types)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
self._extract_node_types(item, node_types)
def _get_relevant_knowledge(self, code: str) -> tuple[str, str]:
"""根据代码获取相关知识库内容和节点定义"""
parameters = self._extract_parameters_from_code(code)
knowledge_base = ""
node_definitions = ""
for param in parameters:
# 处理【】格式的路径,提取最后一个部分
if "" in param:
path_parts = []
parts = param.split("")
for part in parts:
if "" in part:
clean_part = part.split("")[0].strip()
if clean_part:
path_parts.append(clean_part)
# 对于每个路径,只取最后一个部分
for path in path_parts:
if "/" in path:
# 提取路径中的最后一个部分
last_part = path.split("/")[-1].strip()
if last_part:
found_data = self._search_in_kg(last_part)
if found_data:
knowledge_base += f"节点 '{last_part}' 相关信息:\n"
knowledge_base += json.dumps(found_data, ensure_ascii=False, indent=2) + "\n\n"
# 提取节点类型
node_types = set()
self._extract_node_types(found_data, node_types)
for node_type in node_types:
definition = self._get_node_definition(node_type)
if definition and "未找到" not in definition and "出错" not in definition:
node_definitions += f"类型 {node_type} 定义:\n{definition}\n\n"
else:
# 如果没有/,直接使用整个部分
found_data = self._search_in_kg(path)
if found_data:
knowledge_base += f"节点 '{path}' 相关信息:\n"
knowledge_base += json.dumps(found_data, ensure_ascii=False, indent=2) + "\n\n"
# 提取节点类型
node_types = set()
self._extract_node_types(found_data, node_types)
for node_type in node_types:
definition = self._get_node_definition(node_type)
if definition and "未找到" not in definition and "出错" not in definition:
node_definitions += f"类型 {node_type} 定义:\n{definition}\n\n"
else:
# 处理普通参数
found_data = self._search_in_kg(param)
if found_data:
knowledge_base += f"参数 '{param}' 相关信息:\n"
knowledge_base += json.dumps(found_data, ensure_ascii=False, indent=2) + "\n\n"
# 提取节点类型
node_types = set()
self._extract_node_types(found_data, node_types)
for node_type in node_types:
definition = self._get_node_definition(node_type)
if definition and "未找到" not in definition and "出错" not in definition:
node_definitions += f"类型 {node_type} 定义:\n{definition}\n\n"
return knowledge_base, node_definitions
def process_query(self, original_code: str, error_info: Dict[str, str] = None):
print("\n====== RAG重写查询开始 ======")
print(f"原始代码: {original_code}")
print(f"错误信息: {error_info}")
# 检查是否需要重写
if error_info and isinstance(error_info, dict) and "error" in error_info:
# 有错误信息,需要重写
input_vars = {
"original_code": original_code,
"error_info": json.dumps(error_info, ensure_ascii=False) if error_info else "",
}
# 使用LLM
llm_chain = LLMChain(llm=base_llm, prompt=FUNCTION_RETURNS_LOOP_PROMPT)
response = llm_chain.invoke(input_vars)
print(f"LLM响应: {response}")
# 提取代码
fixed_code = response["text"].strip()
return {"code": fixed_code}
else:
# 没有错误信息,不需要重写
print("没有错误信息,不需要重写")
return {"code": original_code}
# 全局处理器实例(延迟初始化)
_processor = None
def _get_processor():
"""获取处理器实例(单例模式)"""
global _processor
if _processor is None:
_processor = KnowledgeGraphProcessor()
return _processor
def rewrite_query_parameters(original_code: str, error_info: Dict[str, str] = None) -> Dict[str, str]:
"""
重写查询参数的外部接口函数
Args:
original_query: 原始查询字符串
original_code: 原始代码字符串
error_info: 错误信息字典 (可选)
Returns:
Dict: 包含修改后的query和code的字典
Example:
result = rewrite_query_parameters("查询动态费用", 'search_node("动态费用")')
print(result["query"]) # 修改后的查询
print(result["code"]) # 修改后的代码
"""
processor = _get_processor()
return processor.process_query(original_code, error_info)
+80
View File
@@ -0,0 +1,80 @@
from langchain.prompts import PromptTemplate
FUNCTION_CALL_TEMPLATE = """
你是一个专业的Python工程师。我会给你一个用户问题,你需要将其转换为对应的Python代码
可用工具:
{tools}
工具名称:
{tool_names}
# 工作流程
1. 从用户问题中{query}提取关键信息(节点路径、节点类型、节点名称等)
2. 使用工具查询知识图谱结构以理解可用节点和节点属性
3. 根据查询结果选择最匹配的{project_class_methods}中的方法
4. 生成可直接执行的Python代码
# 代码模板(必须严格遵循)
def neo4j_find_function():
project = ProjectBuilder.build()
status, data, error, helper_info = project.[SELECTED_METHOD]([PARAMETERS])
return status, data, error, helper_info
# 执行规则
- 每次只能调用一个工具或生成最终代码
- 参数必须从用户问题或知识图谱查询结果中提取
- 必须确保生成的代码可以直接执行
- 禁止修改代码模板结构
- 禁止添加任何注释或解释
- 禁止在代码前加上```python字样
- 禁止在代码后加上```字样
# 当前进度
{agent_scratchpad}
# 响应格式
思考: 分析当前步骤需要做什么
行动: 选择工具名称
行动输入: 工具参数
观察: 工具返回结果
...(重复直到准备好生成代码)...
思考: 已收集足够信息,可以生成代码
Final Answer:
def neo4j_find_function():
project = ProjectBuilder.build()
status, data, error, helper_info = project.[SELECTED_METHOD]([PARAMETERS])
return status, data, error, helper_info
"""
FUNCTION_CALL_PROMPT = PromptTemplate.from_template(FUNCTION_CALL_TEMPLATE)
###########################################################################################################################################################################
FUNCTION_RETURNS_LOOP_TEMPLATE = """
你是一个专业的Python工程师。我会给你一段错误python代码和错误信息,你需要帮我修复这段出错的代码
你的任务是:
1. 根据需要修改的代码{original_code}和代码的错误信息{error_info}来对代码和参数进行修改
2. 如果错误信息中是代码的逻辑出现错误,那么就需要对代码本身整体结构进行修改
3. 如果是代码中参数出现问题了,那么就需要结合错误信息中的帮助信息(helper_info)来对代码总的参数进行修改
4. 修复后的代码应该完整,可以直接执行,并且能够返回查询结果
注意:
- 必须只输出最终的Python代码,不要添加任何解释、注释、推理过程或自然语言描述。
- 不要以“以下是修正后的代码”、“修改如下”等语句开头。
- 不要输出任何其他无关的内容。
- 输出格式必须完全符合指定的函数模板。
- 如果无法根据已有信息进行修改,请原样返回原始代码。
- 禁止在代码前加上```python字样
- 禁止在代码后加上```字样
请输出你修补后的代码:
"""
FUNCTION_RETURNS_LOOP_PROMPT: PromptTemplate = PromptTemplate.from_template(FUNCTION_RETURNS_LOOP_TEMPLATE)