95 lines
2.7 KiB
Python
95 lines
2.7 KiB
Python
import numpy as np
|
|
from typing import List, Dict, Any
|
|
from .llm import embedding as embedding_model
|
|
|
|
class TextEmbedder:
|
|
def __init__(self):
|
|
"""初始化文本嵌入器,使用自定义的embedding模型"""
|
|
pass
|
|
|
|
def get_embedding(self, text: str) -> np.ndarray:
|
|
"""
|
|
获取文本的嵌入向量
|
|
|
|
参数:
|
|
text: 输入文本
|
|
|
|
返回:
|
|
嵌入向量
|
|
"""
|
|
return embedding_model.embed(text)
|
|
|
|
def get_embeddings(self, texts: List[str]) -> List[np.ndarray]:
|
|
"""
|
|
批量获取文本的嵌入向量
|
|
|
|
参数:
|
|
texts: 输入文本列表
|
|
|
|
返回:
|
|
嵌入向量列表
|
|
"""
|
|
return [embedding_model.embed(text) for text in texts]
|
|
|
|
def compute_similarity(self, text1: str, text2: str) -> float:
|
|
"""
|
|
计算两段文本的相似度
|
|
|
|
参数:
|
|
text1: 第一段文本
|
|
text2: 第二段文本
|
|
|
|
返回:
|
|
相似度分数 (0-1)
|
|
"""
|
|
emb1 = self.get_embedding(text1)
|
|
emb2 = self.get_embedding(text2)
|
|
return self._cosine_similarity(emb1, emb2)
|
|
|
|
def find_most_similar(self, query: str, candidates: List[str], top_k: int = 5) -> List[Dict[str, Any]]:
|
|
"""
|
|
找出与查询最相似的候选文本
|
|
|
|
参数:
|
|
query: 查询文本
|
|
candidates: 候选文本列表
|
|
top_k: 返回的最相似文本数量
|
|
|
|
返回:
|
|
最相似文本及其相似度的列表
|
|
"""
|
|
query_emb = self.get_embedding(query)
|
|
candidate_embs = self.get_embeddings(candidates)
|
|
|
|
similarities = []
|
|
for i, emb in enumerate(candidate_embs):
|
|
similarity = self._cosine_similarity(query_emb, emb)
|
|
similarities.append((i, similarity))
|
|
|
|
# 按相似度降序排序
|
|
similarities.sort(key=lambda x: x[1], reverse=True)
|
|
|
|
results = []
|
|
for i, sim in similarities[:top_k]:
|
|
results.append({
|
|
"text": candidates[i],
|
|
"similarity": float(sim)
|
|
})
|
|
|
|
return results
|
|
|
|
def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
|
|
"""
|
|
计算两个向量的余弦相似度
|
|
|
|
参数:
|
|
vec1: 第一个向量
|
|
vec2: 第二个向量
|
|
|
|
返回:
|
|
余弦相似度 (0-1)
|
|
"""
|
|
dot_product = np.dot(vec1, vec2)
|
|
norm1 = np.linalg.norm(vec1)
|
|
norm2 = np.linalg.norm(vec2)
|
|
return dot_product / (norm1 * norm2) |