197 lines
7.1 KiB
Python
197 lines
7.1 KiB
Python
from typing import List, Dict, Any, Optional
|
||
from graph.graph_query import KnowledgeGraphQuerier
|
||
from utils.embedding import TextEmbedder
|
||
import logging
|
||
import sys
|
||
import numpy as np
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||
handlers=[logging.StreamHandler(sys.stdout)]
|
||
)
|
||
logger = logging.getLogger("GraphRetriever")
|
||
|
||
class GraphRetriever:
|
||
def __init__(self, graph_querier: KnowledgeGraphQuerier):
|
||
"""
|
||
初始化图检索器
|
||
|
||
参数:
|
||
graph_querier: 知识图谱查询器
|
||
"""
|
||
self.graph_querier = graph_querier
|
||
self.embedding_model = TextEmbedder()
|
||
logger.info("GraphRetriever初始化完成")
|
||
|
||
def retrieve(self, query: str, top_k: int = 5, slots: Dict[str, Any] = None) -> List[Dict[str, Any]]:
|
||
"""
|
||
检索与查询相关的信息
|
||
|
||
参数:
|
||
query: 用户查询
|
||
top_k: 返回的结果数量
|
||
slots: 槽位信息,用于检索
|
||
|
||
返回:
|
||
检索结果列表
|
||
"""
|
||
logger.info(f"开始检索,查询: '{query}',top_k: {top_k}")
|
||
|
||
# 1. 关键词搜索
|
||
logger.info("步骤1: 执行关键词搜索")
|
||
|
||
keyword_results = []
|
||
|
||
# 使用槽位信息进行检索
|
||
if slots and isinstance(slots, dict) and len(slots) > 0:
|
||
logger.info(f"使用槽位信息进行检索: {slots}")
|
||
|
||
# 从槽位中提取关键词
|
||
slot_keywords = []
|
||
for slot_name, slot_value in slots.items():
|
||
if slot_value and isinstance(slot_value, str):
|
||
slot_keywords.append(slot_value)
|
||
|
||
logger.info(f"从槽位中提取的关键词: {slot_keywords}")
|
||
|
||
# 使用槽位关键词进行检索
|
||
for keyword in slot_keywords:
|
||
logger.info(f"使用槽位关键词: '{keyword}'")
|
||
results = self.graph_querier.search_by_keyword(keyword)
|
||
logger.info(f"槽位关键词'{keyword}'搜索结果数量: {len(results)}")
|
||
|
||
if results:
|
||
for node in results:
|
||
# 转换为文本并添加到结果中
|
||
text = self._node_to_text(node)
|
||
if text:
|
||
keyword_results.append({
|
||
"node": node,
|
||
"text": text
|
||
})
|
||
|
||
logger.info(f"找到结果,当前总结果数量: {len(keyword_results)}")
|
||
|
||
# 如果已经找到足够多的结果,可以停止搜索
|
||
if len(keyword_results) >= top_k * 2:
|
||
break
|
||
|
||
# 如果槽位检索没有结果或没有提供槽位,使用原始查询
|
||
if not keyword_results:
|
||
logger.info("使用原始查询进行检索")
|
||
nodes = self.graph_querier.search_by_keyword(query)
|
||
logger.info(f"原始查询搜索结果数量: {len(nodes)}")
|
||
|
||
for node in nodes:
|
||
# 转换为文本并添加到结果中
|
||
text = self._node_to_text(node)
|
||
if text:
|
||
keyword_results.append({
|
||
"node": node,
|
||
"text": text
|
||
})
|
||
|
||
# 2. 计算相似度并排序
|
||
logger.info("步骤2: 计算相似度并排序")
|
||
|
||
if not keyword_results:
|
||
logger.warning("没有找到任何结果")
|
||
return []
|
||
|
||
try:
|
||
# 计算查询与每个结果的相似度
|
||
# 修改这里,使用正确的方法名
|
||
query_embedding = self.embedding_model.get_embedding(query)
|
||
|
||
for result in keyword_results:
|
||
text = result["text"]
|
||
text_embedding = self.embedding_model.get_embedding(text)
|
||
similarity = self._compute_similarity(query_embedding, text_embedding)
|
||
result["similarity"] = similarity
|
||
|
||
# 按相似度排序
|
||
sorted_results = sorted(keyword_results, key=lambda x: x["similarity"], reverse=True)
|
||
except AttributeError as e:
|
||
# 如果嵌入模型方法不可用,则跳过相似度计算,直接返回结果
|
||
logger.warning(f"嵌入计算失败: {str(e)},跳过相似度排序")
|
||
for result in keyword_results:
|
||
result["similarity"] = 0.0
|
||
sorted_results = keyword_results[:top_k]
|
||
|
||
# 返回top_k个结果
|
||
return sorted_results[:top_k]
|
||
|
||
def _node_to_text(self, node: Dict[str, Any]) -> str:
|
||
"""
|
||
将节点转换为文本表示
|
||
|
||
参数:
|
||
node: 节点信息
|
||
|
||
返回:
|
||
节点的文本表示
|
||
"""
|
||
if not node:
|
||
return None
|
||
|
||
# 获取节点类型
|
||
node_type = "未知类型"
|
||
if "labels" in node and node["labels"]:
|
||
if isinstance(node["labels"], list) and len(node["labels"]) > 0:
|
||
node_type = node["labels"][0]
|
||
elif isinstance(node["labels"], str):
|
||
node_type = node["labels"]
|
||
|
||
# 获取节点名称
|
||
name = node.get("original_name", "") or node.get("display_name", "") or node.get("name", "未知名称")
|
||
|
||
# 获取节点描述
|
||
description = node.get("描述", "")
|
||
|
||
# 获取节点路径
|
||
path = ""
|
||
if "path_to_root" in node:
|
||
if isinstance(node["path_to_root"], list):
|
||
path = " > ".join(node["path_to_root"])
|
||
else:
|
||
path = str(node["path_to_root"])
|
||
|
||
# 构建文本表示
|
||
text_parts = [f"类型: {node_type}", f"名称: {name}"]
|
||
|
||
if path:
|
||
text_parts.append(f"路径: {path}")
|
||
|
||
if description:
|
||
text_parts.append(f"描述: {description}")
|
||
|
||
# 添加其他重要属性
|
||
for key, value in node.items():
|
||
if key not in ["labels", "original_name", "display_name", "name", "描述", "path_to_root"] and value:
|
||
if not isinstance(value, (list, dict)): # 只添加简单类型的属性
|
||
text_parts.append(f"{key}: {value}")
|
||
|
||
return "\n".join(text_parts)
|
||
|
||
def _compute_similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float:
|
||
"""
|
||
计算两个嵌入向量的余弦相似度
|
||
|
||
参数:
|
||
embedding1: 第一个嵌入向量
|
||
embedding2: 第二个嵌入向量
|
||
|
||
返回:
|
||
余弦相似度
|
||
"""
|
||
# 计算余弦相似度
|
||
dot_product = np.dot(embedding1, embedding2)
|
||
norm1 = np.linalg.norm(embedding1)
|
||
norm2 = np.linalg.norm(embedding2)
|
||
|
||
if norm1 == 0 or norm2 == 0:
|
||
return 0.0
|
||
|
||
return dot_product / (norm1 * norm2) |