重构DifyQueryRetrieval_api.py为FastAPI应用,新增异步检索API和健康检查端点,优化错误处理和日志记录,同时更新DifyQueryRetrieval类以支持异步检索功能,提升整体性能和可维护性。
This commit is contained in:
@@ -6,6 +6,8 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
|||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
import asyncio
|
||||||
|
import httpx
|
||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
|
|
||||||
from rag2_0.intent_recognition.DataModels import Classification
|
from rag2_0.intent_recognition.DataModels import Classification
|
||||||
@@ -52,6 +54,28 @@ class DifyQueryRetrieval:
|
|||||||
logging.error(f"检索数据集 {dataset_name} 时出错: {str(e)}", exc_info=True)
|
logging.error(f"检索数据集 {dataset_name} 时出错: {str(e)}", exc_info=True)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def retrieve_by_dataset_async(self, query: str, dataset_name: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
异步版本的retrieve_by_dataset方法
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 查询字符串
|
||||||
|
dataset_name: 数据集名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
检索到的文档列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 使用asyncio.to_thread包装同步方法
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
self.retrieve_by_dataset,
|
||||||
|
query,
|
||||||
|
dataset_name
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"异步检索数据集 {dataset_name} 时出错: {str(e)}", exc_info=True)
|
||||||
|
return []
|
||||||
|
|
||||||
def retrieve(self, original_query: str, query_list: List[str], classification: Classification, software_name: str) -> Optional[List[Dict[str, Any]]]:
|
def retrieve(self, original_query: str, query_list: List[str], classification: Classification, software_name: str) -> Optional[List[Dict[str, Any]]]:
|
||||||
datasets = self.get_datasets_by_classification(classification, software_name)
|
datasets = self.get_datasets_by_classification(classification, software_name)
|
||||||
if len(datasets) == 0:
|
if len(datasets) == 0:
|
||||||
@@ -59,6 +83,25 @@ class DifyQueryRetrieval:
|
|||||||
|
|
||||||
return self.retrieve_api(original_query, query_list, datasets)
|
return self.retrieve_api(original_query, query_list, datasets)
|
||||||
|
|
||||||
|
async def retrieve_async(self, original_query: str, query_list: List[str], classification: Classification, software_name: str) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
异步版本的retrieve方法
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original_query: 原始查询
|
||||||
|
query_list: 查询列表
|
||||||
|
classification: 分类信息
|
||||||
|
software_name: 软件名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
检索到的文档列表
|
||||||
|
"""
|
||||||
|
datasets = self.get_datasets_by_classification(classification, software_name)
|
||||||
|
if len(datasets) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return await self.retrieve_api_async(original_query, query_list, datasets)
|
||||||
|
|
||||||
def retrieve_api(self, original_query: str, query_list: List[str],data_set_list: List[str])->List[Dict[str, Any]]:
|
def retrieve_api(self, original_query: str, query_list: List[str],data_set_list: List[str])->List[Dict[str, Any]]:
|
||||||
all_documents=[]
|
all_documents=[]
|
||||||
# 使用线程池替代无限制创建线程
|
# 使用线程池替代无限制创建线程
|
||||||
@@ -103,6 +146,65 @@ class DifyQueryRetrieval:
|
|||||||
|
|
||||||
return processed_documents
|
return processed_documents
|
||||||
|
|
||||||
|
async def retrieve_api_async(self, original_query: str, query_list: List[str], data_set_list: List[str])->List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
异步版本的retrieve_api方法,使用asyncio代替线程池
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original_query: 原始查询
|
||||||
|
query_list: 查询列表
|
||||||
|
data_set_list: 数据集列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
检索并重排序后的文档列表
|
||||||
|
"""
|
||||||
|
all_documents = []
|
||||||
|
# 记录开始时间
|
||||||
|
time_start = time.time()
|
||||||
|
|
||||||
|
# 创建异步任务列表
|
||||||
|
tasks = []
|
||||||
|
for query in query_list:
|
||||||
|
for dataset in data_set_list:
|
||||||
|
if dataset not in self._datasets_list:
|
||||||
|
logging.error(f"dataset {dataset} not in datasets_list")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 创建异步任务
|
||||||
|
task = self.retrieve_by_dataset_async(query, dataset)
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
# 并发执行所有异步任务
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# 处理结果
|
||||||
|
for result in results:
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
logging.error(f"异步检索过程中发生错误: {str(result)}", exc_info=True)
|
||||||
|
else:
|
||||||
|
all_documents.extend(result)
|
||||||
|
|
||||||
|
time_end = time.time()
|
||||||
|
logging.info(f"异步检索耗时: {time_end - time_start:.2f}秒")
|
||||||
|
|
||||||
|
# 根据segment_id对文档进行去重
|
||||||
|
unique_documents = {}
|
||||||
|
for document in all_documents:
|
||||||
|
segment_id = document['segment']['id']
|
||||||
|
if segment_id not in unique_documents:
|
||||||
|
unique_documents[segment_id] = document
|
||||||
|
|
||||||
|
# 将去重后的文档转换为列表
|
||||||
|
deduplicated_documents = list(unique_documents.values())
|
||||||
|
|
||||||
|
# 对所有检索出来的文档进行重排序
|
||||||
|
time_start = time.time()
|
||||||
|
processed_documents = await self.data_post_processor_async(original_query, deduplicated_documents)
|
||||||
|
time_end = time.time()
|
||||||
|
logging.info(f"异步检索后重排序耗时: {time_end - time_start:.2f}秒")
|
||||||
|
|
||||||
|
return processed_documents
|
||||||
|
|
||||||
def data_post_processor(self, query: str, all_documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
def data_post_processor(self, query: str, all_documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
reranker_model = XinferenceReRankerModel()
|
reranker_model = XinferenceReRankerModel()
|
||||||
documents = [document['segment']['content'] for document in all_documents]
|
documents = [document['segment']['content'] for document in all_documents]
|
||||||
@@ -138,6 +240,52 @@ class DifyQueryRetrieval:
|
|||||||
new_all_documents.append(to_dify_document_format(cur_doc_info))
|
new_all_documents.append(to_dify_document_format(cur_doc_info))
|
||||||
return new_all_documents
|
return new_all_documents
|
||||||
|
|
||||||
|
async def data_post_processor_async(self, query: str, all_documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
异步版本的data_post_processor方法
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 查询字符串
|
||||||
|
all_documents: 待处理的文档列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理后的文档列表
|
||||||
|
"""
|
||||||
|
reranker_model = XinferenceReRankerModel()
|
||||||
|
documents = [document['segment']['content'] for document in all_documents]
|
||||||
|
# 使用异步重排序方法
|
||||||
|
reranked_documents = await reranker_model.rerank_async(query, documents, top_k=5)
|
||||||
|
new_all_documents = []
|
||||||
|
|
||||||
|
def to_dify_document_format(document: dict)->dict:
|
||||||
|
return {
|
||||||
|
"metadata": {
|
||||||
|
"_source": "knowledge",
|
||||||
|
"dataset_id": document["dataset_id"],
|
||||||
|
"dataset_name": document["dataset_name"],
|
||||||
|
"document_id": document['segment']['document_id'],
|
||||||
|
"document_name": document["segment"]["document"]["name"],
|
||||||
|
"data_source_type": document["segment"]["document"]["data_source_type"],
|
||||||
|
"segment_id": document["segment"]["id"],
|
||||||
|
"retriever_from": "api",
|
||||||
|
"score": document.get("score", 0),
|
||||||
|
"segment_hit_count": document.get("segment", {}).get("hit_count", 0),
|
||||||
|
"segment_word_count": document.get("segment", {}).get("word_count", 0),
|
||||||
|
"segment_position": document.get("segment", {}).get("position", 0),
|
||||||
|
"segment_index_node_hash": document.get("segment", {}).get("index_node_hash", ""),
|
||||||
|
"doc_metadata": document.get("segment", {}).get("document", {}).get("doc_metadata", None),
|
||||||
|
"position": document["segment"].get("position", 0)
|
||||||
|
},
|
||||||
|
"title": document["segment"]["document"]["name"],
|
||||||
|
"content": document["segment"]["content"]
|
||||||
|
}
|
||||||
|
|
||||||
|
for reranked_document in reranked_documents:
|
||||||
|
cur_doc_info = all_documents[reranked_document["index"]]
|
||||||
|
cur_doc_info["score"] = reranked_document["score"]
|
||||||
|
new_all_documents.append(to_dify_document_format(cur_doc_info))
|
||||||
|
return new_all_documents
|
||||||
|
|
||||||
def get_datasets_by_classification(self, classification: Classification, software_name: str) -> List[str]:
|
def get_datasets_by_classification(self, classification: Classification, software_name: str) -> List[str]:
|
||||||
if classification.vertical_classification=="软件问题" or classification.vertical_classification=="业务问题":
|
if classification.vertical_classification=="软件问题" or classification.vertical_classification=="业务问题":
|
||||||
software_name_list = self.software_to_dataset_map.keys()
|
software_name_list = self.software_to_dataset_map.keys()
|
||||||
@@ -165,7 +313,21 @@ class DifyQueryRetrieval:
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1")
|
dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1")
|
||||||
# datasets = dify_query_retrieval.retrieve("配网工程计价通D3软件如何新建工程?", Classification(vertical_classification="软件问题", sub_classification="软件功能"), "配网工程计价通D3")
|
# datasets = dify_query_retrieval.retrieve("配网工程计价通D3软件如何新建工程?", Classification(vertical_classification="软件问题", sub_classification="软件功能"), "配网工程计价通D3")
|
||||||
datasets = dify_query_retrieval.retrieve_api("电力建设计价通软件如何批量修改设备价格?",
|
# datasets = dify_query_retrieval.retrieve_api("电力建设计价通软件如何批量修改设备价格?",
|
||||||
["电力建设计价通软件如何批量修改设备价格?"],
|
# ["电力建设计价通软件如何批量修改设备价格?"],
|
||||||
["电力建设计价通(2018)软件知识(new)"])
|
# ["电力建设计价通(2018)软件知识(new)"])
|
||||||
print(json.dumps(datasets, ensure_ascii=False, indent=2))
|
# print(json.dumps(datasets, ensure_ascii=False, indent=2))
|
||||||
|
|
||||||
|
# 测试异步API
|
||||||
|
async def test_async_api():
|
||||||
|
datasets = await dify_query_retrieval.retrieve_api_async(
|
||||||
|
"电力建设计价通软件如何批量修改设备价格?",
|
||||||
|
["电力建设计价通软件如何批量修改设备价格?"],
|
||||||
|
["电力建设计价通(2018)软件知识(new)"]
|
||||||
|
)
|
||||||
|
print("异步API测试结果:")
|
||||||
|
print(json.dumps(datasets, ensure_ascii=False, indent=2))
|
||||||
|
|
||||||
|
# 如果需要测试异步API,取消下面的注释
|
||||||
|
import asyncio
|
||||||
|
asyncio.run(test_async_api())
|
||||||
@@ -1,19 +1,25 @@
|
|||||||
# from gevent import monkey
|
# from gevent import monkey
|
||||||
# monkey.patch_all()
|
# monkey.patch_all()
|
||||||
|
|
||||||
from flask import Flask, request, Response
|
|
||||||
import os
|
import os
|
||||||
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Dict, List, Any, Optional
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from gevent.lock import RLock
|
|
||||||
|
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
# 加载环境变量
|
# 加载环境变量
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
from DifyQueryRetrieval import DifyQueryRetrieval
|
|
||||||
|
import sys
|
||||||
|
sys.path.append(os.getcwd())
|
||||||
|
from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -28,27 +34,87 @@ logging.getLogger('openai').setLevel(logging.WARNING)
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
app = Flask(__name__)
|
# 定义请求模型
|
||||||
|
class RetrieveRequest(BaseModel):
|
||||||
|
original_query: str
|
||||||
|
query_list: str
|
||||||
|
data_set_list: str
|
||||||
|
|
||||||
dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1")
|
# 创建FastAPI应用
|
||||||
@app.route('/retrieve', methods=['POST'])
|
app = FastAPI(
|
||||||
def retrieve():
|
title="Dify查询检索服务",
|
||||||
data = request.get_json(force=True)
|
description="基于Dify的异步查询检索服务",
|
||||||
original_query_str = data.get('original_query')
|
version="1.0"
|
||||||
query_list_str = data.get('query_list')
|
)
|
||||||
data_set_list_str = data.get('data_set_list')
|
|
||||||
|
# 添加CORS中间件
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 全局变量存储DifyQueryRetrieval实例
|
||||||
|
dify_query_retrieval = None
|
||||||
|
|
||||||
|
# 应用启动事件
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
global dify_query_retrieval
|
||||||
|
# 初始化DifyQueryRetrieval实例
|
||||||
|
dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1")
|
||||||
|
logger.info("DifyQueryRetrieval初始化完成")
|
||||||
|
|
||||||
|
# 添加健康检查端点
|
||||||
|
@app.get("/health", summary="健康检查")
|
||||||
|
async def health_check():
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
@app.post("/retrieve", summary="异步检索API")
|
||||||
|
async def retrieve(request: RetrieveRequest):
|
||||||
|
"""
|
||||||
|
异步检索API
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 包含原始查询、查询列表和数据集列表的请求对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
检索结果
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
query_list = query_list_str.split("<sub_query>")
|
# 解析查询列表和数据集列表
|
||||||
data_set_list = data_set_list_str.split("<dataset>")
|
query_list = request.query_list.split("<sub_query>")
|
||||||
results = dify_query_retrieval.retrieve_api(original_query_str, query_list, data_set_list)
|
data_set_list = request.data_set_list.split("<dataset>")
|
||||||
return Response(json.dumps(results, ensure_ascii=False), content_type='application/json; charset=utf-8')
|
|
||||||
|
# 调用异步检索方法
|
||||||
|
start_time = time.time()
|
||||||
|
results = await dify_query_retrieval.retrieve_api_async(
|
||||||
|
request.original_query,
|
||||||
|
query_list,
|
||||||
|
data_set_list
|
||||||
|
)
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
logger.info(f"异步检索总耗时: {end_time - start_time:.2f}秒")
|
||||||
|
return results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"检索出错: {str(e)}", exc_info=True)
|
logger.error(f"异步检索出错: {str(e)}", exc_info=True)
|
||||||
return Response(json.dumps({"error": str(e)}, ensure_ascii=False), content_type='application/json; charset=utf-8', status=500)
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 开发环境使用Flask内置服务器
|
# 使用Uvicorn运行FastAPI应用
|
||||||
# 生产环境使用gunicorn支持高并发 uv run gunicorn -w 10 -k gevent -b 0.0.0.0:8001 rag2_0.dify.DifyQueryRetrieval_api:app
|
import uvicorn
|
||||||
# uv run gunicorn -w 10 -k gevent --preload -b 0.0.0.0:8001 rag2_0.dify.DifyQueryRetrieval_api:app
|
uvicorn.run("rag2_0.dify.DifyQueryRetrieval_api:app", host="0.0.0.0", port=8002, reload=False, workers=10, log_level="info")
|
||||||
app.run(host="0.0.0.0", port=8002, threaded=True)
|
# # 使用uvicorn启动服务
|
||||||
|
# import uvicorn
|
||||||
|
# uvicorn.run(
|
||||||
|
# "rag2_0.dify.intent_recognition_api:app",
|
||||||
|
# host="0.0.0.0",
|
||||||
|
# port=8001,
|
||||||
|
# reload=False, # 开发环境启用热重载
|
||||||
|
# workers=1 # 生产环境可以增加worker数量
|
||||||
|
# )
|
||||||
|
# 生产环境可以使用以下命令启动:
|
||||||
|
# uvicorn rag2_0.dify.DifyQueryRetrieval_api:app --host 0.0.0.0 --port 8002 --workers 10
|
||||||
Reference in New Issue
Block a user