首次提交:上传本地文件夹

This commit is contained in:
ruxia
2025-03-31 17:28:23 +08:00
commit 0de349447c
439 changed files with 36643 additions and 0 deletions
+197
View File
@@ -0,0 +1,197 @@
from typing import List, Dict, Any, Optional
from graph.graph_query import KnowledgeGraphQuerier
from utils.embedding import TextEmbedder
import logging
import sys
import numpy as np
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger("GraphRetriever")
class GraphRetriever:
def __init__(self, graph_querier: KnowledgeGraphQuerier):
"""
初始化图检索器
参数:
graph_querier: 知识图谱查询器
"""
self.graph_querier = graph_querier
self.embedding_model = TextEmbedder()
logger.info("GraphRetriever初始化完成")
def retrieve(self, query: str, top_k: int = 5, slots: Dict[str, Any] = None) -> List[Dict[str, Any]]:
"""
检索与查询相关的信息
参数:
query: 用户查询
top_k: 返回的结果数量
slots: 槽位信息,用于检索
返回:
检索结果列表
"""
logger.info(f"开始检索,查询: '{query}'top_k: {top_k}")
# 1. 关键词搜索
logger.info("步骤1: 执行关键词搜索")
keyword_results = []
# 使用槽位信息进行检索
if slots and isinstance(slots, dict) and len(slots) > 0:
logger.info(f"使用槽位信息进行检索: {slots}")
# 从槽位中提取关键词
slot_keywords = []
for slot_name, slot_value in slots.items():
if slot_value and isinstance(slot_value, str):
slot_keywords.append(slot_value)
logger.info(f"从槽位中提取的关键词: {slot_keywords}")
# 使用槽位关键词进行检索
for keyword in slot_keywords:
logger.info(f"使用槽位关键词: '{keyword}'")
results = self.graph_querier.search_by_keyword(keyword)
logger.info(f"槽位关键词'{keyword}'搜索结果数量: {len(results)}")
if results:
for node in results:
# 转换为文本并添加到结果中
text = self._node_to_text(node)
if text:
keyword_results.append({
"node": node,
"text": text
})
logger.info(f"找到结果,当前总结果数量: {len(keyword_results)}")
# 如果已经找到足够多的结果,可以停止搜索
if len(keyword_results) >= top_k * 2:
break
# 如果槽位检索没有结果或没有提供槽位,使用原始查询
if not keyword_results:
logger.info("使用原始查询进行检索")
nodes = self.graph_querier.search_by_keyword(query)
logger.info(f"原始查询搜索结果数量: {len(nodes)}")
for node in nodes:
# 转换为文本并添加到结果中
text = self._node_to_text(node)
if text:
keyword_results.append({
"node": node,
"text": text
})
# 2. 计算相似度并排序
logger.info("步骤2: 计算相似度并排序")
if not keyword_results:
logger.warning("没有找到任何结果")
return []
try:
# 计算查询与每个结果的相似度
# 修改这里,使用正确的方法名
query_embedding = self.embedding_model.get_embedding(query)
for result in keyword_results:
text = result["text"]
text_embedding = self.embedding_model.get_embedding(text)
similarity = self._compute_similarity(query_embedding, text_embedding)
result["similarity"] = similarity
# 按相似度排序
sorted_results = sorted(keyword_results, key=lambda x: x["similarity"], reverse=True)
except AttributeError as e:
# 如果嵌入模型方法不可用,则跳过相似度计算,直接返回结果
logger.warning(f"嵌入计算失败: {str(e)},跳过相似度排序")
for result in keyword_results:
result["similarity"] = 0.0
sorted_results = keyword_results[:top_k]
# 返回top_k个结果
return sorted_results[:top_k]
def _node_to_text(self, node: Dict[str, Any]) -> str:
"""
将节点转换为文本表示
参数:
node: 节点信息
返回:
节点的文本表示
"""
if not node:
return None
# 获取节点类型
node_type = "未知类型"
if "labels" in node and node["labels"]:
if isinstance(node["labels"], list) and len(node["labels"]) > 0:
node_type = node["labels"][0]
elif isinstance(node["labels"], str):
node_type = node["labels"]
# 获取节点名称
name = node.get("original_name", "") or node.get("display_name", "") or node.get("name", "未知名称")
# 获取节点描述
description = node.get("描述", "")
# 获取节点路径
path = ""
if "path_to_root" in node:
if isinstance(node["path_to_root"], list):
path = " > ".join(node["path_to_root"])
else:
path = str(node["path_to_root"])
# 构建文本表示
text_parts = [f"类型: {node_type}", f"名称: {name}"]
if path:
text_parts.append(f"路径: {path}")
if description:
text_parts.append(f"描述: {description}")
# 添加其他重要属性
for key, value in node.items():
if key not in ["labels", "original_name", "display_name", "name", "描述", "path_to_root"] and value:
if not isinstance(value, (list, dict)): # 只添加简单类型的属性
text_parts.append(f"{key}: {value}")
return "\n".join(text_parts)
def _compute_similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float:
"""
计算两个嵌入向量的余弦相似度
参数:
embedding1: 第一个嵌入向量
embedding2: 第二个嵌入向量
返回:
余弦相似度
"""
# 计算余弦相似度
dot_product = np.dot(embedding1, embedding2)
norm1 = np.linalg.norm(embedding1)
norm2 = np.linalg.norm(embedding2)
if norm1 == 0 or norm2 == 0:
return 0.0
return dot_product / (norm1 * norm2)