重构DifyQueryRetrieval_api.py为FastAPI应用,新增异步检索API和健康检查端点,优化错误处理和日志记录,同时更新DifyQueryRetrieval类以支持异步检索功能,提升整体性能和可维护性。
This commit is contained in:
@@ -1,19 +1,25 @@
|
||||
# from gevent import monkey
|
||||
# monkey.patch_all()
|
||||
|
||||
from flask import Flask, request, Response
|
||||
import os
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, List, Any, Optional
|
||||
import asyncio
|
||||
|
||||
from dotenv import load_dotenv
|
||||
import json
|
||||
import time
|
||||
from gevent.lock import RLock
|
||||
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
from DifyQueryRetrieval import DifyQueryRetrieval
|
||||
|
||||
import sys
|
||||
sys.path.append(os.getcwd())
|
||||
from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
@@ -28,27 +34,87 @@ logging.getLogger('openai').setLevel(logging.WARNING)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = Flask(__name__)
|
||||
# 定义请求模型
|
||||
class RetrieveRequest(BaseModel):
|
||||
original_query: str
|
||||
query_list: str
|
||||
data_set_list: str
|
||||
|
||||
dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1")
|
||||
@app.route('/retrieve', methods=['POST'])
|
||||
def retrieve():
|
||||
data = request.get_json(force=True)
|
||||
original_query_str = data.get('original_query')
|
||||
query_list_str = data.get('query_list')
|
||||
data_set_list_str = data.get('data_set_list')
|
||||
# 创建FastAPI应用
|
||||
app = FastAPI(
|
||||
title="Dify查询检索服务",
|
||||
description="基于Dify的异步查询检索服务",
|
||||
version="1.0"
|
||||
)
|
||||
|
||||
# 添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 全局变量存储DifyQueryRetrieval实例
|
||||
dify_query_retrieval = None
|
||||
|
||||
# 应用启动事件
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
global dify_query_retrieval
|
||||
# 初始化DifyQueryRetrieval实例
|
||||
dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1")
|
||||
logger.info("DifyQueryRetrieval初始化完成")
|
||||
|
||||
# 添加健康检查端点
|
||||
@app.get("/health", summary="健康检查")
|
||||
async def health_check():
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.post("/retrieve", summary="异步检索API")
|
||||
async def retrieve(request: RetrieveRequest):
|
||||
"""
|
||||
异步检索API
|
||||
|
||||
Args:
|
||||
request: 包含原始查询、查询列表和数据集列表的请求对象
|
||||
|
||||
Returns:
|
||||
检索结果
|
||||
"""
|
||||
try:
|
||||
query_list = query_list_str.split("<sub_query>")
|
||||
data_set_list = data_set_list_str.split("<dataset>")
|
||||
results = dify_query_retrieval.retrieve_api(original_query_str, query_list, data_set_list)
|
||||
return Response(json.dumps(results, ensure_ascii=False), content_type='application/json; charset=utf-8')
|
||||
# 解析查询列表和数据集列表
|
||||
query_list = request.query_list.split("<sub_query>")
|
||||
data_set_list = request.data_set_list.split("<dataset>")
|
||||
|
||||
# 调用异步检索方法
|
||||
start_time = time.time()
|
||||
results = await dify_query_retrieval.retrieve_api_async(
|
||||
request.original_query,
|
||||
query_list,
|
||||
data_set_list
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
logger.info(f"异步检索总耗时: {end_time - start_time:.2f}秒")
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"检索出错: {str(e)}", exc_info=True)
|
||||
return Response(json.dumps({"error": str(e)}, ensure_ascii=False), content_type='application/json; charset=utf-8', status=500)
|
||||
|
||||
logger.error(f"异步检索出错: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 开发环境使用Flask内置服务器
|
||||
# 生产环境使用gunicorn支持高并发 uv run gunicorn -w 10 -k gevent -b 0.0.0.0:8001 rag2_0.dify.DifyQueryRetrieval_api:app
|
||||
# uv run gunicorn -w 10 -k gevent --preload -b 0.0.0.0:8001 rag2_0.dify.DifyQueryRetrieval_api:app
|
||||
app.run(host="0.0.0.0", port=8002, threaded=True)
|
||||
# 使用Uvicorn运行FastAPI应用
|
||||
import uvicorn
|
||||
uvicorn.run("rag2_0.dify.DifyQueryRetrieval_api:app", host="0.0.0.0", port=8002, reload=False, workers=10, log_level="info")
|
||||
# # 使用uvicorn启动服务
|
||||
# import uvicorn
|
||||
# uvicorn.run(
|
||||
# "rag2_0.dify.intent_recognition_api:app",
|
||||
# host="0.0.0.0",
|
||||
# port=8001,
|
||||
# reload=False, # 开发环境启用热重载
|
||||
# workers=1 # 生产环境可以增加worker数量
|
||||
# )
|
||||
# 生产环境可以使用以下命令启动:
|
||||
# uvicorn rag2_0.dify.DifyQueryRetrieval_api:app --host 0.0.0.0 --port 8002 --workers 10
|
||||
Reference in New Issue
Block a user