Files
GraphRAG/graph/graph_query.py
T
2025-03-31 17:28:23 +08:00

197 lines
6.3 KiB
Python

import os
from neo4j import GraphDatabase
from typing import List, Dict, Any, Optional
class KnowledgeGraphQuerier:
def __init__(self, uri, auth):
"""
初始化知识图谱查询器
参数:
uri: Neo4j数据库URI
auth: Neo4j认证信息,格式为(username, password)
"""
self.uri = uri
self.auth = auth
self.driver = GraphDatabase.driver(uri, auth=auth)
# 初始化session属性
self.session = self.driver.session()
print(f"已连接到Neo4j数据库: {uri}")
def close(self):
"""关闭数据库连接"""
if hasattr(self, 'session') and self.session:
self.session.close()
if hasattr(self, 'driver') and self.driver:
self.driver.close()
print("已关闭Neo4j数据库连接")
def get_node_by_name(self, name: str, node_type: Optional[str] = None) -> List[Dict[str, Any]]:
"""
通过名称查询节点
参数:
name: 节点名称
node_type: 节点类型(可选)
返回:
节点列表
"""
with self.driver.session() as session:
if node_type:
query = f"""
MATCH (n:{node_type})
WHERE n.original_name = $name OR n.display_name = $name
RETURN n
"""
else:
query = """
MATCH (n)
WHERE n.original_name = $name OR n.display_name = $name
RETURN n
"""
result = session.run(query, name=name)
nodes = []
for record in result:
node = record["n"]
nodes.append(dict(node))
return nodes
def get_related_nodes(self, node_name: str, relationship_type: Optional[str] = None,
max_depth: int = 2) -> List[Dict[str, Any]]:
"""
获取与指定节点相关的节点
参数:
node_name: 节点名称
relationship_type: 关系类型(可选)
max_depth: 最大深度
返回:
相关节点列表
"""
with self.driver.session() as session:
if relationship_type:
query = f"""
MATCH (n)-[r:{relationship_type}*1..{max_depth}]-(related)
WHERE n.original_name = $name OR n.display_name = $name
RETURN related
"""
else:
query = f"""
MATCH (n)-[r*1..{max_depth}]-(related)
WHERE n.original_name = $name OR n.display_name = $name
RETURN related
"""
result = session.run(query, name=node_name)
related_nodes = []
for record in result:
node = record["related"]
related_nodes.append(dict(node))
return related_nodes
def search_by_keyword(self, keyword: str) -> List[Dict[str, Any]]:
"""
通过关键词搜索知识图谱
参数:
keyword: 搜索关键词
返回:
匹配的节点列表
"""
# 使用简化的查询,不包含路径查询
query = """
MATCH (n)
WHERE n.name CONTAINS $keyword
OR n.display_name CONTAINS $keyword
OR n.original_name CONTAINS $keyword
OR n.描述 CONTAINS $keyword
RETURN n
LIMIT 50
"""
result = self.session.run(query, keyword=keyword)
nodes = []
for record in result:
node = dict(record["n"])
node["labels"] = list(record["n"].labels)
nodes.append(node)
return nodes
def get_path_between_nodes(self, start_name: str, end_name: str) -> List[List[Dict[str, Any]]]:
"""
获取两个节点之间的路径
参数:
start_name: 起始节点名称
end_name: 结束节点名称
返回:
路径列表,每个路径是节点字典的列表
"""
with self.driver.session() as session:
query = """
MATCH p = shortestPath((a)-[*]-(b))
WHERE (a.original_name = $start_name OR a.display_name = $start_name)
AND (b.original_name = $end_name OR b.display_name = $end_name)
RETURN p
"""
result = session.run(query, start_name=start_name, end_name=end_name)
paths = []
for record in result:
path = record["p"]
path_nodes = []
for node in path.nodes:
path_nodes.append(dict(node))
paths.append(path_nodes)
return paths
def get_function_details(self, function_name: str) -> Dict[str, Any]:
"""
获取功能详情
参数:
function_name: 功能名称
返回:
功能详情字典,包含功能描述和路径
"""
with self.driver.session() as session:
query = """
MATCH (n:功能名称 {original_name: $name})
RETURN n
"""
result = session.run(query, name=function_name)
record = result.single()
if record:
node = dict(record["n"])
# 获取功能的完整路径
path_query = """
MATCH p = (root:软件)-[*]->(func:功能名称 {original_name: $name})
RETURN p
"""
path_result = session.run(path_query, name=function_name)
path_record = path_result.single()
if path_record:
path = path_record["p"]
node_path = []
for path_node in path.nodes:
node_path.append(dict(path_node).get("original_name", ""))
node["path"] = " > ".join(filter(None, node_path))
return node
return {}