151 lines
4.7 KiB
Python
151 lines
4.7 KiB
Python
# from gevent import monkey
|
|
# monkey.patch_all()
|
|
|
|
import os
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel, Field, ConfigDict
|
|
from typing import Dict, List, Any, Optional
|
|
import asyncio
|
|
|
|
from dotenv import load_dotenv
|
|
import json
|
|
import time
|
|
import datetime
|
|
import logging
|
|
# 加载环境变量
|
|
load_dotenv()
|
|
|
|
import sys
|
|
sys.path.append(os.getcwd())
|
|
from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval
|
|
|
|
# 确保日志目录存在
|
|
os.makedirs('data/logs', exist_ok=True)
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - [%(thread)d] - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.StreamHandler(),
|
|
logging.FileHandler(f'data/logs/dify_query_retrieval_{datetime.datetime.now().strftime("%Y%m%d")}.log', encoding='utf-8')
|
|
]
|
|
)
|
|
logging.getLogger('httpx').setLevel(logging.WARNING)
|
|
logging.getLogger('openai').setLevel(logging.WARNING)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# 定义请求模型
|
|
class RetrieveRequest(BaseModel):
|
|
original_query: str
|
|
query_list: str
|
|
data_set_list: str
|
|
query_expand_dict: dict | str = Field(default="{}")
|
|
topk: int = Field(default=4)
|
|
metadata_filtering_conditions : dict = Field(default={})
|
|
|
|
# 创建FastAPI应用
|
|
app = FastAPI(
|
|
title="Dify查询检索服务",
|
|
description="基于Dify的异步查询检索服务",
|
|
version="1.0"
|
|
)
|
|
|
|
# 添加CORS中间件
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# 全局变量存储DifyQueryRetrieval实例
|
|
dify_query_retrieval = None
|
|
|
|
# 应用启动事件
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
global dify_query_retrieval
|
|
# 初始化DifyQueryRetrieval实例
|
|
dify_query_retrieval = DifyQueryRetrieval(dify_dataset_key=os.getenv("DIFY_DATASET_KEY"), dify_base_url=os.getenv("DIFY_BSAE_URL"))
|
|
logger.info("DifyQueryRetrieval初始化完成")
|
|
|
|
# 添加健康检查端点
|
|
@app.get("/health", summary="健康检查")
|
|
async def health_check():
|
|
return {"status": "ok"}
|
|
|
|
@app.post("/retrieve", summary="异步检索API")
|
|
async def retrieve(request: RetrieveRequest):
|
|
"""
|
|
异步检索API
|
|
|
|
Args:
|
|
request: 包含原始查询、查询列表和数据集列表的请求对象
|
|
|
|
Returns:
|
|
检索结果
|
|
|
|
examples:
|
|
body: {
|
|
"data_set_list": "主网造价知识(new)",
|
|
"original_query": "钻孔灌注桩防沉台、承台基础适用于什么定额",
|
|
"query_expand_dict": "{}",
|
|
"query_list": "钻孔灌注桩防沉台、承台基础适用于什么定额",
|
|
"topk": 20,
|
|
"metadata_filtering_conditions": {
|
|
"logical_operator": "and",
|
|
"conditions": [
|
|
{
|
|
"name": "doc_class",
|
|
"comparison_operator": "is",
|
|
"value": "定额章节说明"
|
|
}
|
|
]
|
|
}
|
|
}
|
|
"""
|
|
try:
|
|
# 解析查询列表和数据集列表
|
|
query_list = request.query_list.split("<sub_query>")
|
|
data_set_list = request.data_set_list.split("<dataset>")
|
|
if isinstance(request.query_expand_dict, str):
|
|
query_expand_dict = json.loads(request.query_expand_dict)
|
|
else:
|
|
query_expand_dict = request.query_expand_dict
|
|
# 调用异步检索方法
|
|
start_time = time.time()
|
|
results = await dify_query_retrieval.retrieve_api_async(
|
|
request.original_query,
|
|
query_list,
|
|
data_set_list,
|
|
query_expand_dict=query_expand_dict,
|
|
top_k=request.topk,
|
|
metadata_filtering_conditions=request.metadata_filtering_conditions
|
|
)
|
|
end_time = time.time()
|
|
|
|
logger.info(f"异步检索总耗时: {end_time - start_time:.2f}秒")
|
|
return results
|
|
except Exception as e:
|
|
logger.error(f"异步检索出错: {str(e)}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
if __name__ == "__main__":
|
|
# 使用Uvicorn运行FastAPI应用
|
|
import uvicorn
|
|
uvicorn.run("rag2_0.api.DifyQueryRetrieval_api:app", host="0.0.0.0", port=8002, reload=False, workers=1, log_level="info")
|
|
# # 使用uvicorn启动服务
|
|
# import uvicorn
|
|
# uvicorn.run(
|
|
# "rag2_0.api.DifyQueryRetrieval_api:app",
|
|
# host="0.0.0.0",
|
|
# port=8001,
|
|
# reload=False, # 开发环境启用热重载
|
|
# workers=1 # 生产环境可以增加worker数量
|
|
# )
|
|
# 生产环境可以使用以下命令启动:
|
|
# uvicorn rag2_0.dify.DifyQueryRetrieval_api:app --host 0.0.0.0 --port 8002 --workers 10 |