import os from neo4j import GraphDatabase from typing import List, Dict, Any, Optional class KnowledgeGraphQuerier: def __init__(self, uri, auth): """ 初始化知识图谱查询器 参数: uri: Neo4j数据库URI auth: Neo4j认证信息,格式为(username, password) """ self.uri = uri self.auth = auth self.driver = GraphDatabase.driver(uri, auth=auth) # 初始化session属性 self.session = self.driver.session() print(f"已连接到Neo4j数据库: {uri}") def close(self): """关闭数据库连接""" if hasattr(self, 'session') and self.session: self.session.close() if hasattr(self, 'driver') and self.driver: self.driver.close() print("已关闭Neo4j数据库连接") def get_node_by_name(self, name: str, node_type: Optional[str] = None) -> List[Dict[str, Any]]: """ 通过名称查询节点 参数: name: 节点名称 node_type: 节点类型(可选) 返回: 节点列表 """ with self.driver.session() as session: if node_type: query = f""" MATCH (n:{node_type}) WHERE n.original_name = $name OR n.display_name = $name RETURN n """ else: query = """ MATCH (n) WHERE n.original_name = $name OR n.display_name = $name RETURN n """ result = session.run(query, name=name) nodes = [] for record in result: node = record["n"] nodes.append(dict(node)) return nodes def get_related_nodes(self, node_name: str, relationship_type: Optional[str] = None, max_depth: int = 2) -> List[Dict[str, Any]]: """ 获取与指定节点相关的节点 参数: node_name: 节点名称 relationship_type: 关系类型(可选) max_depth: 最大深度 返回: 相关节点列表 """ with self.driver.session() as session: if relationship_type: query = f""" MATCH (n)-[r:{relationship_type}*1..{max_depth}]-(related) WHERE n.original_name = $name OR n.display_name = $name RETURN related """ else: query = f""" MATCH (n)-[r*1..{max_depth}]-(related) WHERE n.original_name = $name OR n.display_name = $name RETURN related """ result = session.run(query, name=node_name) related_nodes = [] for record in result: node = record["related"] related_nodes.append(dict(node)) return related_nodes def search_by_keyword(self, keyword: str) -> List[Dict[str, Any]]: """ 通过关键词搜索知识图谱 参数: keyword: 搜索关键词 返回: 匹配的节点列表 """ # 使用简化的查询,不包含路径查询 query = """ MATCH (n) WHERE n.name CONTAINS $keyword OR n.display_name CONTAINS $keyword OR n.original_name CONTAINS $keyword OR n.描述 CONTAINS $keyword RETURN n LIMIT 50 """ result = self.session.run(query, keyword=keyword) nodes = [] for record in result: node = dict(record["n"]) node["labels"] = list(record["n"].labels) nodes.append(node) return nodes def get_path_between_nodes(self, start_name: str, end_name: str) -> List[List[Dict[str, Any]]]: """ 获取两个节点之间的路径 参数: start_name: 起始节点名称 end_name: 结束节点名称 返回: 路径列表,每个路径是节点字典的列表 """ with self.driver.session() as session: query = """ MATCH p = shortestPath((a)-[*]-(b)) WHERE (a.original_name = $start_name OR a.display_name = $start_name) AND (b.original_name = $end_name OR b.display_name = $end_name) RETURN p """ result = session.run(query, start_name=start_name, end_name=end_name) paths = [] for record in result: path = record["p"] path_nodes = [] for node in path.nodes: path_nodes.append(dict(node)) paths.append(path_nodes) return paths def get_function_details(self, function_name: str) -> Dict[str, Any]: """ 获取功能详情 参数: function_name: 功能名称 返回: 功能详情字典,包含功能描述和路径 """ with self.driver.session() as session: query = """ MATCH (n:功能名称 {original_name: $name}) RETURN n """ result = session.run(query, name=function_name) record = result.single() if record: node = dict(record["n"]) # 获取功能的完整路径 path_query = """ MATCH p = (root:软件)-[*]->(func:功能名称 {original_name: $name}) RETURN p """ path_result = session.run(path_query, name=function_name) path_record = path_result.single() if path_record: path = path_record["p"] node_path = [] for path_node in path.nodes: node_path.append(dict(path_node).get("original_name", "")) node["path"] = " > ".join(filter(None, node_path)) return node return {}