重构DifyQueryRetrieval_api.py为FastAPI应用,新增异步检索API和健康检查端点,优化错误处理和日志记录,同时更新DifyQueryRetrieval类以支持异步检索功能,提升整体性能和可维护性。

This commit is contained in:
2025-07-08 08:38:14 +08:00
parent 1f3e97d081
commit 5a8d042360
2 changed files with 256 additions and 28 deletions
+166 -4
View File
@@ -6,6 +6,8 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Dict, Any, Optional
import logging
import time
import asyncio
import httpx
sys.path.append(os.getcwd())
from rag2_0.intent_recognition.DataModels import Classification
@@ -52,6 +54,28 @@ class DifyQueryRetrieval:
logging.error(f"检索数据集 {dataset_name} 时出错: {str(e)}", exc_info=True)
return []
async def retrieve_by_dataset_async(self, query: str, dataset_name: str) -> List[Dict[str, Any]]:
"""
异步版本的retrieve_by_dataset方法
Args:
query: 查询字符串
dataset_name: 数据集名称
Returns:
检索到的文档列表
"""
try:
# 使用asyncio.to_thread包装同步方法
return await asyncio.to_thread(
self.retrieve_by_dataset,
query,
dataset_name
)
except Exception as e:
logging.error(f"异步检索数据集 {dataset_name} 时出错: {str(e)}", exc_info=True)
return []
def retrieve(self, original_query: str, query_list: List[str], classification: Classification, software_name: str) -> Optional[List[Dict[str, Any]]]:
datasets = self.get_datasets_by_classification(classification, software_name)
if len(datasets) == 0:
@@ -59,6 +83,25 @@ class DifyQueryRetrieval:
return self.retrieve_api(original_query, query_list, datasets)
async def retrieve_async(self, original_query: str, query_list: List[str], classification: Classification, software_name: str) -> Optional[List[Dict[str, Any]]]:
"""
异步版本的retrieve方法
Args:
original_query: 原始查询
query_list: 查询列表
classification: 分类信息
software_name: 软件名称
Returns:
检索到的文档列表
"""
datasets = self.get_datasets_by_classification(classification, software_name)
if len(datasets) == 0:
return None
return await self.retrieve_api_async(original_query, query_list, datasets)
def retrieve_api(self, original_query: str, query_list: List[str],data_set_list: List[str])->List[Dict[str, Any]]:
all_documents=[]
# 使用线程池替代无限制创建线程
@@ -103,6 +146,65 @@ class DifyQueryRetrieval:
return processed_documents
async def retrieve_api_async(self, original_query: str, query_list: List[str], data_set_list: List[str])->List[Dict[str, Any]]:
"""
异步版本的retrieve_api方法,使用asyncio代替线程池
Args:
original_query: 原始查询
query_list: 查询列表
data_set_list: 数据集列表
Returns:
检索并重排序后的文档列表
"""
all_documents = []
# 记录开始时间
time_start = time.time()
# 创建异步任务列表
tasks = []
for query in query_list:
for dataset in data_set_list:
if dataset not in self._datasets_list:
logging.error(f"dataset {dataset} not in datasets_list")
continue
# 创建异步任务
task = self.retrieve_by_dataset_async(query, dataset)
tasks.append(task)
# 并发执行所有异步任务
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理结果
for result in results:
if isinstance(result, Exception):
logging.error(f"异步检索过程中发生错误: {str(result)}", exc_info=True)
else:
all_documents.extend(result)
time_end = time.time()
logging.info(f"异步检索耗时: {time_end - time_start:.2f}")
# 根据segment_id对文档进行去重
unique_documents = {}
for document in all_documents:
segment_id = document['segment']['id']
if segment_id not in unique_documents:
unique_documents[segment_id] = document
# 将去重后的文档转换为列表
deduplicated_documents = list(unique_documents.values())
# 对所有检索出来的文档进行重排序
time_start = time.time()
processed_documents = await self.data_post_processor_async(original_query, deduplicated_documents)
time_end = time.time()
logging.info(f"异步检索后重排序耗时: {time_end - time_start:.2f}")
return processed_documents
def data_post_processor(self, query: str, all_documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
reranker_model = XinferenceReRankerModel()
documents = [document['segment']['content'] for document in all_documents]
@@ -138,6 +240,52 @@ class DifyQueryRetrieval:
new_all_documents.append(to_dify_document_format(cur_doc_info))
return new_all_documents
async def data_post_processor_async(self, query: str, all_documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
异步版本的data_post_processor方法
Args:
query: 查询字符串
all_documents: 待处理的文档列表
Returns:
处理后的文档列表
"""
reranker_model = XinferenceReRankerModel()
documents = [document['segment']['content'] for document in all_documents]
# 使用异步重排序方法
reranked_documents = await reranker_model.rerank_async(query, documents, top_k=5)
new_all_documents = []
def to_dify_document_format(document: dict)->dict:
return {
"metadata": {
"_source": "knowledge",
"dataset_id": document["dataset_id"],
"dataset_name": document["dataset_name"],
"document_id": document['segment']['document_id'],
"document_name": document["segment"]["document"]["name"],
"data_source_type": document["segment"]["document"]["data_source_type"],
"segment_id": document["segment"]["id"],
"retriever_from": "api",
"score": document.get("score", 0),
"segment_hit_count": document.get("segment", {}).get("hit_count", 0),
"segment_word_count": document.get("segment", {}).get("word_count", 0),
"segment_position": document.get("segment", {}).get("position", 0),
"segment_index_node_hash": document.get("segment", {}).get("index_node_hash", ""),
"doc_metadata": document.get("segment", {}).get("document", {}).get("doc_metadata", None),
"position": document["segment"].get("position", 0)
},
"title": document["segment"]["document"]["name"],
"content": document["segment"]["content"]
}
for reranked_document in reranked_documents:
cur_doc_info = all_documents[reranked_document["index"]]
cur_doc_info["score"] = reranked_document["score"]
new_all_documents.append(to_dify_document_format(cur_doc_info))
return new_all_documents
def get_datasets_by_classification(self, classification: Classification, software_name: str) -> List[str]:
if classification.vertical_classification=="软件问题" or classification.vertical_classification=="业务问题":
software_name_list = self.software_to_dataset_map.keys()
@@ -165,7 +313,21 @@ class DifyQueryRetrieval:
if __name__ == "__main__":
dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1")
# datasets = dify_query_retrieval.retrieve("配网工程计价通D3软件如何新建工程?", Classification(vertical_classification="软件问题", sub_classification="软件功能"), "配网工程计价通D3")
datasets = dify_query_retrieval.retrieve_api("电力建设计价通软件如何批量修改设备价格?",
["电力建设计价通软件如何批量修改设备价格?"],
["电力建设计价通(2018)软件知识(new)"])
print(json.dumps(datasets, ensure_ascii=False, indent=2))
# datasets = dify_query_retrieval.retrieve_api("电力建设计价通软件如何批量修改设备价格?",
# ["电力建设计价通软件如何批量修改设备价格?"],
# ["电力建设计价通(2018)软件知识(new)"])
# print(json.dumps(datasets, ensure_ascii=False, indent=2))
# 测试异步API
async def test_async_api():
datasets = await dify_query_retrieval.retrieve_api_async(
"电力建设计价通软件如何批量修改设备价格?",
["电力建设计价通软件如何批量修改设备价格?"],
["电力建设计价通(2018)软件知识(new)"]
)
print("异步API测试结果:")
print(json.dumps(datasets, ensure_ascii=False, indent=2))
# 如果需要测试异步API,取消下面的注释
import asyncio
asyncio.run(test_async_api())