197 lines
6.3 KiB
Python
197 lines
6.3 KiB
Python
import os
|
|
from neo4j import GraphDatabase
|
|
from typing import List, Dict, Any, Optional
|
|
|
|
class KnowledgeGraphQuerier:
|
|
def __init__(self, uri, auth):
|
|
"""
|
|
初始化知识图谱查询器
|
|
|
|
参数:
|
|
uri: Neo4j数据库URI
|
|
auth: Neo4j认证信息,格式为(username, password)
|
|
"""
|
|
self.uri = uri
|
|
self.auth = auth
|
|
self.driver = GraphDatabase.driver(uri, auth=auth)
|
|
# 初始化session属性
|
|
self.session = self.driver.session()
|
|
print(f"已连接到Neo4j数据库: {uri}")
|
|
|
|
def close(self):
|
|
"""关闭数据库连接"""
|
|
if hasattr(self, 'session') and self.session:
|
|
self.session.close()
|
|
if hasattr(self, 'driver') and self.driver:
|
|
self.driver.close()
|
|
print("已关闭Neo4j数据库连接")
|
|
|
|
def get_node_by_name(self, name: str, node_type: Optional[str] = None) -> List[Dict[str, Any]]:
|
|
"""
|
|
通过名称查询节点
|
|
|
|
参数:
|
|
name: 节点名称
|
|
node_type: 节点类型(可选)
|
|
|
|
返回:
|
|
节点列表
|
|
"""
|
|
with self.driver.session() as session:
|
|
if node_type:
|
|
query = f"""
|
|
MATCH (n:{node_type})
|
|
WHERE n.original_name = $name OR n.display_name = $name
|
|
RETURN n
|
|
"""
|
|
else:
|
|
query = """
|
|
MATCH (n)
|
|
WHERE n.original_name = $name OR n.display_name = $name
|
|
RETURN n
|
|
"""
|
|
|
|
result = session.run(query, name=name)
|
|
nodes = []
|
|
for record in result:
|
|
node = record["n"]
|
|
nodes.append(dict(node))
|
|
|
|
return nodes
|
|
|
|
def get_related_nodes(self, node_name: str, relationship_type: Optional[str] = None,
|
|
max_depth: int = 2) -> List[Dict[str, Any]]:
|
|
"""
|
|
获取与指定节点相关的节点
|
|
|
|
参数:
|
|
node_name: 节点名称
|
|
relationship_type: 关系类型(可选)
|
|
max_depth: 最大深度
|
|
|
|
返回:
|
|
相关节点列表
|
|
"""
|
|
with self.driver.session() as session:
|
|
if relationship_type:
|
|
query = f"""
|
|
MATCH (n)-[r:{relationship_type}*1..{max_depth}]-(related)
|
|
WHERE n.original_name = $name OR n.display_name = $name
|
|
RETURN related
|
|
"""
|
|
else:
|
|
query = f"""
|
|
MATCH (n)-[r*1..{max_depth}]-(related)
|
|
WHERE n.original_name = $name OR n.display_name = $name
|
|
RETURN related
|
|
"""
|
|
|
|
result = session.run(query, name=node_name)
|
|
related_nodes = []
|
|
for record in result:
|
|
node = record["related"]
|
|
related_nodes.append(dict(node))
|
|
|
|
return related_nodes
|
|
|
|
def search_by_keyword(self, keyword: str) -> List[Dict[str, Any]]:
|
|
"""
|
|
通过关键词搜索知识图谱
|
|
|
|
参数:
|
|
keyword: 搜索关键词
|
|
|
|
返回:
|
|
匹配的节点列表
|
|
"""
|
|
# 使用简化的查询,不包含路径查询
|
|
query = """
|
|
MATCH (n)
|
|
WHERE n.name CONTAINS $keyword
|
|
OR n.display_name CONTAINS $keyword
|
|
OR n.original_name CONTAINS $keyword
|
|
OR n.描述 CONTAINS $keyword
|
|
RETURN n
|
|
LIMIT 50
|
|
"""
|
|
|
|
result = self.session.run(query, keyword=keyword)
|
|
nodes = []
|
|
|
|
for record in result:
|
|
node = dict(record["n"])
|
|
node["labels"] = list(record["n"].labels)
|
|
nodes.append(node)
|
|
|
|
return nodes
|
|
|
|
def get_path_between_nodes(self, start_name: str, end_name: str) -> List[List[Dict[str, Any]]]:
|
|
"""
|
|
获取两个节点之间的路径
|
|
|
|
参数:
|
|
start_name: 起始节点名称
|
|
end_name: 结束节点名称
|
|
|
|
返回:
|
|
路径列表,每个路径是节点字典的列表
|
|
"""
|
|
with self.driver.session() as session:
|
|
query = """
|
|
MATCH p = shortestPath((a)-[*]-(b))
|
|
WHERE (a.original_name = $start_name OR a.display_name = $start_name)
|
|
AND (b.original_name = $end_name OR b.display_name = $end_name)
|
|
RETURN p
|
|
"""
|
|
|
|
result = session.run(query, start_name=start_name, end_name=end_name)
|
|
paths = []
|
|
for record in result:
|
|
path = record["p"]
|
|
path_nodes = []
|
|
for node in path.nodes:
|
|
path_nodes.append(dict(node))
|
|
paths.append(path_nodes)
|
|
|
|
return paths
|
|
|
|
def get_function_details(self, function_name: str) -> Dict[str, Any]:
|
|
"""
|
|
获取功能详情
|
|
|
|
参数:
|
|
function_name: 功能名称
|
|
|
|
返回:
|
|
功能详情字典,包含功能描述和路径
|
|
"""
|
|
with self.driver.session() as session:
|
|
query = """
|
|
MATCH (n:功能名称 {original_name: $name})
|
|
RETURN n
|
|
"""
|
|
|
|
result = session.run(query, name=function_name)
|
|
record = result.single()
|
|
if record:
|
|
node = dict(record["n"])
|
|
|
|
# 获取功能的完整路径
|
|
path_query = """
|
|
MATCH p = (root:软件)-[*]->(func:功能名称 {original_name: $name})
|
|
RETURN p
|
|
"""
|
|
path_result = session.run(path_query, name=function_name)
|
|
path_record = path_result.single()
|
|
|
|
if path_record:
|
|
path = path_record["p"]
|
|
node_path = []
|
|
for path_node in path.nodes:
|
|
node_path.append(dict(path_node).get("original_name", ""))
|
|
|
|
node["path"] = " > ".join(filter(None, node_path))
|
|
|
|
return node
|
|
|
|
return {} |