首次提交:上传本地文件夹
This commit is contained in:
@@ -0,0 +1,197 @@
|
||||
import os
|
||||
from neo4j import GraphDatabase
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
class KnowledgeGraphQuerier:
|
||||
def __init__(self, uri, auth):
|
||||
"""
|
||||
初始化知识图谱查询器
|
||||
|
||||
参数:
|
||||
uri: Neo4j数据库URI
|
||||
auth: Neo4j认证信息,格式为(username, password)
|
||||
"""
|
||||
self.uri = uri
|
||||
self.auth = auth
|
||||
self.driver = GraphDatabase.driver(uri, auth=auth)
|
||||
# 初始化session属性
|
||||
self.session = self.driver.session()
|
||||
print(f"已连接到Neo4j数据库: {uri}")
|
||||
|
||||
def close(self):
|
||||
"""关闭数据库连接"""
|
||||
if hasattr(self, 'session') and self.session:
|
||||
self.session.close()
|
||||
if hasattr(self, 'driver') and self.driver:
|
||||
self.driver.close()
|
||||
print("已关闭Neo4j数据库连接")
|
||||
|
||||
def get_node_by_name(self, name: str, node_type: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
通过名称查询节点
|
||||
|
||||
参数:
|
||||
name: 节点名称
|
||||
node_type: 节点类型(可选)
|
||||
|
||||
返回:
|
||||
节点列表
|
||||
"""
|
||||
with self.driver.session() as session:
|
||||
if node_type:
|
||||
query = f"""
|
||||
MATCH (n:{node_type})
|
||||
WHERE n.original_name = $name OR n.display_name = $name
|
||||
RETURN n
|
||||
"""
|
||||
else:
|
||||
query = """
|
||||
MATCH (n)
|
||||
WHERE n.original_name = $name OR n.display_name = $name
|
||||
RETURN n
|
||||
"""
|
||||
|
||||
result = session.run(query, name=name)
|
||||
nodes = []
|
||||
for record in result:
|
||||
node = record["n"]
|
||||
nodes.append(dict(node))
|
||||
|
||||
return nodes
|
||||
|
||||
def get_related_nodes(self, node_name: str, relationship_type: Optional[str] = None,
|
||||
max_depth: int = 2) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取与指定节点相关的节点
|
||||
|
||||
参数:
|
||||
node_name: 节点名称
|
||||
relationship_type: 关系类型(可选)
|
||||
max_depth: 最大深度
|
||||
|
||||
返回:
|
||||
相关节点列表
|
||||
"""
|
||||
with self.driver.session() as session:
|
||||
if relationship_type:
|
||||
query = f"""
|
||||
MATCH (n)-[r:{relationship_type}*1..{max_depth}]-(related)
|
||||
WHERE n.original_name = $name OR n.display_name = $name
|
||||
RETURN related
|
||||
"""
|
||||
else:
|
||||
query = f"""
|
||||
MATCH (n)-[r*1..{max_depth}]-(related)
|
||||
WHERE n.original_name = $name OR n.display_name = $name
|
||||
RETURN related
|
||||
"""
|
||||
|
||||
result = session.run(query, name=node_name)
|
||||
related_nodes = []
|
||||
for record in result:
|
||||
node = record["related"]
|
||||
related_nodes.append(dict(node))
|
||||
|
||||
return related_nodes
|
||||
|
||||
def search_by_keyword(self, keyword: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
通过关键词搜索知识图谱
|
||||
|
||||
参数:
|
||||
keyword: 搜索关键词
|
||||
|
||||
返回:
|
||||
匹配的节点列表
|
||||
"""
|
||||
# 使用简化的查询,不包含路径查询
|
||||
query = """
|
||||
MATCH (n)
|
||||
WHERE n.name CONTAINS $keyword
|
||||
OR n.display_name CONTAINS $keyword
|
||||
OR n.original_name CONTAINS $keyword
|
||||
OR n.描述 CONTAINS $keyword
|
||||
RETURN n
|
||||
LIMIT 50
|
||||
"""
|
||||
|
||||
result = self.session.run(query, keyword=keyword)
|
||||
nodes = []
|
||||
|
||||
for record in result:
|
||||
node = dict(record["n"])
|
||||
node["labels"] = list(record["n"].labels)
|
||||
nodes.append(node)
|
||||
|
||||
return nodes
|
||||
|
||||
def get_path_between_nodes(self, start_name: str, end_name: str) -> List[List[Dict[str, Any]]]:
|
||||
"""
|
||||
获取两个节点之间的路径
|
||||
|
||||
参数:
|
||||
start_name: 起始节点名称
|
||||
end_name: 结束节点名称
|
||||
|
||||
返回:
|
||||
路径列表,每个路径是节点字典的列表
|
||||
"""
|
||||
with self.driver.session() as session:
|
||||
query = """
|
||||
MATCH p = shortestPath((a)-[*]-(b))
|
||||
WHERE (a.original_name = $start_name OR a.display_name = $start_name)
|
||||
AND (b.original_name = $end_name OR b.display_name = $end_name)
|
||||
RETURN p
|
||||
"""
|
||||
|
||||
result = session.run(query, start_name=start_name, end_name=end_name)
|
||||
paths = []
|
||||
for record in result:
|
||||
path = record["p"]
|
||||
path_nodes = []
|
||||
for node in path.nodes:
|
||||
path_nodes.append(dict(node))
|
||||
paths.append(path_nodes)
|
||||
|
||||
return paths
|
||||
|
||||
def get_function_details(self, function_name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取功能详情
|
||||
|
||||
参数:
|
||||
function_name: 功能名称
|
||||
|
||||
返回:
|
||||
功能详情字典,包含功能描述和路径
|
||||
"""
|
||||
with self.driver.session() as session:
|
||||
query = """
|
||||
MATCH (n:功能名称 {original_name: $name})
|
||||
RETURN n
|
||||
"""
|
||||
|
||||
result = session.run(query, name=function_name)
|
||||
record = result.single()
|
||||
if record:
|
||||
node = dict(record["n"])
|
||||
|
||||
# 获取功能的完整路径
|
||||
path_query = """
|
||||
MATCH p = (root:软件)-[*]->(func:功能名称 {original_name: $name})
|
||||
RETURN p
|
||||
"""
|
||||
path_result = session.run(path_query, name=function_name)
|
||||
path_record = path_result.single()
|
||||
|
||||
if path_record:
|
||||
path = path_record["p"]
|
||||
node_path = []
|
||||
for path_node in path.nodes:
|
||||
node_path.append(dict(path_node).get("original_name", ""))
|
||||
|
||||
node["path"] = " > ".join(filter(None, node_path))
|
||||
|
||||
return node
|
||||
|
||||
return {}
|
||||
Reference in New Issue
Block a user