Files
GraphRAG/retriever.py
T
2025-03-31 17:28:23 +08:00

197 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from typing import List, Dict, Any, Optional
from graph.graph_query import KnowledgeGraphQuerier
from utils.embedding import TextEmbedder
import logging
import sys
import numpy as np
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger("GraphRetriever")
class GraphRetriever:
def __init__(self, graph_querier: KnowledgeGraphQuerier):
"""
初始化图检索器
参数:
graph_querier: 知识图谱查询器
"""
self.graph_querier = graph_querier
self.embedding_model = TextEmbedder()
logger.info("GraphRetriever初始化完成")
def retrieve(self, query: str, top_k: int = 5, slots: Dict[str, Any] = None) -> List[Dict[str, Any]]:
"""
检索与查询相关的信息
参数:
query: 用户查询
top_k: 返回的结果数量
slots: 槽位信息,用于检索
返回:
检索结果列表
"""
logger.info(f"开始检索,查询: '{query}'top_k: {top_k}")
# 1. 关键词搜索
logger.info("步骤1: 执行关键词搜索")
keyword_results = []
# 使用槽位信息进行检索
if slots and isinstance(slots, dict) and len(slots) > 0:
logger.info(f"使用槽位信息进行检索: {slots}")
# 从槽位中提取关键词
slot_keywords = []
for slot_name, slot_value in slots.items():
if slot_value and isinstance(slot_value, str):
slot_keywords.append(slot_value)
logger.info(f"从槽位中提取的关键词: {slot_keywords}")
# 使用槽位关键词进行检索
for keyword in slot_keywords:
logger.info(f"使用槽位关键词: '{keyword}'")
results = self.graph_querier.search_by_keyword(keyword)
logger.info(f"槽位关键词'{keyword}'搜索结果数量: {len(results)}")
if results:
for node in results:
# 转换为文本并添加到结果中
text = self._node_to_text(node)
if text:
keyword_results.append({
"node": node,
"text": text
})
logger.info(f"找到结果,当前总结果数量: {len(keyword_results)}")
# 如果已经找到足够多的结果,可以停止搜索
if len(keyword_results) >= top_k * 2:
break
# 如果槽位检索没有结果或没有提供槽位,使用原始查询
if not keyword_results:
logger.info("使用原始查询进行检索")
nodes = self.graph_querier.search_by_keyword(query)
logger.info(f"原始查询搜索结果数量: {len(nodes)}")
for node in nodes:
# 转换为文本并添加到结果中
text = self._node_to_text(node)
if text:
keyword_results.append({
"node": node,
"text": text
})
# 2. 计算相似度并排序
logger.info("步骤2: 计算相似度并排序")
if not keyword_results:
logger.warning("没有找到任何结果")
return []
try:
# 计算查询与每个结果的相似度
# 修改这里,使用正确的方法名
query_embedding = self.embedding_model.get_embedding(query)
for result in keyword_results:
text = result["text"]
text_embedding = self.embedding_model.get_embedding(text)
similarity = self._compute_similarity(query_embedding, text_embedding)
result["similarity"] = similarity
# 按相似度排序
sorted_results = sorted(keyword_results, key=lambda x: x["similarity"], reverse=True)
except AttributeError as e:
# 如果嵌入模型方法不可用,则跳过相似度计算,直接返回结果
logger.warning(f"嵌入计算失败: {str(e)},跳过相似度排序")
for result in keyword_results:
result["similarity"] = 0.0
sorted_results = keyword_results[:top_k]
# 返回top_k个结果
return sorted_results[:top_k]
def _node_to_text(self, node: Dict[str, Any]) -> str:
"""
将节点转换为文本表示
参数:
node: 节点信息
返回:
节点的文本表示
"""
if not node:
return None
# 获取节点类型
node_type = "未知类型"
if "labels" in node and node["labels"]:
if isinstance(node["labels"], list) and len(node["labels"]) > 0:
node_type = node["labels"][0]
elif isinstance(node["labels"], str):
node_type = node["labels"]
# 获取节点名称
name = node.get("original_name", "") or node.get("display_name", "") or node.get("name", "未知名称")
# 获取节点描述
description = node.get("描述", "")
# 获取节点路径
path = ""
if "path_to_root" in node:
if isinstance(node["path_to_root"], list):
path = " > ".join(node["path_to_root"])
else:
path = str(node["path_to_root"])
# 构建文本表示
text_parts = [f"类型: {node_type}", f"名称: {name}"]
if path:
text_parts.append(f"路径: {path}")
if description:
text_parts.append(f"描述: {description}")
# 添加其他重要属性
for key, value in node.items():
if key not in ["labels", "original_name", "display_name", "name", "描述", "path_to_root"] and value:
if not isinstance(value, (list, dict)): # 只添加简单类型的属性
text_parts.append(f"{key}: {value}")
return "\n".join(text_parts)
def _compute_similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float:
"""
计算两个嵌入向量的余弦相似度
参数:
embedding1: 第一个嵌入向量
embedding2: 第二个嵌入向量
返回:
余弦相似度
"""
# 计算余弦相似度
dot_product = np.dot(embedding1, embedding2)
norm1 = np.linalg.norm(embedding1)
norm2 = np.linalg.norm(embedding2)
if norm1 == 0 or norm2 == 0:
return 0.0
return dot_product / (norm1 * norm2)