重构DifyQueryRetrieval_api.py为FastAPI应用,新增异步检索API和健康检查端点,优化错误处理和日志记录,同时更新DifyQueryRetrieval类以支持异步检索功能,提升整体性能和可维护性。

This commit is contained in:
2025-07-08 08:38:14 +08:00
parent 1f3e97d081
commit 5a8d042360
2 changed files with 256 additions and 28 deletions
+166 -4
View File
@@ -6,6 +6,8 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
import logging import logging
import time import time
import asyncio
import httpx
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
from rag2_0.intent_recognition.DataModels import Classification from rag2_0.intent_recognition.DataModels import Classification
@@ -52,6 +54,28 @@ class DifyQueryRetrieval:
logging.error(f"检索数据集 {dataset_name} 时出错: {str(e)}", exc_info=True) logging.error(f"检索数据集 {dataset_name} 时出错: {str(e)}", exc_info=True)
return [] return []
async def retrieve_by_dataset_async(self, query: str, dataset_name: str) -> List[Dict[str, Any]]:
"""
异步版本的retrieve_by_dataset方法
Args:
query: 查询字符串
dataset_name: 数据集名称
Returns:
检索到的文档列表
"""
try:
# 使用asyncio.to_thread包装同步方法
return await asyncio.to_thread(
self.retrieve_by_dataset,
query,
dataset_name
)
except Exception as e:
logging.error(f"异步检索数据集 {dataset_name} 时出错: {str(e)}", exc_info=True)
return []
def retrieve(self, original_query: str, query_list: List[str], classification: Classification, software_name: str) -> Optional[List[Dict[str, Any]]]: def retrieve(self, original_query: str, query_list: List[str], classification: Classification, software_name: str) -> Optional[List[Dict[str, Any]]]:
datasets = self.get_datasets_by_classification(classification, software_name) datasets = self.get_datasets_by_classification(classification, software_name)
if len(datasets) == 0: if len(datasets) == 0:
@@ -59,6 +83,25 @@ class DifyQueryRetrieval:
return self.retrieve_api(original_query, query_list, datasets) return self.retrieve_api(original_query, query_list, datasets)
async def retrieve_async(self, original_query: str, query_list: List[str], classification: Classification, software_name: str) -> Optional[List[Dict[str, Any]]]:
"""
异步版本的retrieve方法
Args:
original_query: 原始查询
query_list: 查询列表
classification: 分类信息
software_name: 软件名称
Returns:
检索到的文档列表
"""
datasets = self.get_datasets_by_classification(classification, software_name)
if len(datasets) == 0:
return None
return await self.retrieve_api_async(original_query, query_list, datasets)
def retrieve_api(self, original_query: str, query_list: List[str],data_set_list: List[str])->List[Dict[str, Any]]: def retrieve_api(self, original_query: str, query_list: List[str],data_set_list: List[str])->List[Dict[str, Any]]:
all_documents=[] all_documents=[]
# 使用线程池替代无限制创建线程 # 使用线程池替代无限制创建线程
@@ -103,6 +146,65 @@ class DifyQueryRetrieval:
return processed_documents return processed_documents
async def retrieve_api_async(self, original_query: str, query_list: List[str], data_set_list: List[str])->List[Dict[str, Any]]:
"""
异步版本的retrieve_api方法,使用asyncio代替线程池
Args:
original_query: 原始查询
query_list: 查询列表
data_set_list: 数据集列表
Returns:
检索并重排序后的文档列表
"""
all_documents = []
# 记录开始时间
time_start = time.time()
# 创建异步任务列表
tasks = []
for query in query_list:
for dataset in data_set_list:
if dataset not in self._datasets_list:
logging.error(f"dataset {dataset} not in datasets_list")
continue
# 创建异步任务
task = self.retrieve_by_dataset_async(query, dataset)
tasks.append(task)
# 并发执行所有异步任务
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理结果
for result in results:
if isinstance(result, Exception):
logging.error(f"异步检索过程中发生错误: {str(result)}", exc_info=True)
else:
all_documents.extend(result)
time_end = time.time()
logging.info(f"异步检索耗时: {time_end - time_start:.2f}")
# 根据segment_id对文档进行去重
unique_documents = {}
for document in all_documents:
segment_id = document['segment']['id']
if segment_id not in unique_documents:
unique_documents[segment_id] = document
# 将去重后的文档转换为列表
deduplicated_documents = list(unique_documents.values())
# 对所有检索出来的文档进行重排序
time_start = time.time()
processed_documents = await self.data_post_processor_async(original_query, deduplicated_documents)
time_end = time.time()
logging.info(f"异步检索后重排序耗时: {time_end - time_start:.2f}")
return processed_documents
def data_post_processor(self, query: str, all_documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]: def data_post_processor(self, query: str, all_documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
reranker_model = XinferenceReRankerModel() reranker_model = XinferenceReRankerModel()
documents = [document['segment']['content'] for document in all_documents] documents = [document['segment']['content'] for document in all_documents]
@@ -138,6 +240,52 @@ class DifyQueryRetrieval:
new_all_documents.append(to_dify_document_format(cur_doc_info)) new_all_documents.append(to_dify_document_format(cur_doc_info))
return new_all_documents return new_all_documents
async def data_post_processor_async(self, query: str, all_documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
异步版本的data_post_processor方法
Args:
query: 查询字符串
all_documents: 待处理的文档列表
Returns:
处理后的文档列表
"""
reranker_model = XinferenceReRankerModel()
documents = [document['segment']['content'] for document in all_documents]
# 使用异步重排序方法
reranked_documents = await reranker_model.rerank_async(query, documents, top_k=5)
new_all_documents = []
def to_dify_document_format(document: dict)->dict:
return {
"metadata": {
"_source": "knowledge",
"dataset_id": document["dataset_id"],
"dataset_name": document["dataset_name"],
"document_id": document['segment']['document_id'],
"document_name": document["segment"]["document"]["name"],
"data_source_type": document["segment"]["document"]["data_source_type"],
"segment_id": document["segment"]["id"],
"retriever_from": "api",
"score": document.get("score", 0),
"segment_hit_count": document.get("segment", {}).get("hit_count", 0),
"segment_word_count": document.get("segment", {}).get("word_count", 0),
"segment_position": document.get("segment", {}).get("position", 0),
"segment_index_node_hash": document.get("segment", {}).get("index_node_hash", ""),
"doc_metadata": document.get("segment", {}).get("document", {}).get("doc_metadata", None),
"position": document["segment"].get("position", 0)
},
"title": document["segment"]["document"]["name"],
"content": document["segment"]["content"]
}
for reranked_document in reranked_documents:
cur_doc_info = all_documents[reranked_document["index"]]
cur_doc_info["score"] = reranked_document["score"]
new_all_documents.append(to_dify_document_format(cur_doc_info))
return new_all_documents
def get_datasets_by_classification(self, classification: Classification, software_name: str) -> List[str]: def get_datasets_by_classification(self, classification: Classification, software_name: str) -> List[str]:
if classification.vertical_classification=="软件问题" or classification.vertical_classification=="业务问题": if classification.vertical_classification=="软件问题" or classification.vertical_classification=="业务问题":
software_name_list = self.software_to_dataset_map.keys() software_name_list = self.software_to_dataset_map.keys()
@@ -165,7 +313,21 @@ class DifyQueryRetrieval:
if __name__ == "__main__": if __name__ == "__main__":
dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1") dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1")
# datasets = dify_query_retrieval.retrieve("配网工程计价通D3软件如何新建工程?", Classification(vertical_classification="软件问题", sub_classification="软件功能"), "配网工程计价通D3") # datasets = dify_query_retrieval.retrieve("配网工程计价通D3软件如何新建工程?", Classification(vertical_classification="软件问题", sub_classification="软件功能"), "配网工程计价通D3")
datasets = dify_query_retrieval.retrieve_api("电力建设计价通软件如何批量修改设备价格?", # datasets = dify_query_retrieval.retrieve_api("电力建设计价通软件如何批量修改设备价格?",
["电力建设计价通软件如何批量修改设备价格?"], # ["电力建设计价通软件如何批量修改设备价格?"],
["电力建设计价通(2018)软件知识(new)"]) # ["电力建设计价通(2018)软件知识(new)"])
print(json.dumps(datasets, ensure_ascii=False, indent=2)) # print(json.dumps(datasets, ensure_ascii=False, indent=2))
# 测试异步API
async def test_async_api():
datasets = await dify_query_retrieval.retrieve_api_async(
"电力建设计价通软件如何批量修改设备价格?",
["电力建设计价通软件如何批量修改设备价格?"],
["电力建设计价通(2018)软件知识(new)"]
)
print("异步API测试结果:")
print(json.dumps(datasets, ensure_ascii=False, indent=2))
# 如果需要测试异步API,取消下面的注释
import asyncio
asyncio.run(test_async_api())
+90 -24
View File
@@ -1,19 +1,25 @@
# from gevent import monkey # from gevent import monkey
# monkey.patch_all() # monkey.patch_all()
from flask import Flask, request, Response
import os import os
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import Dict, List, Any, Optional
import asyncio
from dotenv import load_dotenv from dotenv import load_dotenv
import json import json
import time import time
from gevent.lock import RLock
import datetime import datetime
import logging import logging
# 加载环境变量 # 加载环境变量
load_dotenv() load_dotenv()
from DifyQueryRetrieval import DifyQueryRetrieval
import sys
sys.path.append(os.getcwd())
from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval
logging.basicConfig( logging.basicConfig(
@@ -28,27 +34,87 @@ logging.getLogger('openai').setLevel(logging.WARNING)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
app = Flask(__name__) # 定义请求模型
class RetrieveRequest(BaseModel):
original_query: str
query_list: str
data_set_list: str
dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1") # 创建FastAPI应用
@app.route('/retrieve', methods=['POST']) app = FastAPI(
def retrieve(): title="Dify查询检索服务",
data = request.get_json(force=True) description="基于Dify的异步查询检索服务",
original_query_str = data.get('original_query') version="1.0"
query_list_str = data.get('query_list') )
data_set_list_str = data.get('data_set_list')
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 全局变量存储DifyQueryRetrieval实例
dify_query_retrieval = None
# 应用启动事件
@app.on_event("startup")
async def startup_event():
global dify_query_retrieval
# 初始化DifyQueryRetrieval实例
dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1")
logger.info("DifyQueryRetrieval初始化完成")
# 添加健康检查端点
@app.get("/health", summary="健康检查")
async def health_check():
return {"status": "ok"}
@app.post("/retrieve", summary="异步检索API")
async def retrieve(request: RetrieveRequest):
"""
异步检索API
Args:
request: 包含原始查询、查询列表和数据集列表的请求对象
Returns:
检索结果
"""
try: try:
query_list = query_list_str.split("<sub_query>") # 解析查询列表和数据集列表
data_set_list = data_set_list_str.split("<dataset>") query_list = request.query_list.split("<sub_query>")
results = dify_query_retrieval.retrieve_api(original_query_str, query_list, data_set_list) data_set_list = request.data_set_list.split("<dataset>")
return Response(json.dumps(results, ensure_ascii=False), content_type='application/json; charset=utf-8')
# 调用异步检索方法
start_time = time.time()
results = await dify_query_retrieval.retrieve_api_async(
request.original_query,
query_list,
data_set_list
)
end_time = time.time()
logger.info(f"异步检索总耗时: {end_time - start_time:.2f}")
return results
except Exception as e: except Exception as e:
logger.error(f"检索出错: {str(e)}", exc_info=True) logger.error(f"异步检索出错: {str(e)}", exc_info=True)
return Response(json.dumps({"error": str(e)}, ensure_ascii=False), content_type='application/json; charset=utf-8', status=500) raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__": if __name__ == "__main__":
# 开发环境使用Flask内置服务器 # 使用Uvicorn运行FastAPI应用
# 生产环境使用gunicorn支持高并发 uv run gunicorn -w 10 -k gevent -b 0.0.0.0:8001 rag2_0.dify.DifyQueryRetrieval_api:app import uvicorn
# uv run gunicorn -w 10 -k gevent --preload -b 0.0.0.0:8001 rag2_0.dify.DifyQueryRetrieval_api:app uvicorn.run("rag2_0.dify.DifyQueryRetrieval_api:app", host="0.0.0.0", port=8002, reload=False, workers=10, log_level="info")
app.run(host="0.0.0.0", port=8002, threaded=True) # # 使用uvicorn启动服务
# import uvicorn
# uvicorn.run(
# "rag2_0.dify.intent_recognition_api:app",
# host="0.0.0.0",
# port=8001,
# reload=False, # 开发环境启用热重载
# workers=1 # 生产环境可以增加worker数量
# )
# 生产环境可以使用以下命令启动:
# uvicorn rag2_0.dify.DifyQueryRetrieval_api:app --host 0.0.0.0 --port 8002 --workers 10