首次提交:上传本地文件夹
This commit is contained in:
+197
@@ -0,0 +1,197 @@
|
||||
from typing import List, Dict, Any, Optional
|
||||
from graph.graph_query import KnowledgeGraphQuerier
|
||||
from utils.embedding import TextEmbedder
|
||||
import logging
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[logging.StreamHandler(sys.stdout)]
|
||||
)
|
||||
logger = logging.getLogger("GraphRetriever")
|
||||
|
||||
class GraphRetriever:
|
||||
def __init__(self, graph_querier: KnowledgeGraphQuerier):
|
||||
"""
|
||||
初始化图检索器
|
||||
|
||||
参数:
|
||||
graph_querier: 知识图谱查询器
|
||||
"""
|
||||
self.graph_querier = graph_querier
|
||||
self.embedding_model = TextEmbedder()
|
||||
logger.info("GraphRetriever初始化完成")
|
||||
|
||||
def retrieve(self, query: str, top_k: int = 5, slots: Dict[str, Any] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
检索与查询相关的信息
|
||||
|
||||
参数:
|
||||
query: 用户查询
|
||||
top_k: 返回的结果数量
|
||||
slots: 槽位信息,用于检索
|
||||
|
||||
返回:
|
||||
检索结果列表
|
||||
"""
|
||||
logger.info(f"开始检索,查询: '{query}',top_k: {top_k}")
|
||||
|
||||
# 1. 关键词搜索
|
||||
logger.info("步骤1: 执行关键词搜索")
|
||||
|
||||
keyword_results = []
|
||||
|
||||
# 使用槽位信息进行检索
|
||||
if slots and isinstance(slots, dict) and len(slots) > 0:
|
||||
logger.info(f"使用槽位信息进行检索: {slots}")
|
||||
|
||||
# 从槽位中提取关键词
|
||||
slot_keywords = []
|
||||
for slot_name, slot_value in slots.items():
|
||||
if slot_value and isinstance(slot_value, str):
|
||||
slot_keywords.append(slot_value)
|
||||
|
||||
logger.info(f"从槽位中提取的关键词: {slot_keywords}")
|
||||
|
||||
# 使用槽位关键词进行检索
|
||||
for keyword in slot_keywords:
|
||||
logger.info(f"使用槽位关键词: '{keyword}'")
|
||||
results = self.graph_querier.search_by_keyword(keyword)
|
||||
logger.info(f"槽位关键词'{keyword}'搜索结果数量: {len(results)}")
|
||||
|
||||
if results:
|
||||
for node in results:
|
||||
# 转换为文本并添加到结果中
|
||||
text = self._node_to_text(node)
|
||||
if text:
|
||||
keyword_results.append({
|
||||
"node": node,
|
||||
"text": text
|
||||
})
|
||||
|
||||
logger.info(f"找到结果,当前总结果数量: {len(keyword_results)}")
|
||||
|
||||
# 如果已经找到足够多的结果,可以停止搜索
|
||||
if len(keyword_results) >= top_k * 2:
|
||||
break
|
||||
|
||||
# 如果槽位检索没有结果或没有提供槽位,使用原始查询
|
||||
if not keyword_results:
|
||||
logger.info("使用原始查询进行检索")
|
||||
nodes = self.graph_querier.search_by_keyword(query)
|
||||
logger.info(f"原始查询搜索结果数量: {len(nodes)}")
|
||||
|
||||
for node in nodes:
|
||||
# 转换为文本并添加到结果中
|
||||
text = self._node_to_text(node)
|
||||
if text:
|
||||
keyword_results.append({
|
||||
"node": node,
|
||||
"text": text
|
||||
})
|
||||
|
||||
# 2. 计算相似度并排序
|
||||
logger.info("步骤2: 计算相似度并排序")
|
||||
|
||||
if not keyword_results:
|
||||
logger.warning("没有找到任何结果")
|
||||
return []
|
||||
|
||||
try:
|
||||
# 计算查询与每个结果的相似度
|
||||
# 修改这里,使用正确的方法名
|
||||
query_embedding = self.embedding_model.get_embedding(query)
|
||||
|
||||
for result in keyword_results:
|
||||
text = result["text"]
|
||||
text_embedding = self.embedding_model.get_embedding(text)
|
||||
similarity = self._compute_similarity(query_embedding, text_embedding)
|
||||
result["similarity"] = similarity
|
||||
|
||||
# 按相似度排序
|
||||
sorted_results = sorted(keyword_results, key=lambda x: x["similarity"], reverse=True)
|
||||
except AttributeError as e:
|
||||
# 如果嵌入模型方法不可用,则跳过相似度计算,直接返回结果
|
||||
logger.warning(f"嵌入计算失败: {str(e)},跳过相似度排序")
|
||||
for result in keyword_results:
|
||||
result["similarity"] = 0.0
|
||||
sorted_results = keyword_results[:top_k]
|
||||
|
||||
# 返回top_k个结果
|
||||
return sorted_results[:top_k]
|
||||
|
||||
def _node_to_text(self, node: Dict[str, Any]) -> str:
|
||||
"""
|
||||
将节点转换为文本表示
|
||||
|
||||
参数:
|
||||
node: 节点信息
|
||||
|
||||
返回:
|
||||
节点的文本表示
|
||||
"""
|
||||
if not node:
|
||||
return None
|
||||
|
||||
# 获取节点类型
|
||||
node_type = "未知类型"
|
||||
if "labels" in node and node["labels"]:
|
||||
if isinstance(node["labels"], list) and len(node["labels"]) > 0:
|
||||
node_type = node["labels"][0]
|
||||
elif isinstance(node["labels"], str):
|
||||
node_type = node["labels"]
|
||||
|
||||
# 获取节点名称
|
||||
name = node.get("original_name", "") or node.get("display_name", "") or node.get("name", "未知名称")
|
||||
|
||||
# 获取节点描述
|
||||
description = node.get("描述", "")
|
||||
|
||||
# 获取节点路径
|
||||
path = ""
|
||||
if "path_to_root" in node:
|
||||
if isinstance(node["path_to_root"], list):
|
||||
path = " > ".join(node["path_to_root"])
|
||||
else:
|
||||
path = str(node["path_to_root"])
|
||||
|
||||
# 构建文本表示
|
||||
text_parts = [f"类型: {node_type}", f"名称: {name}"]
|
||||
|
||||
if path:
|
||||
text_parts.append(f"路径: {path}")
|
||||
|
||||
if description:
|
||||
text_parts.append(f"描述: {description}")
|
||||
|
||||
# 添加其他重要属性
|
||||
for key, value in node.items():
|
||||
if key not in ["labels", "original_name", "display_name", "name", "描述", "path_to_root"] and value:
|
||||
if not isinstance(value, (list, dict)): # 只添加简单类型的属性
|
||||
text_parts.append(f"{key}: {value}")
|
||||
|
||||
return "\n".join(text_parts)
|
||||
|
||||
def _compute_similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float:
|
||||
"""
|
||||
计算两个嵌入向量的余弦相似度
|
||||
|
||||
参数:
|
||||
embedding1: 第一个嵌入向量
|
||||
embedding2: 第二个嵌入向量
|
||||
|
||||
返回:
|
||||
余弦相似度
|
||||
"""
|
||||
# 计算余弦相似度
|
||||
dot_product = np.dot(embedding1, embedding2)
|
||||
norm1 = np.linalg.norm(embedding1)
|
||||
norm2 = np.linalg.norm(embedding2)
|
||||
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0.0
|
||||
|
||||
return dot_product / (norm1 * norm2)
|
||||
Reference in New Issue
Block a user