Files
langchain_KG/langchain_neo4j.py
T
2025-06-19 16:53:47 +08:00

493 lines
15 KiB
Python

from langchain.chains import LLMChain
from langchain_openai import OpenAI
from langchain_experimental.utilities import PythonREPL
from project_implementation import ProjectBuilder
from prompt_templates import FUNCTION_CALL_PROMPT
import inspect
import project
import io
import sys
from parameter_rewriting import rewrite_query_parameters, KnowledgeGraphProcessor
import json
from langchain.agents import Tool, AgentExecutor, create_react_agent
from llm import llm
# 获取ProjectTookiIt类的方法定义
def get_project_class_methods():
"""
从project模块中提取ProjectTookiIt类的方法定义
Returns:
str: 格式化后的方法定义字符串
"""
project_class_code = inspect.getsource(project.ProjectTookiIt)
lines = project_class_code.split("\n")
result_lines = []
in_class = False
skip_init = False
for line in lines:
if line.strip().startswith("class ProjectTookiIt"):
in_class = True
result_lines.append(line)
elif in_class:
if line.strip().startswith("def __init__"):
skip_init = True
elif skip_init and line.strip() and not line.startswith(" " * 8):
skip_init = False
if not skip_init:
result_lines.append(line)
return "\n".join(result_lines)
# 创建动态提示模板
project_class_methods = get_project_class_methods()
# 创建 Chain
function_call_chain = LLMChain(llm=llm, prompt=FUNCTION_CALL_PROMPT, output_key="code")
# Python 执行器
repl = PythonREPL()
# 创建知识图谱处理器实例
kg_processor = KnowledgeGraphProcessor()
# 定义搜索知识库的工具
def search_knowledge_and_node_definition(query):
"""
在知识库中搜索关键词
Args:
query (str): 搜索关键词
Returns:
str: 搜索结果的JSON字符串
"""
found_data = kg_processor._get_relevant_knowledge(query)
if found_data:
return json.dumps(found_data, ensure_ascii=False, indent=2)
else:
return f"未找到与'{query}'相关的信息"
# 创建工具列表
tools = [
Tool(
name="search_knowledge_and_node_definition",
func=search_knowledge_and_node_definition,
description="获取输入节点的知识图谱结构和对应节点定义类型代码。输入应该是一个节点类型名称。",
),
]
# 创建Agent
agent = create_react_agent(llm, tools, FUNCTION_CALL_PROMPT)
# 创建Agent执行器
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True)
def nl_query_to_function_call(input_data):
"""
将自然语言查询转换为函数调用并执行,或直接执行提供的代码
Args:
input_data (dict): 包含type和value的字典
{
"type": "query|code",
"value": "查询内容或代码"
}
Returns:
dict: 包含状态码、消息和数据的字典
"""
input_type = input_data.get("type", "query")
input_value = input_data.get("value", "")
original_query = input_value
print(f"\n====== 开始执行查询函数 ======")
print(f"查询类型: {input_type}")
print(f"查询内容: {input_value}")
# 第一次执行:使用Agent生成代码
if input_type == "query":
print("\n----- 第一次尝试:使用Agent生成代码 -----")
try:
agent_response = agent_executor.invoke(
{
"query": input_value,
"project_class_methods": project_class_methods,
}
)
code = agent_response["output"].strip()
print(f"\n生成的代码:\n{code}")
except Exception as e:
# 确保恢复stdout
sys.stdout = old_stdout
return {
"code": 1,
"message": f"Agent生成代码失败: {e}",
"data": {"value": "", "code": ""},
"query_status": "Agent生成代码失败",
}
else:
code = input_value.strip()
print(f"\n使用提供的代码:\n{code}")
original_code = code
def execute_code(code_str):
"""封装代码执行逻辑"""
print(f"开始执行代码: {code_str[:50]}...")
try:
namespace = {
"ProjectBuilder": ProjectBuilder,
"project_implementation": __import__("project_implementation"),
"project": __import__("project"),
}
old_stdout = sys.stdout
redirected_output = io.StringIO()
sys.stdout = redirected_output
exec(code_str, namespace)
# 确保neo4j_find_function存在
if "neo4j_find_function" not in namespace:
raise ValueError("代码中未定义neo4j_find_function函数")
result_tuple = namespace["neo4j_find_function"]()
sys.stdout = old_stdout
output = redirected_output.getvalue().strip()
if not isinstance(result_tuple, tuple) or len(result_tuple) != 4:
raise ValueError("函数应返回包含4个元素的元组(status, data, error, helper_info)")
status, data, error, helper_info = result_tuple
# 添加更详细的日志
print(f"执行结果: status={status}, data={data}, error={error}")
return {
"status": status,
"data": data,
"error": error,
"helper_info": helper_info,
"output": output,
}
except Exception as e:
# 确保恢复stdout
sys.stdout = old_stdout
import traceback
print(f"执行代码时出错: {e}")
print(traceback.format_exc())
return {
"status": "error",
"error": str(e),
"helper_info": [],
"traceback": traceback.format_exc(),
}
# 第一次执行
print("开始执行函数...")
print("准备执行代码...")
result = execute_code(code)
print(f"执行完成,结果状态: {result.get('status', 'unknown')}")
# 检查是否需要重写
if result["status"] == "success":
print("执行成功,准备返回结果...")
# 返回成功结果...
final_result = {
"code": 0,
"message": "成功",
"data": {
"value": result["data"],
"code": original_code,
},
"query_status": "第一次查询成功",
}
else:
print(f"执行失败,错误: {result.get('error', 'unknown')}")
# 重写代码...
# 第一次失败,记录错误信息
error_info = {
"error": result["error"],
"helper_info": result.get("helper_info", []),
"traceback": result.get("traceback", ""),
}
print("\n第一次查询失败,尝试使用LLM修复代码...")
# 使用LLM直接修复代码
rewritten = rewrite_query_parameters(original_code=original_code, error_info=error_info)
fixed_code = rewritten.get("code", "").strip()
if not fixed_code:
print("\nLLM未返回有效代码")
final_result = {
"code": 1,
"message": error_info["error"],
"data": {"value": "", "code": original_code},
"error_info": error_info,
"query_status": "第一次查询失败,LLM未返回有效代码",
}
else:
print(f"\nLLM修复后的代码:\n{fixed_code}")
# 第二次执行修复后的代码
result = execute_code(fixed_code)
if result["status"] == "success":
print("\nLLM修复后查询成功")
print(f"\n执行结果详情: {json.dumps(result, ensure_ascii=False, indent=2)}")
final_result = {
"code": 0,
"message": "成功",
"data": {
"value": result["data"],
"code": fixed_code,
},
"query_status": "LLM修复后查询成功",
}
else:
print("\nLLM修复后查询仍然失败")
final_result = {
"code": 1,
"message": result["error"],
"data": {"value": "", "code": fixed_code},
"error_info": {
"error": result["error"],
"helper_info": result.get("helper_info", []),
},
"query_status": "LLM修复后查询仍然失败",
}
# 最后返回前
print("函数执行完毕,准备返回最终结果")
return final_result
def format_result(result):
"""
格式化查询结果
Args:
result: 查询结果(可能为 list、dict 或其他类型)
Returns:
str: 格式化后的结果
"""
# 处理 project 对象
if hasattr(result, "__module__") and result.__module__ == "project":
# 这是一个 project 模块中的对象
attrs = {k: v for k, v in result.__dict__.items() if not k.startswith("_")}
return json.dumps(attrs, ensure_ascii=False, indent=2)
# 处理 project 对象列表
if isinstance(result, list) and all(
hasattr(item, "__module__") and item.__module__ == "project" for item in result if hasattr(item, "__module__")
):
formatted_items = []
for item in result:
if hasattr(item, "__dict__"):
attrs = {k: v for k, v in item.__dict__.items() if not k.startswith("_")}
formatted_items.append(attrs)
else:
formatted_items.append(str(item))
return json.dumps(formatted_items, ensure_ascii=False, indent=2)
# 如果结果是字符串,可能包含调试信息,需要提取有用部分
if isinstance(result, str):
# 尝试提取最终结果部分
if "[]" in result:
return "未找到匹配的数据。"
# 如果包含节点信息,提取关键部分
import re
node_match = re.search(r"找到.*?labels=.*?properties=(.*?)>", result)
if node_match:
try:
# 提取属性部分并格式化
props_str = node_match.group(1).replace("'", '"')
import ast
props = ast.literal_eval(props_str)
formatted = "找到节点:\n"
for k, v in props.items():
formatted += f" {k}: {v}\n"
return formatted
except:
pass
# 如果包含查询结果数量
count_match = re.search(r"查询结果数量: (\d+)", result)
if count_match:
count = count_match.group(1)
if count == "0":
return "未找到匹配的数据。"
# 如果是列表
if isinstance(result, list):
if not result:
return "未找到匹配的数据。"
lines = [f"找到 {len(result)} 条匹配结果:"]
for i, item in enumerate(result, 1):
lines.append(f"\n结果 {i}:")
if hasattr(item, "items"): # 检查是否有items方法(字典或类似字典的对象)
try:
for k, v in item.items():
lines.append(f" {k}: {v}")
except:
lines.append(f" {item}")
else:
lines.append(f" {item}")
return "\n".join(lines)
# 如果是字典
elif isinstance(result, dict):
lines = ["查询结果:"]
for k, v in result.items():
lines.append(f" {k}: {v}")
return "\n".join(lines)
# 其他类型
else:
return str(result)
def format_dict_or_item(item):
"""
格式化字典或其他对象
Args:
item: 字典或其他对象
Returns:
str: 格式化后的字符串
"""
if isinstance(item, dict):
formatted = ""
for key, value in item.items():
formatted += f" {key}: {value}\n"
return formatted
return str(item)
def _extract_all_values(item, result_list):
"""
递归提取字典中的所有值并添加到结果列表中
Args:
item: 字典或其他对象
result_list: 用于存储结果的列表
"""
if isinstance(item, dict):
for key, value in item.items():
if key != "children": # 跳过children键,它会在主逻辑中单独处理
if isinstance(value, str):
result_list.append(value)
elif isinstance(value, (dict, list)):
_extract_all_values(value, result_list)
elif isinstance(item, list):
for element in item:
_extract_all_values(element, result_list)
# 定义一个辅助函数来提取和处理相似参数
def extract_similar_parameters(path_parts, knowledge_base):
"""
从路径和知识库中提取相似参数
Args:
path_parts (list): 路径部分列表
knowledge_base (list): 知识库数据
Returns:
str: 相似参数字符串
"""
# 从路径中提取关键词
extracted_parts = []
for part in path_parts:
if "/" in part:
last_part = part.split("/")[-1].strip()
if last_part:
extracted_parts.append(last_part)
else:
extracted_parts.append(part)
# 收集所有可能的相似参数
similar_params = []
# 从knowledge_base中查找相关项及其子节点
for part in extracted_parts:
for item in knowledge_base:
# 检查当前项是否匹配
match_found = False
for key, value in item.items():
if isinstance(value, str) and part.lower() in value.lower():
match_found = True
break
if match_found:
# 如果找到匹配项,提取所有值
_extract_all_values(item, similar_params)
# 特别处理子节点
if "children" in item and isinstance(item["children"], list):
for child in item["children"]:
_extract_all_values(child, similar_params)
# 移除重复项并排序
similar_params = list(set(similar_params))
similar_params.sort()
print(f"找到的相似参数: {similar_params}")
# 创建一个包含所有相似参数的字符串
return ", ".join(similar_params)
# 在查询前先增加一个简单函数,专门提取字符串中的键值对值
def extract_values_from_kb_string(kb_string):
"""从知识库字符串中提取所有键值对的值"""
import re
# 匹配所有键值对:"key": "value" 的模式
# 这里我们直接取第二个捕获组,即值部分
values = re.findall(r'"([^"]+)"\s*:\s*"([^"]+)"', kb_string)
# 只保留值(第二个元素)
result = [match[1] for match in values]
return result
# question = {
# "type": "query",
# "value": "查找一下【工程数据/安装工程/安装/架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立】的类型为【主材】的【塔材】",
# }
# result = nl_query_to_function_call(question)
# print("\n最终结果:")
# print(result)