from typing import List, Dict, Any, Optional from graph.graph_query import KnowledgeGraphQuerier from utils.embedding import TextEmbedder import logging import sys import numpy as np # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(sys.stdout)] ) logger = logging.getLogger("GraphRetriever") class GraphRetriever: def __init__(self, graph_querier: KnowledgeGraphQuerier): """ 初始化图检索器 参数: graph_querier: 知识图谱查询器 """ self.graph_querier = graph_querier self.embedding_model = TextEmbedder() logger.info("GraphRetriever初始化完成") def retrieve(self, query: str, top_k: int = 5, slots: Dict[str, Any] = None) -> List[Dict[str, Any]]: """ 检索与查询相关的信息 参数: query: 用户查询 top_k: 返回的结果数量 slots: 槽位信息,用于检索 返回: 检索结果列表 """ logger.info(f"开始检索,查询: '{query}',top_k: {top_k}") # 1. 关键词搜索 logger.info("步骤1: 执行关键词搜索") keyword_results = [] # 使用槽位信息进行检索 if slots and isinstance(slots, dict) and len(slots) > 0: logger.info(f"使用槽位信息进行检索: {slots}") # 从槽位中提取关键词 slot_keywords = [] for slot_name, slot_value in slots.items(): if slot_value and isinstance(slot_value, str): slot_keywords.append(slot_value) logger.info(f"从槽位中提取的关键词: {slot_keywords}") # 使用槽位关键词进行检索 for keyword in slot_keywords: logger.info(f"使用槽位关键词: '{keyword}'") results = self.graph_querier.search_by_keyword(keyword) logger.info(f"槽位关键词'{keyword}'搜索结果数量: {len(results)}") if results: for node in results: # 转换为文本并添加到结果中 text = self._node_to_text(node) if text: keyword_results.append({ "node": node, "text": text }) logger.info(f"找到结果,当前总结果数量: {len(keyword_results)}") # 如果已经找到足够多的结果,可以停止搜索 if len(keyword_results) >= top_k * 2: break # 如果槽位检索没有结果或没有提供槽位,使用原始查询 if not keyword_results: logger.info("使用原始查询进行检索") nodes = self.graph_querier.search_by_keyword(query) logger.info(f"原始查询搜索结果数量: {len(nodes)}") for node in nodes: # 转换为文本并添加到结果中 text = self._node_to_text(node) if text: keyword_results.append({ "node": node, "text": text }) # 2. 计算相似度并排序 logger.info("步骤2: 计算相似度并排序") if not keyword_results: logger.warning("没有找到任何结果") return [] try: # 计算查询与每个结果的相似度 # 修改这里,使用正确的方法名 query_embedding = self.embedding_model.get_embedding(query) for result in keyword_results: text = result["text"] text_embedding = self.embedding_model.get_embedding(text) similarity = self._compute_similarity(query_embedding, text_embedding) result["similarity"] = similarity # 按相似度排序 sorted_results = sorted(keyword_results, key=lambda x: x["similarity"], reverse=True) except AttributeError as e: # 如果嵌入模型方法不可用,则跳过相似度计算,直接返回结果 logger.warning(f"嵌入计算失败: {str(e)},跳过相似度排序") for result in keyword_results: result["similarity"] = 0.0 sorted_results = keyword_results[:top_k] # 返回top_k个结果 return sorted_results[:top_k] def _node_to_text(self, node: Dict[str, Any]) -> str: """ 将节点转换为文本表示 参数: node: 节点信息 返回: 节点的文本表示 """ if not node: return None # 获取节点类型 node_type = "未知类型" if "labels" in node and node["labels"]: if isinstance(node["labels"], list) and len(node["labels"]) > 0: node_type = node["labels"][0] elif isinstance(node["labels"], str): node_type = node["labels"] # 获取节点名称 name = node.get("original_name", "") or node.get("display_name", "") or node.get("name", "未知名称") # 获取节点描述 description = node.get("描述", "") # 获取节点路径 path = "" if "path_to_root" in node: if isinstance(node["path_to_root"], list): path = " > ".join(node["path_to_root"]) else: path = str(node["path_to_root"]) # 构建文本表示 text_parts = [f"类型: {node_type}", f"名称: {name}"] if path: text_parts.append(f"路径: {path}") if description: text_parts.append(f"描述: {description}") # 添加其他重要属性 for key, value in node.items(): if key not in ["labels", "original_name", "display_name", "name", "描述", "path_to_root"] and value: if not isinstance(value, (list, dict)): # 只添加简单类型的属性 text_parts.append(f"{key}: {value}") return "\n".join(text_parts) def _compute_similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float: """ 计算两个嵌入向量的余弦相似度 参数: embedding1: 第一个嵌入向量 embedding2: 第二个嵌入向量 返回: 余弦相似度 """ # 计算余弦相似度 dot_product = np.dot(embedding1, embedding2) norm1 = np.linalg.norm(embedding1) norm2 = np.linalg.norm(embedding2) if norm1 == 0 or norm2 == 0: return 0.0 return dot_product / (norm1 * norm2)