1、修改api文件位置
2、意图识别继承langfuse
This commit is contained in:
@@ -0,0 +1,132 @@
|
||||
# from gevent import monkey
|
||||
# monkey.patch_all()
|
||||
|
||||
import os
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Dict, List, Any, Optional
|
||||
import asyncio
|
||||
|
||||
from dotenv import load_dotenv
|
||||
import json
|
||||
import time
|
||||
import datetime
|
||||
import logging
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
import sys
|
||||
sys.path.append(os.getcwd())
|
||||
from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval
|
||||
|
||||
# 确保日志目录存在
|
||||
os.makedirs('data/logs', exist_ok=True)
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - [%(thread)d] - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler(f'data/logs/dify_query_retrieval_{datetime.datetime.now().strftime("%Y%m%d")}.log', encoding='utf-8')
|
||||
]
|
||||
)
|
||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||
logging.getLogger('openai').setLevel(logging.WARNING)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 定义请求模型
|
||||
class RetrieveRequest(BaseModel):
|
||||
original_query: str
|
||||
query_list: str
|
||||
data_set_list: str
|
||||
query_expand_dict: dict | str = Field(default="{}")
|
||||
topk: int = Field(default=4)
|
||||
metadata_filtering_conditions : dict = Field(default={})
|
||||
|
||||
# 创建FastAPI应用
|
||||
app = FastAPI(
|
||||
title="Dify查询检索服务",
|
||||
description="基于Dify的异步查询检索服务",
|
||||
version="1.0"
|
||||
)
|
||||
|
||||
# 添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 全局变量存储DifyQueryRetrieval实例
|
||||
dify_query_retrieval = None
|
||||
|
||||
# 应用启动事件
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
global dify_query_retrieval
|
||||
# 初始化DifyQueryRetrieval实例
|
||||
dify_query_retrieval = DifyQueryRetrieval(dify_dataset_key=os.getenv("DIFY_DATASET_KEY"), dify_base_url=os.getenv("DIFY_BSAE_URL"))
|
||||
logger.info("DifyQueryRetrieval初始化完成")
|
||||
|
||||
# 添加健康检查端点
|
||||
@app.get("/health", summary="健康检查")
|
||||
async def health_check():
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.post("/retrieve", summary="异步检索API")
|
||||
async def retrieve(request: RetrieveRequest):
|
||||
"""
|
||||
异步检索API
|
||||
|
||||
Args:
|
||||
request: 包含原始查询、查询列表和数据集列表的请求对象
|
||||
|
||||
Returns:
|
||||
检索结果
|
||||
"""
|
||||
try:
|
||||
# 解析查询列表和数据集列表
|
||||
query_list = request.query_list.split("<sub_query>")
|
||||
data_set_list = request.data_set_list.split("<dataset>")
|
||||
if isinstance(request.query_expand_dict, str):
|
||||
query_expand_dict = json.loads(request.query_expand_dict)
|
||||
else:
|
||||
query_expand_dict = request.query_expand_dict
|
||||
# 调用异步检索方法
|
||||
start_time = time.time()
|
||||
results = await dify_query_retrieval.retrieve_api_async(
|
||||
request.original_query,
|
||||
query_list,
|
||||
data_set_list,
|
||||
query_expand_dict=query_expand_dict,
|
||||
top_k=request.topk,
|
||||
metadata_filtering_conditions=request.metadata_filtering_conditions
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
logger.info(f"异步检索总耗时: {end_time - start_time:.2f}秒")
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"异步检索出错: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 使用Uvicorn运行FastAPI应用
|
||||
import uvicorn
|
||||
uvicorn.run("rag2_0.api.DifyQueryRetrieval_api:app", host="0.0.0.0", port=9002, reload=False, workers=1, log_level="info")
|
||||
# # 使用uvicorn启动服务
|
||||
# import uvicorn
|
||||
# uvicorn.run(
|
||||
# "rag2_0.api.DifyQueryRetrieval_api:app",
|
||||
# host="0.0.0.0",
|
||||
# port=8001,
|
||||
# reload=False, # 开发环境启用热重载
|
||||
# workers=1 # 生产环境可以增加worker数量
|
||||
# )
|
||||
# 生产环境可以使用以下命令启动:
|
||||
# uvicorn rag2_0.dify.DifyQueryRetrieval_api:app --host 0.0.0.0 --port 8002 --workers 10
|
||||
Reference in New Issue
Block a user