上传文件至 /
This commit is contained in:
@@ -0,0 +1,492 @@
|
||||
from langchain.chains import LLMChain
|
||||
from langchain_openai import OpenAI
|
||||
from langchain_experimental.utilities import PythonREPL
|
||||
from project_implementation import ProjectBuilder
|
||||
from prompt_templates import FUNCTION_CALL_PROMPT
|
||||
import inspect
|
||||
import project
|
||||
import io
|
||||
import sys
|
||||
from parameter_rewriting import rewrite_query_parameters, KnowledgeGraphProcessor
|
||||
import json
|
||||
from langchain.agents import Tool, AgentExecutor, create_react_agent
|
||||
|
||||
from llm import llm
|
||||
|
||||
|
||||
# 获取ProjectTookiIt类的方法定义
|
||||
def get_project_class_methods():
|
||||
"""
|
||||
从project模块中提取ProjectTookiIt类的方法定义
|
||||
|
||||
Returns:
|
||||
str: 格式化后的方法定义字符串
|
||||
"""
|
||||
project_class_code = inspect.getsource(project.ProjectTookiIt)
|
||||
|
||||
lines = project_class_code.split("\n")
|
||||
result_lines = []
|
||||
|
||||
in_class = False
|
||||
skip_init = False
|
||||
|
||||
for line in lines:
|
||||
if line.strip().startswith("class ProjectTookiIt"):
|
||||
in_class = True
|
||||
result_lines.append(line)
|
||||
elif in_class:
|
||||
if line.strip().startswith("def __init__"):
|
||||
skip_init = True
|
||||
elif skip_init and line.strip() and not line.startswith(" " * 8):
|
||||
skip_init = False
|
||||
|
||||
if not skip_init:
|
||||
result_lines.append(line)
|
||||
|
||||
return "\n".join(result_lines)
|
||||
|
||||
|
||||
# 创建动态提示模板
|
||||
project_class_methods = get_project_class_methods()
|
||||
|
||||
# 创建 Chain
|
||||
function_call_chain = LLMChain(llm=llm, prompt=FUNCTION_CALL_PROMPT, output_key="code")
|
||||
|
||||
# Python 执行器
|
||||
repl = PythonREPL()
|
||||
|
||||
# 创建知识图谱处理器实例
|
||||
kg_processor = KnowledgeGraphProcessor()
|
||||
|
||||
|
||||
# 定义搜索知识库的工具
|
||||
def search_knowledge_and_node_definition(query):
|
||||
"""
|
||||
在知识库中搜索关键词
|
||||
|
||||
Args:
|
||||
query (str): 搜索关键词
|
||||
|
||||
Returns:
|
||||
str: 搜索结果的JSON字符串
|
||||
"""
|
||||
found_data = kg_processor._get_relevant_knowledge(query)
|
||||
if found_data:
|
||||
return json.dumps(found_data, ensure_ascii=False, indent=2)
|
||||
else:
|
||||
return f"未找到与'{query}'相关的信息"
|
||||
|
||||
|
||||
# 创建工具列表
|
||||
tools = [
|
||||
Tool(
|
||||
name="search_knowledge_and_node_definition",
|
||||
func=search_knowledge_and_node_definition,
|
||||
description="获取输入节点的知识图谱结构和对应节点定义类型代码。输入应该是一个节点类型名称。",
|
||||
),
|
||||
]
|
||||
|
||||
# 创建Agent
|
||||
agent = create_react_agent(llm, tools, FUNCTION_CALL_PROMPT)
|
||||
|
||||
# 创建Agent执行器
|
||||
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True)
|
||||
|
||||
|
||||
def nl_query_to_function_call(input_data):
|
||||
"""
|
||||
将自然语言查询转换为函数调用并执行,或直接执行提供的代码
|
||||
|
||||
Args:
|
||||
input_data (dict): 包含type和value的字典
|
||||
{
|
||||
"type": "query|code",
|
||||
"value": "查询内容或代码"
|
||||
}
|
||||
|
||||
Returns:
|
||||
dict: 包含状态码、消息和数据的字典
|
||||
"""
|
||||
|
||||
input_type = input_data.get("type", "query")
|
||||
input_value = input_data.get("value", "")
|
||||
original_query = input_value
|
||||
|
||||
print(f"\n====== 开始执行查询函数 ======")
|
||||
print(f"查询类型: {input_type}")
|
||||
print(f"查询内容: {input_value}")
|
||||
|
||||
# 第一次执行:使用Agent生成代码
|
||||
if input_type == "query":
|
||||
print("\n----- 第一次尝试:使用Agent生成代码 -----")
|
||||
|
||||
try:
|
||||
agent_response = agent_executor.invoke(
|
||||
{
|
||||
"query": input_value,
|
||||
"project_class_methods": project_class_methods,
|
||||
}
|
||||
)
|
||||
|
||||
code = agent_response["output"].strip()
|
||||
print(f"\n生成的代码:\n{code}")
|
||||
except Exception as e:
|
||||
# 确保恢复stdout
|
||||
sys.stdout = old_stdout
|
||||
return {
|
||||
"code": 1,
|
||||
"message": f"Agent生成代码失败: {e}",
|
||||
"data": {"value": "", "code": ""},
|
||||
"query_status": "Agent生成代码失败",
|
||||
}
|
||||
else:
|
||||
code = input_value.strip()
|
||||
print(f"\n使用提供的代码:\n{code}")
|
||||
|
||||
original_code = code
|
||||
|
||||
def execute_code(code_str):
|
||||
"""封装代码执行逻辑"""
|
||||
print(f"开始执行代码: {code_str[:50]}...")
|
||||
try:
|
||||
namespace = {
|
||||
"ProjectBuilder": ProjectBuilder,
|
||||
"project_implementation": __import__("project_implementation"),
|
||||
"project": __import__("project"),
|
||||
}
|
||||
|
||||
old_stdout = sys.stdout
|
||||
redirected_output = io.StringIO()
|
||||
sys.stdout = redirected_output
|
||||
|
||||
exec(code_str, namespace)
|
||||
|
||||
# 确保neo4j_find_function存在
|
||||
if "neo4j_find_function" not in namespace:
|
||||
raise ValueError("代码中未定义neo4j_find_function函数")
|
||||
|
||||
result_tuple = namespace["neo4j_find_function"]()
|
||||
|
||||
sys.stdout = old_stdout
|
||||
output = redirected_output.getvalue().strip()
|
||||
|
||||
if not isinstance(result_tuple, tuple) or len(result_tuple) != 4:
|
||||
raise ValueError("函数应返回包含4个元素的元组(status, data, error, helper_info)")
|
||||
|
||||
status, data, error, helper_info = result_tuple
|
||||
|
||||
# 添加更详细的日志
|
||||
print(f"执行结果: status={status}, data={data}, error={error}")
|
||||
|
||||
return {
|
||||
"status": status,
|
||||
"data": data,
|
||||
"error": error,
|
||||
"helper_info": helper_info,
|
||||
"output": output,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# 确保恢复stdout
|
||||
sys.stdout = old_stdout
|
||||
import traceback
|
||||
|
||||
print(f"执行代码时出错: {e}")
|
||||
print(traceback.format_exc())
|
||||
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"helper_info": [],
|
||||
"traceback": traceback.format_exc(),
|
||||
}
|
||||
|
||||
# 第一次执行
|
||||
print("开始执行函数...")
|
||||
|
||||
print("准备执行代码...")
|
||||
result = execute_code(code)
|
||||
print(f"执行完成,结果状态: {result.get('status', 'unknown')}")
|
||||
|
||||
# 检查是否需要重写
|
||||
if result["status"] == "success":
|
||||
print("执行成功,准备返回结果...")
|
||||
# 返回成功结果...
|
||||
final_result = {
|
||||
"code": 0,
|
||||
"message": "成功",
|
||||
"data": {
|
||||
"value": result["data"],
|
||||
"code": original_code,
|
||||
},
|
||||
"query_status": "第一次查询成功",
|
||||
}
|
||||
else:
|
||||
print(f"执行失败,错误: {result.get('error', 'unknown')}")
|
||||
# 重写代码...
|
||||
|
||||
# 第一次失败,记录错误信息
|
||||
error_info = {
|
||||
"error": result["error"],
|
||||
"helper_info": result.get("helper_info", []),
|
||||
"traceback": result.get("traceback", ""),
|
||||
}
|
||||
|
||||
print("\n第一次查询失败,尝试使用LLM修复代码...")
|
||||
|
||||
# 使用LLM直接修复代码
|
||||
rewritten = rewrite_query_parameters(original_code=original_code, error_info=error_info)
|
||||
fixed_code = rewritten.get("code", "").strip()
|
||||
|
||||
if not fixed_code:
|
||||
print("\nLLM未返回有效代码")
|
||||
final_result = {
|
||||
"code": 1,
|
||||
"message": error_info["error"],
|
||||
"data": {"value": "", "code": original_code},
|
||||
"error_info": error_info,
|
||||
"query_status": "第一次查询失败,LLM未返回有效代码",
|
||||
}
|
||||
else:
|
||||
print(f"\nLLM修复后的代码:\n{fixed_code}")
|
||||
|
||||
# 第二次执行修复后的代码
|
||||
result = execute_code(fixed_code)
|
||||
|
||||
if result["status"] == "success":
|
||||
print("\nLLM修复后查询成功")
|
||||
print(f"\n执行结果详情: {json.dumps(result, ensure_ascii=False, indent=2)}")
|
||||
final_result = {
|
||||
"code": 0,
|
||||
"message": "成功",
|
||||
"data": {
|
||||
"value": result["data"],
|
||||
"code": fixed_code,
|
||||
},
|
||||
"query_status": "LLM修复后查询成功",
|
||||
}
|
||||
else:
|
||||
print("\nLLM修复后查询仍然失败")
|
||||
final_result = {
|
||||
"code": 1,
|
||||
"message": result["error"],
|
||||
"data": {"value": "", "code": fixed_code},
|
||||
"error_info": {
|
||||
"error": result["error"],
|
||||
"helper_info": result.get("helper_info", []),
|
||||
},
|
||||
"query_status": "LLM修复后查询仍然失败",
|
||||
}
|
||||
|
||||
# 最后返回前
|
||||
print("函数执行完毕,准备返回最终结果")
|
||||
return final_result
|
||||
|
||||
|
||||
def format_result(result):
|
||||
"""
|
||||
格式化查询结果
|
||||
|
||||
Args:
|
||||
result: 查询结果(可能为 list、dict 或其他类型)
|
||||
|
||||
Returns:
|
||||
str: 格式化后的结果
|
||||
"""
|
||||
# 处理 project 对象
|
||||
if hasattr(result, "__module__") and result.__module__ == "project":
|
||||
# 这是一个 project 模块中的对象
|
||||
attrs = {k: v for k, v in result.__dict__.items() if not k.startswith("_")}
|
||||
return json.dumps(attrs, ensure_ascii=False, indent=2)
|
||||
|
||||
# 处理 project 对象列表
|
||||
if isinstance(result, list) and all(
|
||||
hasattr(item, "__module__") and item.__module__ == "project" for item in result if hasattr(item, "__module__")
|
||||
):
|
||||
formatted_items = []
|
||||
for item in result:
|
||||
if hasattr(item, "__dict__"):
|
||||
attrs = {k: v for k, v in item.__dict__.items() if not k.startswith("_")}
|
||||
formatted_items.append(attrs)
|
||||
else:
|
||||
formatted_items.append(str(item))
|
||||
|
||||
return json.dumps(formatted_items, ensure_ascii=False, indent=2)
|
||||
|
||||
# 如果结果是字符串,可能包含调试信息,需要提取有用部分
|
||||
if isinstance(result, str):
|
||||
# 尝试提取最终结果部分
|
||||
if "[]" in result:
|
||||
return "未找到匹配的数据。"
|
||||
|
||||
# 如果包含节点信息,提取关键部分
|
||||
import re
|
||||
|
||||
node_match = re.search(r"找到.*?labels=.*?properties=(.*?)>", result)
|
||||
if node_match:
|
||||
try:
|
||||
# 提取属性部分并格式化
|
||||
props_str = node_match.group(1).replace("'", '"')
|
||||
import ast
|
||||
|
||||
props = ast.literal_eval(props_str)
|
||||
|
||||
formatted = "找到节点:\n"
|
||||
for k, v in props.items():
|
||||
formatted += f" {k}: {v}\n"
|
||||
return formatted
|
||||
except:
|
||||
pass
|
||||
|
||||
# 如果包含查询结果数量
|
||||
count_match = re.search(r"查询结果数量: (\d+)", result)
|
||||
if count_match:
|
||||
count = count_match.group(1)
|
||||
if count == "0":
|
||||
return "未找到匹配的数据。"
|
||||
|
||||
# 如果是列表
|
||||
if isinstance(result, list):
|
||||
if not result:
|
||||
return "未找到匹配的数据。"
|
||||
|
||||
lines = [f"找到 {len(result)} 条匹配结果:"]
|
||||
for i, item in enumerate(result, 1):
|
||||
lines.append(f"\n结果 {i}:")
|
||||
if hasattr(item, "items"): # 检查是否有items方法(字典或类似字典的对象)
|
||||
try:
|
||||
for k, v in item.items():
|
||||
lines.append(f" {k}: {v}")
|
||||
except:
|
||||
lines.append(f" {item}")
|
||||
else:
|
||||
lines.append(f" {item}")
|
||||
return "\n".join(lines)
|
||||
|
||||
# 如果是字典
|
||||
elif isinstance(result, dict):
|
||||
lines = ["查询结果:"]
|
||||
for k, v in result.items():
|
||||
lines.append(f" {k}: {v}")
|
||||
return "\n".join(lines)
|
||||
|
||||
# 其他类型
|
||||
else:
|
||||
return str(result)
|
||||
|
||||
|
||||
def format_dict_or_item(item):
|
||||
"""
|
||||
格式化字典或其他对象
|
||||
|
||||
Args:
|
||||
item: 字典或其他对象
|
||||
|
||||
Returns:
|
||||
str: 格式化后的字符串
|
||||
"""
|
||||
if isinstance(item, dict):
|
||||
formatted = ""
|
||||
for key, value in item.items():
|
||||
formatted += f" {key}: {value}\n"
|
||||
return formatted
|
||||
|
||||
return str(item)
|
||||
|
||||
|
||||
def _extract_all_values(item, result_list):
|
||||
"""
|
||||
递归提取字典中的所有值并添加到结果列表中
|
||||
|
||||
Args:
|
||||
item: 字典或其他对象
|
||||
result_list: 用于存储结果的列表
|
||||
"""
|
||||
if isinstance(item, dict):
|
||||
for key, value in item.items():
|
||||
if key != "children": # 跳过children键,它会在主逻辑中单独处理
|
||||
if isinstance(value, str):
|
||||
result_list.append(value)
|
||||
elif isinstance(value, (dict, list)):
|
||||
_extract_all_values(value, result_list)
|
||||
elif isinstance(item, list):
|
||||
for element in item:
|
||||
_extract_all_values(element, result_list)
|
||||
|
||||
|
||||
# 定义一个辅助函数来提取和处理相似参数
|
||||
def extract_similar_parameters(path_parts, knowledge_base):
|
||||
"""
|
||||
从路径和知识库中提取相似参数
|
||||
|
||||
Args:
|
||||
path_parts (list): 路径部分列表
|
||||
knowledge_base (list): 知识库数据
|
||||
|
||||
Returns:
|
||||
str: 相似参数字符串
|
||||
"""
|
||||
# 从路径中提取关键词
|
||||
extracted_parts = []
|
||||
for part in path_parts:
|
||||
if "/" in part:
|
||||
last_part = part.split("/")[-1].strip()
|
||||
if last_part:
|
||||
extracted_parts.append(last_part)
|
||||
else:
|
||||
extracted_parts.append(part)
|
||||
|
||||
# 收集所有可能的相似参数
|
||||
similar_params = []
|
||||
|
||||
# 从knowledge_base中查找相关项及其子节点
|
||||
for part in extracted_parts:
|
||||
for item in knowledge_base:
|
||||
# 检查当前项是否匹配
|
||||
match_found = False
|
||||
for key, value in item.items():
|
||||
if isinstance(value, str) and part.lower() in value.lower():
|
||||
match_found = True
|
||||
break
|
||||
|
||||
if match_found:
|
||||
# 如果找到匹配项,提取所有值
|
||||
_extract_all_values(item, similar_params)
|
||||
|
||||
# 特别处理子节点
|
||||
if "children" in item and isinstance(item["children"], list):
|
||||
for child in item["children"]:
|
||||
_extract_all_values(child, similar_params)
|
||||
|
||||
# 移除重复项并排序
|
||||
similar_params = list(set(similar_params))
|
||||
similar_params.sort()
|
||||
|
||||
print(f"找到的相似参数: {similar_params}")
|
||||
|
||||
# 创建一个包含所有相似参数的字符串
|
||||
return ", ".join(similar_params)
|
||||
|
||||
|
||||
# 在查询前先增加一个简单函数,专门提取字符串中的键值对值
|
||||
def extract_values_from_kb_string(kb_string):
|
||||
"""从知识库字符串中提取所有键值对的值"""
|
||||
import re
|
||||
|
||||
# 匹配所有键值对:"key": "value" 的模式
|
||||
# 这里我们直接取第二个捕获组,即值部分
|
||||
values = re.findall(r'"([^"]+)"\s*:\s*"([^"]+)"', kb_string)
|
||||
|
||||
# 只保留值(第二个元素)
|
||||
result = [match[1] for match in values]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# question = {
|
||||
# "type": "query",
|
||||
# "value": "查找一下【工程数据/安装工程/安装/架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立】的类型为【主材】的【塔材】",
|
||||
# }
|
||||
# result = nl_query_to_function_call(question)
|
||||
# print("\n最终结果:")
|
||||
# print(result)
|
||||
@@ -0,0 +1,327 @@
|
||||
import json
|
||||
import ast
|
||||
import inspect
|
||||
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
from langchain.agents import Tool, initialize_agent, AgentType
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import BaseOutputParser
|
||||
from langchain.schema import BaseRetriever
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.schema import Document
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.llms.base import BaseLLM
|
||||
|
||||
from prompt_templates import FUNCTION_RETURNS_LOOP_PROMPT
|
||||
from llm import llm as base_llm
|
||||
import project # 明确导入project模块
|
||||
|
||||
|
||||
class CodeOnlyOutputParser(BaseOutputParser):
|
||||
def parse(self, text: str) -> dict:
|
||||
if "Final Answer:" in text:
|
||||
code_part = text.split("Final Answer:")[-1].strip()
|
||||
else:
|
||||
code_part = text.strip()
|
||||
|
||||
return {"code": code_part}
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return "只输出最终的Python代码,不要包含其他解释或内容。"
|
||||
|
||||
|
||||
class KnowledgeGraphProcessor:
|
||||
"""知识图谱处理器"""
|
||||
|
||||
def __init__(self, kg_file="kg_simple_hierarchy.json"):
|
||||
"""初始化处理器"""
|
||||
self.kg_data = self._load_kg_data(kg_file)
|
||||
self.prompt = FUNCTION_RETURNS_LOOP_PROMPT
|
||||
|
||||
def _load_kg_data(self, file_path: str) -> Dict[str, Any]:
|
||||
"""加载知识图谱数据"""
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except FileNotFoundError:
|
||||
print(f"警告: 未找到 {file_path} 文件")
|
||||
return {}
|
||||
except Exception as e:
|
||||
print(f"加载知识图谱数据时出错: {e}")
|
||||
return {}
|
||||
|
||||
def _extract_parameters_from_code(self, code: str) -> List[str]:
|
||||
"""从代码中提取字符串参数"""
|
||||
try:
|
||||
tree = ast.parse(code)
|
||||
parameters = []
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Str):
|
||||
parameters.append(node.s)
|
||||
elif isinstance(node, ast.Constant) and isinstance(node.value, str):
|
||||
parameters.append(node.value)
|
||||
|
||||
return parameters
|
||||
except Exception as e:
|
||||
print(f"解析代码时出错: {e}")
|
||||
return []
|
||||
|
||||
def _find_node_by_path(self, path: str, data: Union[Dict[str, Any], List[Any]] = None) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
根据路径查找节点
|
||||
|
||||
Args:
|
||||
path: 节点路径,格式如 "工程数据/安装过程/安装/架空输电线路本体工程/杆塔工程"
|
||||
data: 当前数据节点
|
||||
|
||||
Returns:
|
||||
找到的节点及其子节点
|
||||
"""
|
||||
if data is None:
|
||||
data = self.kg_data
|
||||
|
||||
path_parts = path.split("/")
|
||||
current_part = path_parts[0]
|
||||
|
||||
# 如果data是列表,则遍历列表中的每个元素
|
||||
if isinstance(data, list):
|
||||
for item in data:
|
||||
result = self._find_node_by_path(path, item)
|
||||
if result:
|
||||
return result
|
||||
return None
|
||||
|
||||
# 如果data是字典,则按原来的逻辑处理
|
||||
if isinstance(data, dict):
|
||||
# 检查当前节点是否匹配
|
||||
found_node = None
|
||||
for key, value in data.items():
|
||||
if isinstance(value, str) and current_part in value:
|
||||
found_node = data
|
||||
break
|
||||
|
||||
# 如果在当前层级找到了节点
|
||||
if found_node:
|
||||
# 如果还有更深的路径部分,则继续递归查找
|
||||
if len(path_parts) > 1:
|
||||
# 查找子节点
|
||||
if "children" in found_node:
|
||||
for child in found_node["children"]:
|
||||
result = self._find_node_by_path("/".join(path_parts[1:]), child)
|
||||
if result:
|
||||
return result
|
||||
return None
|
||||
else:
|
||||
# 已经找到最终节点
|
||||
return found_node
|
||||
|
||||
# 如果当前层级没找到,则查找子节点
|
||||
if "children" in data:
|
||||
for child in data["children"]:
|
||||
result = self._find_node_by_path(path, child)
|
||||
if result:
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
def _search_in_kg(self, parameter: str, data: Union[Dict[str, Any], List[Any]] = None) -> Optional[Dict[str, Any]]:
|
||||
"""在知识图谱中递归搜索参数"""
|
||||
# 先尝试按路径查找
|
||||
if "/" in parameter:
|
||||
result = self._find_node_by_path(parameter)
|
||||
if result:
|
||||
return result
|
||||
|
||||
# 如果路径查找失败,则按关键词搜索
|
||||
if data is None:
|
||||
data = self.kg_data
|
||||
|
||||
# 如果是列表,遍历每个元素
|
||||
if isinstance(data, list):
|
||||
for item in data:
|
||||
if isinstance(item, (dict, list)):
|
||||
result = self._search_in_kg(parameter, item)
|
||||
if result:
|
||||
return result
|
||||
return None
|
||||
|
||||
# 如果是字典,按原来的逻辑处理
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
if isinstance(value, str) and parameter in value:
|
||||
return data
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, (dict, list)):
|
||||
result = self._search_in_kg(parameter, item)
|
||||
if result:
|
||||
return result
|
||||
elif isinstance(value, dict):
|
||||
result = self._search_in_kg(parameter, value)
|
||||
if result:
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
def _get_node_definition(self, node_type: str) -> str:
|
||||
"""获取节点类型定义"""
|
||||
try:
|
||||
# 检查project模块中是否存在该类
|
||||
if hasattr(project, node_type):
|
||||
cls = getattr(project, node_type)
|
||||
source = inspect.getsource(cls)
|
||||
return source
|
||||
else:
|
||||
return f"未找到类型 {node_type} 的定义"
|
||||
except NameError:
|
||||
return f"获取类型 {node_type} 定义时出错: project模块未正确导入"
|
||||
except Exception as e:
|
||||
return f"获取类型 {node_type} 定义时出错: {e}"
|
||||
|
||||
def _extract_node_types(self, data: Dict[str, Any], node_types: set):
|
||||
"""递归提取所有节点类型"""
|
||||
if not isinstance(data, dict):
|
||||
return
|
||||
|
||||
for key, value in data.items():
|
||||
# 所有键都可能是节点类型
|
||||
node_types.add(key)
|
||||
|
||||
if isinstance(value, dict):
|
||||
self._extract_node_types(value, node_types)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
self._extract_node_types(item, node_types)
|
||||
|
||||
def _get_relevant_knowledge(self, code: str) -> tuple[str, str]:
|
||||
"""根据代码获取相关知识库内容和节点定义"""
|
||||
parameters = self._extract_parameters_from_code(code)
|
||||
|
||||
knowledge_base = ""
|
||||
node_definitions = ""
|
||||
|
||||
for param in parameters:
|
||||
# 处理【】格式的路径,提取最后一个部分
|
||||
if "【" in param:
|
||||
path_parts = []
|
||||
parts = param.split("【")
|
||||
for part in parts:
|
||||
if "】" in part:
|
||||
clean_part = part.split("】")[0].strip()
|
||||
if clean_part:
|
||||
path_parts.append(clean_part)
|
||||
|
||||
# 对于每个路径,只取最后一个部分
|
||||
for path in path_parts:
|
||||
if "/" in path:
|
||||
# 提取路径中的最后一个部分
|
||||
last_part = path.split("/")[-1].strip()
|
||||
if last_part:
|
||||
found_data = self._search_in_kg(last_part)
|
||||
if found_data:
|
||||
knowledge_base += f"节点 '{last_part}' 相关信息:\n"
|
||||
knowledge_base += json.dumps(found_data, ensure_ascii=False, indent=2) + "\n\n"
|
||||
|
||||
# 提取节点类型
|
||||
node_types = set()
|
||||
self._extract_node_types(found_data, node_types)
|
||||
|
||||
for node_type in node_types:
|
||||
definition = self._get_node_definition(node_type)
|
||||
if definition and "未找到" not in definition and "出错" not in definition:
|
||||
node_definitions += f"类型 {node_type} 定义:\n{definition}\n\n"
|
||||
else:
|
||||
# 如果没有/,直接使用整个部分
|
||||
found_data = self._search_in_kg(path)
|
||||
if found_data:
|
||||
knowledge_base += f"节点 '{path}' 相关信息:\n"
|
||||
knowledge_base += json.dumps(found_data, ensure_ascii=False, indent=2) + "\n\n"
|
||||
|
||||
# 提取节点类型
|
||||
node_types = set()
|
||||
self._extract_node_types(found_data, node_types)
|
||||
|
||||
for node_type in node_types:
|
||||
definition = self._get_node_definition(node_type)
|
||||
if definition and "未找到" not in definition and "出错" not in definition:
|
||||
node_definitions += f"类型 {node_type} 定义:\n{definition}\n\n"
|
||||
else:
|
||||
# 处理普通参数
|
||||
found_data = self._search_in_kg(param)
|
||||
if found_data:
|
||||
knowledge_base += f"参数 '{param}' 相关信息:\n"
|
||||
knowledge_base += json.dumps(found_data, ensure_ascii=False, indent=2) + "\n\n"
|
||||
|
||||
# 提取节点类型
|
||||
node_types = set()
|
||||
self._extract_node_types(found_data, node_types)
|
||||
|
||||
for node_type in node_types:
|
||||
definition = self._get_node_definition(node_type)
|
||||
if definition and "未找到" not in definition and "出错" not in definition:
|
||||
node_definitions += f"类型 {node_type} 定义:\n{definition}\n\n"
|
||||
|
||||
return knowledge_base, node_definitions
|
||||
|
||||
def process_query(self, original_code: str, error_info: Dict[str, str] = None):
|
||||
print("\n====== RAG重写查询开始 ======")
|
||||
print(f"原始代码: {original_code}")
|
||||
print(f"错误信息: {error_info}")
|
||||
|
||||
# 检查是否需要重写
|
||||
if error_info and isinstance(error_info, dict) and "error" in error_info:
|
||||
# 有错误信息,需要重写
|
||||
input_vars = {
|
||||
"original_code": original_code,
|
||||
"error_info": json.dumps(error_info, ensure_ascii=False) if error_info else "",
|
||||
}
|
||||
|
||||
# 使用LLM
|
||||
llm_chain = LLMChain(llm=base_llm, prompt=FUNCTION_RETURNS_LOOP_PROMPT)
|
||||
response = llm_chain.invoke(input_vars)
|
||||
|
||||
print(f"LLM响应: {response}")
|
||||
|
||||
# 提取代码
|
||||
fixed_code = response["text"].strip()
|
||||
return {"code": fixed_code}
|
||||
else:
|
||||
# 没有错误信息,不需要重写
|
||||
print("没有错误信息,不需要重写")
|
||||
return {"code": original_code}
|
||||
|
||||
|
||||
# 全局处理器实例(延迟初始化)
|
||||
_processor = None
|
||||
|
||||
|
||||
def _get_processor():
|
||||
"""获取处理器实例(单例模式)"""
|
||||
global _processor
|
||||
if _processor is None:
|
||||
_processor = KnowledgeGraphProcessor()
|
||||
return _processor
|
||||
|
||||
|
||||
def rewrite_query_parameters(original_code: str, error_info: Dict[str, str] = None) -> Dict[str, str]:
|
||||
"""
|
||||
重写查询参数的外部接口函数
|
||||
|
||||
Args:
|
||||
original_query: 原始查询字符串
|
||||
original_code: 原始代码字符串
|
||||
error_info: 错误信息字典 (可选)
|
||||
|
||||
Returns:
|
||||
Dict: 包含修改后的query和code的字典
|
||||
|
||||
Example:
|
||||
result = rewrite_query_parameters("查询动态费用", 'search_node("动态费用")')
|
||||
print(result["query"]) # 修改后的查询
|
||||
print(result["code"]) # 修改后的代码
|
||||
"""
|
||||
processor = _get_processor()
|
||||
return processor.process_query(original_code, error_info)
|
||||
@@ -0,0 +1,80 @@
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
FUNCTION_CALL_TEMPLATE = """
|
||||
你是一个专业的Python工程师。我会给你一个用户问题,你需要将其转换为对应的Python代码
|
||||
|
||||
可用工具:
|
||||
{tools}
|
||||
|
||||
工具名称:
|
||||
{tool_names}
|
||||
|
||||
# 工作流程
|
||||
1. 从用户问题中{query}提取关键信息(节点路径、节点类型、节点名称等)
|
||||
2. 使用工具查询知识图谱结构以理解可用节点和节点属性
|
||||
3. 根据查询结果选择最匹配的{project_class_methods}中的方法
|
||||
4. 生成可直接执行的Python代码
|
||||
|
||||
# 代码模板(必须严格遵循)
|
||||
def neo4j_find_function():
|
||||
project = ProjectBuilder.build()
|
||||
status, data, error, helper_info = project.[SELECTED_METHOD]([PARAMETERS])
|
||||
return status, data, error, helper_info
|
||||
|
||||
# 执行规则
|
||||
- 每次只能调用一个工具或生成最终代码
|
||||
- 参数必须从用户问题或知识图谱查询结果中提取
|
||||
- 必须确保生成的代码可以直接执行
|
||||
- 禁止修改代码模板结构
|
||||
- 禁止添加任何注释或解释
|
||||
- 禁止在代码前加上```python字样
|
||||
- 禁止在代码后加上```字样
|
||||
|
||||
# 当前进度
|
||||
{agent_scratchpad}
|
||||
|
||||
# 响应格式
|
||||
思考: 分析当前步骤需要做什么
|
||||
行动: 选择工具名称
|
||||
行动输入: 工具参数
|
||||
观察: 工具返回结果
|
||||
|
||||
...(重复直到准备好生成代码)...
|
||||
|
||||
思考: 已收集足够信息,可以生成代码
|
||||
Final Answer:
|
||||
def neo4j_find_function():
|
||||
project = ProjectBuilder.build()
|
||||
status, data, error, helper_info = project.[SELECTED_METHOD]([PARAMETERS])
|
||||
return status, data, error, helper_info
|
||||
"""
|
||||
|
||||
FUNCTION_CALL_PROMPT = PromptTemplate.from_template(FUNCTION_CALL_TEMPLATE)
|
||||
|
||||
|
||||
###########################################################################################################################################################################
|
||||
|
||||
FUNCTION_RETURNS_LOOP_TEMPLATE = """
|
||||
|
||||
你是一个专业的Python工程师。我会给你一段错误python代码和错误信息,你需要帮我修复这段出错的代码
|
||||
|
||||
你的任务是:
|
||||
1. 根据需要修改的代码{original_code}和代码的错误信息{error_info}来对代码和参数进行修改
|
||||
2. 如果错误信息中是代码的逻辑出现错误,那么就需要对代码本身整体结构进行修改
|
||||
3. 如果是代码中参数出现问题了,那么就需要结合错误信息中的帮助信息(helper_info)来对代码总的参数进行修改
|
||||
4. 修复后的代码应该完整,可以直接执行,并且能够返回查询结果
|
||||
|
||||
注意:
|
||||
- 必须只输出最终的Python代码,不要添加任何解释、注释、推理过程或自然语言描述。
|
||||
- 不要以“以下是修正后的代码”、“修改如下”等语句开头。
|
||||
- 不要输出任何其他无关的内容。
|
||||
- 输出格式必须完全符合指定的函数模板。
|
||||
- 如果无法根据已有信息进行修改,请原样返回原始代码。
|
||||
- 禁止在代码前加上```python字样
|
||||
- 禁止在代码后加上```字样
|
||||
|
||||
请输出你修补后的代码:
|
||||
"""
|
||||
|
||||
|
||||
FUNCTION_RETURNS_LOOP_PROMPT: PromptTemplate = PromptTemplate.from_template(FUNCTION_RETURNS_LOOP_TEMPLATE)
|
||||
Reference in New Issue
Block a user