Files
QueryRewrite/rag2_0/dify/DifyQueryRetrieval.py
T

268 lines
12 KiB
Python

import sys
import os
import json
from typing import List, Dict, Any
import logging
import time
import asyncio
import httpx
from rag2_0.dify.dify_client.client import DifyClient, KnowledgeBaseClient
from rag2_0.tool.ModelTool import XinferenceReRankerModel
class DifyQueryRetrieval:
def __init__(self, dify_dataset_key: str, dify_base_url: str):
self._dify_dataset_key = dify_dataset_key
self._dify_base_url = dify_base_url
self._datasets_list = self.get_datasets_list()
def get_datasets_list(self) -> Dict[str, str]:
client = KnowledgeBaseClient(api_key=self._dify_dataset_key, base_url=self._dify_base_url)
datasets = client.list_datasets(page_size=50)
datasets_json = datasets.json()
return {dataset["name"]:dataset for dataset in datasets_json["data"]}
def retrieve_by_dataset(self, query: str, dataset_name: str, metadata_filtering_conditions:dict = {}) -> Dict[str, Any]:
try:
dataset_id = self._datasets_list[dataset_name]["id"]
retrieval_model = self._datasets_list[dataset_name]["retrieval_model_dict"]
knowledge_base_client = KnowledgeBaseClient(api_key=self._dify_dataset_key, base_url=self._dify_base_url, dataset_id=dataset_id)
if len(metadata_filtering_conditions) !=0:
retrieval_model["metadata_filtering_conditions"]=metadata_filtering_conditions
documents = knowledge_base_client.retrieve(query, retrieval_model=retrieval_model, timeout=300)
retrieved_documents = documents.json().get("records", [])
# 添加数据集信息
for retrieved_document in retrieved_documents:
retrieved_document["dataset_id"] = dataset_id
retrieved_document["dataset_name"] = dataset_name
# 返回包含查询和文档的字典
return {
"query": query,
"dataset_name": dataset_name,
"documents": retrieved_documents
}
except Exception as e:
logging.error(f"检索数据集 {dataset_name} 时出错: {str(e)}", exc_info=True)
return {
"query": query,
"dataset_name": dataset_name,
"documents": []
}
async def retrieve_by_dataset_async(self, query: str, dataset_name: str, metadata_filtering_conditions:dict = {}) -> Dict[str, Any]:
"""
异步版本的retrieve_by_dataset方法
Args:
query: 查询字符串
dataset_name: 数据集名称
Returns:
包含查询、数据集名称和检索到的文档的字典
"""
try:
# 使用asyncio.to_thread包装同步方法
return await asyncio.to_thread(
self.retrieve_by_dataset,
query,
dataset_name,
metadata_filtering_conditions
)
except Exception as e:
logging.error(f"异步检索数据集 {dataset_name} 时出错: {str(e)}", exc_info=True)
return {
"query": query,
"dataset_name": dataset_name,
"documents": []
}
async def retrieve_api_async(self,
original_query: str,
query_list: List[str],
data_set_list: List[str],
query_expand_dict: dict,
top_k: int = 5,
metadata_filtering_conditions:dict = {})->Dict[str, Any]:
"""
异步版本的retrieve_api方法,使用asyncio代替线程池
Args:
original_query: 原始查询
query_list: 查询列表
data_set_list: 数据集列表
query_expand_dict: 查询扩展字典,包含不同类型的查询
top_k: 返回的文档数量
Returns:
包含检索结果和查询命中统计的字典
"""
all_documents = []
query_document_mapping = {} # 用于存储查询和文档的映射关系,键为查询,值为文档列表
# 记录开始时间
time_start = time.time()
# 创建异步任务列表
tasks = []
for query in query_list:
for dataset in data_set_list:
if dataset not in list(self._datasets_list.keys()):
logging.error(f"dataset {dataset} not in datasets_list")
continue
# 创建异步任务
task = self.retrieve_by_dataset_async(query, dataset, metadata_filtering_conditions)
tasks.append(task)
# 并发执行所有异步任务
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理结果
for result in results:
if isinstance(result, Exception):
logging.error(f"异步检索过程中发生错误: {str(result)}", exc_info=True)
else:
# 添加查询和文档的映射关系
query = result["query"]
if query not in query_document_mapping:
query_document_mapping[query] = []
# 将文档添加到对应查询的文档列表中
query_document_mapping[query].extend([item['segment']['id'] for item in result["documents"]])
# 将文档添加到总文档列表
all_documents.extend(result["documents"])
time_end = time.time()
logging.info(f"异步检索耗时: {time_end - time_start:.2f}秒")
# 根据segment_id对文档进行去重
unique_documents = {}
for document in all_documents:
segment_id = document['segment']['id']
if segment_id not in unique_documents:
unique_documents[segment_id] = document
# 将去重后的文档转换为列表
deduplicated_documents = list(unique_documents.values())
if len(deduplicated_documents) == 0:
return {
"documents": [],
"query_hit_stats": {}
}
# 对所有检索出来的文档进行重排序
time_start = time.time()
processed_documents = await self.data_post_processor_async(original_query, deduplicated_documents, top_k)
time_end = time.time()
logging.info(f"异步检索后重排序耗时: {time_end - time_start:.2f}秒")
# 统计不同类型查询命中的文档
query_hit_stats = {}
# 获取重排序后的文档ID列表
reranked_doc_ids = [doc["metadata"]["segment_id"] for doc in processed_documents]
reranked_doc_titles = [doc["title"].split("/")[-1] for doc in processed_documents]
# 解析query_expand_dict(如果是字符串)
if isinstance(query_expand_dict, str):
try:
query_expand_dict = json.loads(query_expand_dict)
except Exception as e:
logging.error(f"解析query_expand_dict失败: {str(e)}")
query_expand_dict = {}
# 统计各类型查询命中的文档
for query_type, queries in query_expand_dict.items():
if not isinstance(queries, list):
queries = [queries]
# 初始化该查询类型命中的文档列表
query_hit_stats[query_type] = []
# 合并所有该类型查询命中的文档
hit_doc_ids_set = set()
for query in queries:
if query in query_document_mapping:
hit_doc_ids_set.update(set(query_document_mapping[query]))
# 找出在重排序结果中的文档
hit_doc_titles_set = set() # 用于去重
for i, doc_id in enumerate(reranked_doc_ids):
if doc_id in hit_doc_ids_set:
doc_title = reranked_doc_titles[i]
if doc_title not in hit_doc_titles_set: # 确保不添加重复的标题
hit_doc_titles_set.add(doc_title)
query_hit_stats[query_type].append(doc_title)
return {
"documents": processed_documents,
"query_hit_stats": query_hit_stats
}
async def data_post_processor_async(self, query: str, all_documents: List[Dict[str, Any]], top_k: int = 5) -> List[Dict[str, Any]]:
"""
异步版本的data_post_processor方法
Args:
query: 查询字符串
all_documents: 待处理的文档列表
Returns:
处理后的文档列表
"""
reranker_model = XinferenceReRankerModel()
documents = [document['segment']['content'] for document in all_documents]
# 使用异步重排序方法
reranked_documents = await reranker_model.rerank_async(query, documents, top_k=top_k)
new_all_documents = []
def to_dify_document_format(document: dict)->dict:
return {
"metadata": {
"_source": "knowledge",
"dataset_id": document["dataset_id"],
"dataset_name": document["dataset_name"],
"document_id": document['segment']['document_id'],
"document_name": document["segment"]["document"]["name"],
"data_source_type": document["segment"]["document"]["data_source_type"],
"segment_id": document["segment"]["id"],
"retriever_from": "api",
"score": document.get("score", 0),
"segment_hit_count": document.get("segment", {}).get("hit_count", 0),
"segment_word_count": document.get("segment", {}).get("word_count", 0),
"segment_position": document.get("segment", {}).get("position", 0),
"segment_index_node_hash": document.get("segment", {}).get("index_node_hash", ""),
"doc_metadata": document.get("segment", {}).get("document", {}).get("doc_metadata", None),
"position": document["segment"].get("position", 0)
},
"title": document["segment"]["document"]["name"],
"content": document["segment"]["content"]
}
for reranked_document in reranked_documents:
cur_doc_info = all_documents[reranked_document["index"]]
cur_doc_info["score"] = reranked_document["score"]
new_all_documents.append(to_dify_document_format(cur_doc_info))
return new_all_documents
if __name__ == "__main__":
dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1")
# 测试异步API
async def test_async_api():
datasets = await dify_query_retrieval.retrieve_api_async(
"电力建设计价通软件如何批量修改设备价格?",
["电力建设计价通软件如何批量修改设备价格?"],
["电力建设计价通(2018)软件知识(new)"]
)
print("异步API测试结果:")
print(json.dumps(datasets, ensure_ascii=False, indent=2))
# 如果需要测试异步API,取消下面的注释
import asyncio
asyncio.run(test_async_api())