Files
QueryRewrite/rag2_0/dify/DifyQueryRetrieval_api.py
T

126 lines
3.7 KiB
Python

# from gevent import monkey
# monkey.patch_all()
import os
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import Dict, List, Any, Optional
import asyncio
from dotenv import load_dotenv
import json
import time
import datetime
import logging
# 加载环境变量
load_dotenv()
import sys
sys.path.append(os.getcwd())
from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler()
]
)
logging.getLogger('httpx').setLevel(logging.WARNING)
logging.getLogger('openai').setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
# 定义请求模型
class RetrieveRequest(BaseModel):
original_query: str
query_list: str
data_set_list: str
query_expand_dict: dict | str = Field(default="{}")
# 创建FastAPI应用
app = FastAPI(
title="Dify查询检索服务",
description="基于Dify的异步查询检索服务",
version="1.0"
)
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 全局变量存储DifyQueryRetrieval实例
dify_query_retrieval = None
# 应用启动事件
@app.on_event("startup")
async def startup_event():
global dify_query_retrieval
# 初始化DifyQueryRetrieval实例
dify_query_retrieval = DifyQueryRetrieval(dify_dataset_key=os.getenv("DIFY_DATASET_KEY"), dify_base_url=os.getenv("DIFY_BSAE_URL"))
logger.info("DifyQueryRetrieval初始化完成")
# 添加健康检查端点
@app.get("/health", summary="健康检查")
async def health_check():
return {"status": "ok"}
@app.post("/retrieve", summary="异步检索API")
async def retrieve(request: RetrieveRequest):
"""
异步检索API
Args:
request: 包含原始查询、查询列表和数据集列表的请求对象
Returns:
检索结果
"""
try:
# 解析查询列表和数据集列表
query_list = request.query_list.split("<sub_query>")
data_set_list = request.data_set_list.split("<dataset>")
if isinstance(request.query_expand_dict, str):
query_expand_dict = json.loads(request.query_expand_dict)
else:
query_expand_dict = request.query_expand_dict
# 调用异步检索方法
start_time = time.time()
results = await dify_query_retrieval.retrieve_api_async(
request.original_query,
query_list,
data_set_list,
query_expand_dict=query_expand_dict,
top_k=5
)
end_time = time.time()
logger.info(f"异步检索总耗时: {end_time - start_time:.2f}秒")
return results
except Exception as e:
logger.error(f"异步检索出错: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
# 使用Uvicorn运行FastAPI应用
import uvicorn
uvicorn.run("rag2_0.dify.DifyQueryRetrieval_api:app", host="0.0.0.0", port=8002, reload=False, workers=1, log_level="info")
# # 使用uvicorn启动服务
# import uvicorn
# uvicorn.run(
# "rag2_0.dify.intent_recognition_api:app",
# host="0.0.0.0",
# port=8001,
# reload=False, # 开发环境启用热重载
# workers=1 # 生产环境可以增加worker数量
# )
# 生产环境可以使用以下命令启动:
# uvicorn rag2_0.dify.DifyQueryRetrieval_api:app --host 0.0.0.0 --port 8002 --workers 10