Files
langchain_KG/langchain_neo4j.py
T
2025-06-03 13:45:48 +08:00

578 lines
22 KiB
Python

from langchain.chains import LLMChain
from langchain_openai import OpenAI
from langchain_experimental.utilities import PythonREPL
from project_implementation import ProjectBuilder
from prompt_templates import FUNCTION_CALL_PROMPT
import inspect
import project
import io
import sys
from parameter_rewriting import rewrite_query_parameters, KnowledgeGraphProcessor
import json
from langchain.agents import Tool, AgentExecutor, create_react_agent
from langchain.prompts import PromptTemplate
from llm import llm
# 获取ProjectTookiIt类的方法定义
def get_project_class_methods():
"""
从project模块中提取ProjectTookiIt类的方法定义
Returns:
str: 格式化后的方法定义字符串
"""
project_class_code = inspect.getsource(project.ProjectTookiIt)
lines = project_class_code.split("\n")
result_lines = []
in_class = False
skip_init = False
for line in lines:
if line.strip().startswith("class ProjectTookiIt"):
in_class = True
result_lines.append(line)
elif in_class:
if line.strip().startswith("def __init__"):
skip_init = True
elif skip_init and line.strip() and not line.startswith(" " * 8):
skip_init = False
if not skip_init:
result_lines.append(line)
return "\n".join(result_lines)
# 创建动态提示模板
project_class_methods = get_project_class_methods()
# 创建 Chain
function_call_chain = LLMChain(llm=llm, prompt=FUNCTION_CALL_PROMPT, output_key="code")
# Python 执行器
repl = PythonREPL()
# 创建知识图谱处理器实例
kg_processor = KnowledgeGraphProcessor()
# 定义搜索知识库的工具
def search_knowledge_base(query):
"""
在知识库中搜索关键词
Args:
query (str): 搜索关键词
Returns:
str: 搜索结果的JSON字符串
"""
found_data = kg_processor._search_in_kg(query)
if found_data:
return json.dumps(found_data, ensure_ascii=False, indent=2)
else:
return f"未找到与'{query}'相关的信息"
# 定义获取节点定义的工具
def get_node_definition(node_type):
"""
获取节点类型的定义
Args:
node_type (str): 节点类型名称
Returns:
str: 节点类型定义
"""
definition = kg_processor._get_node_definition(node_type)
return definition
# 创建工具列表
tools = [
Tool(
name="search_knowledge_base",
func=search_knowledge_base,
description="在知识库中搜索关键词,返回相关信息。输入应该是一个搜索关键词。",
),
Tool(
name="get_node_definition",
func=get_node_definition,
description="获取节点类型的定义,返回类型定义代码。输入应该是一个节点类型名称。",
),
]
# 创建Agent
agent = create_react_agent(llm, tools, FUNCTION_CALL_PROMPT)
# 创建Agent执行器
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True)
def nl_query_to_function_call(input_data):
"""
将自然语言查询转换为函数调用并执行,或直接执行提供的代码
Args:
input_data (dict): 包含type和value的字典
{
"type": "query|code",
"value": "查询内容或代码"
}
Returns:
dict: 包含状态码、消息和数据的字典
"""
input_type = input_data.get("type", "query")
input_value = input_data.get("value", "")
max_retries = 1 # 设置最大重试次数
current_retry = 0
original_query = input_value # 保存原始查询用于RAG
print(f"\n====== 开始处理查询 ======")
print(f"查询类型: {input_type}")
print(f"查询内容: {input_value}")
while current_retry <= max_retries:
print(f"\n----- 尝试 #{current_retry + 1} -----")
# 如果type是query,使用LLM生成代码
if input_type == "query" and current_retry == 0:
# 从查询中提取关键部分
import re
# 提取【】中的内容
path_parts = re.findall(r"【([^】]+)】", input_value)
# 创建临时代码
temp_code = f'search("{input_value}")'
# 对于每个路径,只添加最后一个部分
for part in path_parts:
if "/" in part:
# 提取路径中的最后一个部分
last_part = part.split("/")[-1].strip()
if last_part:
temp_code += f'\nsearch("{last_part}")'
else:
# 如果没有/,直接使用整个部分
temp_code += f'\nsearch("{part}")'
# 获取知识库内容和节点定义
knowledge_base, node_definitions = kg_processor._get_relevant_knowledge(temp_code)
# 使用Agent执行查询
agent_response = agent_executor.invoke(
{
"query": input_value,
"project_class_methods": project_class_methods,
"KnowledgeBase": knowledge_base,
"NodeDefinition": node_definitions,
}
)
# 从Agent响应中提取代码
code = agent_response["output"]
print(f"\n生成的代码:\n{code}")
else:
print(f"\n使用重写后的代码:\n{input_value}")
code = input_value
# 保存原始代码用于返回
original_code = code
# 执行生成的函数并捕获输出
try:
# 创建一个新的命名空间来执行代码,包含必要的导入
namespace = {
"ProjectBuilder": ProjectBuilder, # 添加ProjectBuilder到命名空间
"project_implementation": __import__("project_implementation"),
"project": __import__("project"),
}
# 执行生成的代码,定义neo4j_find_function函数
exec(code, namespace)
# 重定向stdout来捕获print输出
old_stdout = sys.stdout
redirected_output = io.StringIO()
sys.stdout = redirected_output
try:
# 执行函数
print("\n执行代码...")
result = namespace["neo4j_find_function"]()
# 获取捕获的输出
output = redirected_output.getvalue().strip()
print(f"\n原始输出:\n{output}")
# 检查结果是否为空
is_empty_result = (
not output
or output.lower() == "none"
or output == "[]"
or "未找到" in output
or "[]" in output
or "None" in output
or result is None
)
# 如果结果为空,走重写流程
if is_empty_result:
print("\n查询未找到结果,尝试定位具体缺失节点...")
# 解析原始查询路径中的最后一个节点名
import re
match = re.search(r"【([^】]+)】\s*$", original_query)
missing_node = match.group(1) if match else "未知节点"
error_info = {
"error_type": "NodeNotFoundError",
"error_message": f"{missing_node} 未找到,请检查该节点是否存在。",
"missing_node": missing_node,
"original_query": original_query,
"executed_code": original_code,
}
print("结构化错误信息:")
print(json.dumps(error_info, ensure_ascii=False, indent=2))
if current_retry < max_retries:
print("\n尝试使用RAG重写查询...")
try:
# 使用RAG重写查询和代码,并传递错误信息
rewritten = rewrite_query_parameters(original_query, original_code, error_info)
print(f"\nRAG重写结果: {json.dumps(rewritten, ensure_ascii=False, indent=2)}")
# 更新查询和代码
if "query" in rewritten and "code" in rewritten and rewritten["code"] != original_code:
print("\nRAG重写成功,使用新代码重试...")
input_value = rewritten["code"] # 直接使用重写后的代码
input_type = "code" # 切换到代码模式
current_retry += 1
continue # 继续下一次循环
else:
print("\nRAG重写未产生新代码,返回原始错误")
except Exception as e:
print(f"\nRAG重写失败: {e}")
# 记录错误但继续执行
# RAG重写失败或未产生新代码,返回原始错误
query_status = (
"第一次查询失败,RAG重写也失败"
if current_retry == 0
else f"第{current_retry+1}次查询失败,RAG重写也失败"
)
print(f"\n{query_status}")
return {
"code": 1,
"message": f"{missing_node} 未找到,请检查该节点是否存在。",
"data": {"value": "", "code": original_code},
"error_info": error_info,
"query_status": query_status,
}
# 清理输出,只保留有用的结果部分
clean_output = output
# 如果输出包含查询结果数量和对象引用
if "查询结果数量:" in output and "<project." in output:
# 提取查询结果部分
import re
# 尝试提取节点属性
node_match = re.search(r"找到节点: <Node.*?properties=({.*?})>", output, re.DOTALL)
if node_match:
props_str = node_match.group(1).replace("'", '"')
try:
import ast
props = ast.literal_eval(props_str)
clean_output = json.dumps(props, ensure_ascii=False, indent=2)
except:
pass
# 如果有查询结果数量信息
count_match = re.search(r"查询结果数量: (\d+)", output)
if count_match:
count = count_match.group(1)
if count == "0":
clean_output = "未找到匹配的数据。"
is_empty_result = True
elif not node_match: # 如果没有提取到节点属性但有结果
clean_output = f"找到 {count} 条匹配结果"
# 检查结果对象
if result is not None:
if isinstance(result, list):
if not result: # 空列表
is_empty_result = True
else:
# 处理非空列表
formatted_items = []
for item in result:
if hasattr(item, "__dict__"):
# 提取对象的所有属性
attrs = {k: v for k, v in item.__dict__.items() if not k.startswith("_")}
formatted_items.append(attrs)
else:
formatted_items.append(str(item))
if not is_empty_result: # 只有在不是空结果时才返回成功
query_status = (
"第一次查询成功"
if current_retry == 0
else f"第{current_retry+1}次查询成功(RAG重写后)"
)
print(f"\n{query_status}")
return {
"code": 0,
"message": "成功",
"data": {
"value": json.dumps(formatted_items, ensure_ascii=False, indent=2),
"code": original_code,
},
"query_status": query_status,
}
elif hasattr(result, "__dict__"):
# 单个对象
attrs = {k: v for k, v in result.__dict__.items() if not k.startswith("_")}
if not is_empty_result: # 只有在不是空结果时才返回成功
query_status = (
"第一次查询成功" if current_retry == 0 else f"第{current_retry+1}次查询成功(RAG重写后)"
)
print(f"\n{query_status}")
return {
"code": 0,
"message": "成功",
"data": {
"value": json.dumps(attrs, ensure_ascii=False, indent=2),
"code": original_code,
},
"query_status": query_status,
}
# 如果没有对象属性但有清理后的输出,且不是空结果
if (
clean_output
and clean_output.lower() != "none"
and clean_output != "[]"
and "未找到" not in clean_output
and not is_empty_result
):
query_status = (
"第一次查询成功" if current_retry == 0 else f"第{current_retry+1}次查询成功(RAG重写后)"
)
print(f"\n{query_status}")
return {
"code": 0,
"message": "成功",
"data": {"value": clean_output, "code": original_code},
"query_status": query_status,
}
finally:
# 恢复stdout
sys.stdout = old_stdout
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"\n执行代码时出错: {error_details}")
# 如果走到这里,说明结果为空或未找到匹配项,应该执行RAG重写流程
print("\n查询未找到结果,尝试定位具体缺失节点...")
# 解析原始查询路径中的最后一个节点名
import re
match = re.search(r"【([^】]+)】\s*$", original_query)
missing_node = match.group(1) if match else "未知节点"
error_info = {
"error_type": "NodeNotFoundError",
"error_message": f"{missing_node} 未找到,请检查该节点是否存在。",
"missing_node": missing_node,
"original_query": original_query,
"executed_code": original_code,
}
print("结构化错误信息:")
print(json.dumps(error_info, ensure_ascii=False, indent=2))
if current_retry < max_retries:
print("\n尝试使用RAG重写查询...")
try:
# 使用RAG重写查询和代码,并传递错误信息
rewritten = rewrite_query_parameters(original_query, original_code, error_info)
print(f"\nRAG重写结果: {json.dumps(rewritten, ensure_ascii=False, indent=2)}")
# 更新查询和代码
if "query" in rewritten and "code" in rewritten and rewritten["code"] != original_code:
print("\nRAG重写成功,使用新代码重试...")
input_value = rewritten["code"] # 直接使用重写后的代码
input_type = "code" # 切换到代码模式
current_retry += 1
continue # 继续下一次循环
else:
print("\nRAG重写未产生新代码,返回原始错误")
except Exception as e:
print(f"\nRAG重写失败: {e}")
# 记录错误但继续执行
# RAG重写失败或未产生新代码,返回原始错误
query_status = (
"第一次查询失败,RAG重写也失败"
if current_retry == 0
else f"第{current_retry+1}次查询失败,RAG重写也失败"
)
print(f"\n{query_status}")
return {
"code": 1,
"message": f"{missing_node} 未找到,请检查该节点是否存在。",
"data": {"value": "", "code": original_code},
"error_info": error_info,
"query_status": query_status,
}
# 如果所有重试都失败
print("\n所有重试都失败,无法找到匹配的结果")
query_status = "所有重试都失败"
return {
"code": 1,
"message": "所有重试都失败,无法找到匹配的结果",
"data": {"value": "", "code": original_code},
"query_status": query_status,
}
def format_result(result):
"""
格式化查询结果
Args:
result: 查询结果(可能为 list、dict 或其他类型)
Returns:
str: 格式化后的结果
"""
# 处理 project 对象
if hasattr(result, "__module__") and result.__module__ == "project":
# 这是一个 project 模块中的对象
attrs = {k: v for k, v in result.__dict__.items() if not k.startswith("_")}
return json.dumps(attrs, ensure_ascii=False, indent=2)
# 处理 project 对象列表
if isinstance(result, list) and all(
hasattr(item, "__module__") and item.__module__ == "project" for item in result if hasattr(item, "__module__")
):
formatted_items = []
for item in result:
if hasattr(item, "__dict__"):
attrs = {k: v for k, v in item.__dict__.items() if not k.startswith("_")}
formatted_items.append(attrs)
else:
formatted_items.append(str(item))
return json.dumps(formatted_items, ensure_ascii=False, indent=2)
# 如果结果是字符串,可能包含调试信息,需要提取有用部分
if isinstance(result, str):
# 尝试提取最终结果部分
if "[]" in result:
return "未找到匹配的数据。"
# 如果包含节点信息,提取关键部分
import re
node_match = re.search(r"找到.*?labels=.*?properties=(.*?)>", result)
if node_match:
try:
# 提取属性部分并格式化
props_str = node_match.group(1).replace("'", '"')
import ast
props = ast.literal_eval(props_str)
formatted = "找到节点:\n"
for k, v in props.items():
formatted += f" {k}: {v}\n"
return formatted
except:
pass
# 如果包含查询结果数量
count_match = re.search(r"查询结果数量: (\d+)", result)
if count_match:
count = count_match.group(1)
if count == "0":
return "未找到匹配的数据。"
# 如果是列表
if isinstance(result, list):
if not result:
return "未找到匹配的数据。"
lines = [f"找到 {len(result)} 条匹配结果:"]
for i, item in enumerate(result, 1):
lines.append(f"\n结果 {i}:")
if hasattr(item, "items"): # 检查是否有items方法(字典或类似字典的对象)
try:
for k, v in item.items():
lines.append(f" {k}: {v}")
except:
lines.append(f" {item}")
else:
lines.append(f" {item}")
return "\n".join(lines)
# 如果是字典
elif isinstance(result, dict):
lines = ["查询结果:"]
for k, v in result.items():
lines.append(f" {k}: {v}")
return "\n".join(lines)
# 其他类型
else:
return str(result)
def format_dict_or_item(item):
"""
格式化字典或其他对象
Args:
item: 字典或其他对象
Returns:
str: 格式化后的字符串
"""
if isinstance(item, dict):
formatted = ""
for key, value in item.items():
formatted += f" {key}: {value}\n"
return formatted
return str(item)
question = {
"type": "query",
"value": "查找一下【工程数据/安装工程/安装/架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立】的类型为【主材】的【阿巴阿巴】",
}
result = nl_query_to_function_call(question)
print(result)