From 78dc1673aa27ae0c6ee874a56883574054dd1d04 Mon Sep 17 00:00:00 2001 From: ouyangyouzhang Date: Fri, 29 Aug 2025 09:18:24 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=B8=80=E4=B8=AA=E7=BB=9F?= =?UTF-8?q?=E4=B8=80=E7=9A=84=E8=84=9A=E6=9C=AC=E7=AE=A1=E7=90=86=E6=9C=8D?= =?UTF-8?q?=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- manage_services.sh | 269 ++++++++++++++++++++++++ pyproject.toml | 1 + rag2_0/api/AnswerType_api.py | 290 ++++++++++++-------------- rag2_0/api/DifyQueryRetrieval_api.py | 2 +- rag2_0/api/intent_recognition_api.py | 2 +- rag2_0/api/query_dinge_qingdan_api.py | 143 ++++++------- start_AnswerType.sh | 18 -- start_DifyQueryRetrieval_api.sh | 18 -- start_intent_recognition_api.sh | 18 -- uv.lock | 14 ++ 10 files changed, 484 insertions(+), 291 deletions(-) create mode 100755 manage_services.sh delete mode 100755 start_AnswerType.sh delete mode 100755 start_DifyQueryRetrieval_api.sh delete mode 100755 start_intent_recognition_api.sh diff --git a/manage_services.sh b/manage_services.sh new file mode 100755 index 0000000..4880976 --- /dev/null +++ b/manage_services.sh @@ -0,0 +1,269 @@ +#!/usr/bin/env bash + +# 统一管理脚本:启动/停止/查看 四个 API 服务 +# 支持服务: +# - intent -> rag2_0.api.intent_recognition_api:app (port 8001, workers 25) +# - dify -> rag2_0.api.DifyQueryRetrieval_api:app (port 8002, workers 25) +# - answertype -> rag2_0.api.AnswerType_api:app (port 8003, workers 1) +# - qingdan -> rag2_0.api.query_dinge_qingdan_api:app (port 8005, workers 1) + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +# 定义服务配置:会话名 与 启动命令 +SERVICE_NAMES=(intent dify answertype qingdan) + +service_port() { + case "$1" in + intent) echo "8001" ;; + dify) echo "8002" ;; + answertype) echo "8003" ;; + qingdan) echo "8005" ;; + *) echo "" ;; + esac +} + +session_name() { + case "$1" in + intent) echo "intent_recognition_api" ;; + dify) echo "DifyQueryRetrieval_api" ;; + answertype) echo "AnswerType" ;; + qingdan) echo "query_dinge_qingdan_api" ;; + *) echo "" ;; + esac +} + +start_command() { + case "$1" in + intent) + echo "cd \"$SCRIPT_DIR\" && uv run uvicorn rag2_0.api.intent_recognition_api:app --host 0.0.0.0 --port 8001 --workers 25" ;; + dify) + echo "cd \"$SCRIPT_DIR\" && uv run uvicorn rag2_0.api.DifyQueryRetrieval_api:app --host 0.0.0.0 --port 8002 --workers 25" ;; + answertype) + echo "cd \"$SCRIPT_DIR\" && uv run uvicorn rag2_0.api.AnswerType_api:app --host 0.0.0.0 --port 8003 --workers 1" ;; + qingdan) + echo "cd \"$SCRIPT_DIR\" && uv run uvicorn rag2_0.api.query_dinge_qingdan_api:app --host 0.0.0.0 --port 8005 --workers 4" ;; + *) echo "" ;; + esac +} + +exists_session() { + # 使用严格匹配,避免误判 + local name="$1" + if screen -ls 2>/dev/null | grep -q "\\.${name}\\s"; then + return 0 + fi + return 1 +} + +# 按端口获取监听该端口的任意一个PID,优先用 ss,其次 lsof +pids_on_port() { + local port="$1" + # 从 ss 提取 pid 列表 + local ss_pids + ss_pids=$(ss -lptn 2>/dev/null \ + | grep -E ":${port}\\b" \ + | awk '{print $NF}' \ + | sed -n 's/.*pid=\([0-9]\+\),.*/\1/p' \ + | sort -u) + if [[ -n "$ss_pids" ]]; then + echo "$ss_pids" + return 0 + fi + # 从 lsof 提取 pid 列表 + if command -v lsof >/dev/null 2>&1; then + local lsof_pids + lsof_pids=$(lsof -nP -i :"$port" -sTCP:LISTEN -t 2>/dev/null | sort -u) + if [[ -n "$lsof_pids" ]]; then + echo "$lsof_pids" + return 0 + fi + fi + return 1 +} + +# 根据端口优雅终止(TERM)并在必要时强制(KILL)清理进程 +kill_by_port() { + local port="$1" + local pids + pids=$(pids_on_port "$port" || true) + if [[ -z "$pids" ]]; then + return 0 + fi + echo "[清理] 端口 $port 仍被占用,发送 SIGTERM 到: $pids" + kill -TERM $pids 2>/dev/null || true + sleep 2 + # 再次检查 + local left + left=$(pids_on_port "$port" || true) + if [[ -n "$left" ]]; then + echo "[强制] 端口 $port 仍占用,发送 SIGKILL 到: $left" + kill -KILL $left 2>/dev/null || true + fi +} + +start_service() { + local svc="$1" + local sname + sname="$(session_name "$svc")" + if [[ -z "$sname" ]]; then echo "未知服务: $svc"; return 2; fi + if exists_session "$sname"; then + echo "[跳过] 会话 '$sname' 已存在" + return 0 + fi + local cmd + cmd="$(start_command "$svc")" + if [[ -z "$cmd" ]]; then echo "未配置启动命令: $svc"; return 2; fi + screen -dmS "$sname" bash -c "$cmd" + echo "[启动] $svc -> screen 会话 '$sname'" +} + +stop_service() { + local svc="$1" + local sname + sname="$(session_name "$svc")" + local port + port="$(service_port "$svc")" + if [[ -z "$sname" || -z "$port" ]]; then echo "未知服务: $svc"; return 2; fi + + # 1) 先尝试关闭 screen 会话 + if exists_session "$sname"; then + screen -S "$sname" -X quit || true + echo "[停止] $svc -> '$sname'" + else + echo "[提示] 未发现 screen 会话: $sname" + fi + + # 2) 等待释放端口 + sleep 2 + + # 3) 如果仍占用,按端口清理 + if ss -lptn 2>/dev/null | grep -E -q ":${port}\\b" || (command -v lsof >/dev/null 2>&1 && lsof -i :"$port" -sTCP:LISTEN >/dev/null 2>&1); then + kill_by_port "$port" + fi +} + +status_service() { + local svc="$1" + local sname + sname="$(session_name "$svc")" + if [[ -z "$sname" ]]; then echo "未知服务: $svc"; return 2; fi + if exists_session "$sname"; then + echo "[运行中] $svc -> '$sname'" + else + echo "[未运行] $svc" + fi +} + +attach_service() { + local svc="$1" + local sname + sname="$(session_name "$svc")" + if [[ -z "$sname" ]]; then echo "未知服务: $svc"; return 2; fi + if exists_session "$sname"; then + echo "附着到会话: $sname (退出: Ctrl+A 然后 D)" + screen -r "$sname" + else + echo "服务未运行: $svc" + return 1 + fi +} + +start_all() { + for s in "${SERVICE_NAMES[@]}"; do + start_service "$s" + done +} + +stop_all() { + for s in "${SERVICE_NAMES[@]}"; do + stop_service "$s" + done +} + +status_all() { + for s in "${SERVICE_NAMES[@]}"; do + status_service "$s" + done +} + +restart_service() { + local svc="$1" + stop_service "$svc" + # 等待会话释放 + sleep 1 + start_service "$svc" +} + +usage() { + cat < [service] + +command: + start [svc] 启动指定服务;不指定时启动全部 + stop [svc] 停止指定服务;不指定时停止全部 + restart [svc] 重启指定服务;不指定时重启全部 + status 查看所有服务状态 + attach 附着到指定服务的 screen 会话 + force-stop [svc] 强制结束进程(按端口终止);不指定时对全部执行 + +service 可选值: + intent | dify | answertype | qingdan +EOF +} + +main() { + local cmd="${1:-}"; shift || true + case "$cmd" in + start) + local svc="${1:-all}" + if [[ "$svc" == "all" ]]; then start_all; else start_service "$svc"; fi + ;; + stop) + local svc="${1:-all}" + if [[ "$svc" == "all" ]]; then stop_all; else stop_service "$svc"; fi + ;; + restart) + local svc="${1:-all}" + if [[ "$svc" == "all" ]]; then + for s in "${SERVICE_NAMES[@]}"; do restart_service "$s"; done + else + restart_service "$svc" + fi + ;; + status) + status_all + ;; + attach) + local svc="${1:-}" + if [[ -z "$svc" ]]; then echo "请指定服务"; usage; exit 2; fi + attach_service "$svc" + ;; + force-stop) + local svc="${1:-all}" + if [[ "$svc" == "all" ]]; then + for s in "${SERVICE_NAMES[@]}"; do + # 仅按端口强制清理 + p="$(service_port "$s")" + if [[ -n "$p" ]]; then kill_by_port "$p"; fi + done + else + local p + p="$(service_port "$svc")" + if [[ -z "$p" ]]; then echo "未知服务: $svc"; exit 2; fi + kill_by_port "$p" + fi + ;; + ""|-h|--help|help) + usage + ;; + *) + echo "未知命令: $cmd" >&2 + usage + exit 2 + ;; + esac +} + +main "$@" diff --git a/pyproject.toml b/pyproject.toml index a8d1975..663d6a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,7 @@ readme = "README.md" requires-python = ">=3.11" dependencies = [ "bs4>=0.0.2", + "aiosqlite>=0.20.0", "faiss-cpu>=1.11.0", "fastapi>=0.115.14", "flask>=3.1.1", diff --git a/rag2_0/api/AnswerType_api.py b/rag2_0/api/AnswerType_api.py index faae6dd..8825b0e 100644 --- a/rag2_0/api/AnswerType_api.py +++ b/rag2_0/api/AnswerType_api.py @@ -8,8 +8,7 @@ from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from typing import Dict, List, Any, Optional import asyncio -import threading -import queue +import aiosqlite import sqlite3 from contextlib import closing @@ -38,90 +37,71 @@ DATA_DIR = os.path.join(os.getcwd(), "data") DB_DIR = os.path.join(DATA_DIR, "db") DB_FILE = os.path.join(DB_DIR, "answer_logs.db") -# 创建异步日志队列和工作线程 -log_queue = queue.Queue() -worker_thread = None -db_lock = threading.Lock() # 数据库操作锁 +# 创建异步日志队列和后台任务 +log_queue: asyncio.Queue = asyncio.Queue() +worker_task: asyncio.Task | None = None -# 初始化数据库 -def init_database(): +# 初始化数据库(异步) +async def init_database(): os.makedirs(DB_DIR, exist_ok=True) - - with db_lock: - with closing(sqlite3.connect(DB_FILE)) as conn: - cursor = conn.cursor() - - # 创建查询类型表 - cursor.execute(''' + async with aiosqlite.connect(DB_FILE) as conn: + await conn.execute(''' CREATE TABLE IF NOT EXISTS query_types ( id INTEGER PRIMARY KEY AUTOINCREMENT, query_type TEXT NOT NULL, workflow_run_id TEXT NOT NULL, timestamp TEXT NOT NULL ) - ''') - - # 创建点踩原因表 - cursor.execute(''' + ''') + await conn.execute(''' CREATE TABLE IF NOT EXISTS dislike_reasons ( id INTEGER PRIMARY KEY AUTOINCREMENT, dislike_reason TEXT NOT NULL, workflow_run_id TEXT NOT NULL, timestamp TEXT NOT NULL ) - ''') - - conn.commit() - + ''') + await conn.commit() logger.info("数据库初始化完成") -# 后台工作线程函数 -def log_worker(): - while True: - try: - # 从队列获取数据,设置超时以允许线程退出 - data = log_queue.get(timeout=1.0) - if data is None: # 接收到退出信号 - # 处理剩余数据后再退出 - while not log_queue.empty(): - data = log_queue.get_nowait() - if data is None: # 跳过额外的停止信号 - continue - process_log_data(data) - break - - process_log_data(data) - log_queue.task_done() - except queue.Empty: - continue - except Exception as e: - logger.error(f"保存查询数据时出错: {str(e)}", exc_info=True) - -# 提取数据处理逻辑到单独函数 -def process_log_data(data): +# 后台异步工作协程 +async def log_worker(): try: - with db_lock: - with closing(sqlite3.connect(DB_FILE)) as conn: - cursor = conn.cursor() - - if "dislike_reason" in data: - # 保存点踩原因 - cursor.execute( - "INSERT INTO dislike_reasons (dislike_reason, workflow_run_id, timestamp) VALUES (?, ?, ?)", - (data["dislike_reason"], data["workflow_run_id"], data["timestamp"]) - ) - table_name = "dislike_reasons" - else: - # 保存查询类型 - cursor.execute( - "INSERT INTO query_types (query_type, workflow_run_id, timestamp) VALUES (?, ?, ?)", - (data["query_type"], data["workflow_run_id"], data["timestamp"]) - ) - table_name = "query_types" - - conn.commit() - - logger.info(f"成功保存数据到表: {table_name}") + while True: + data = await log_queue.get() + if data is None: + # 排空剩余数据后退出 + while not log_queue.empty(): + pending = log_queue.get_nowait() + if pending is None: + continue + await process_log_data(pending) + break + await process_log_data(data) + log_queue.task_done() + except asyncio.CancelledError: + logger.info("日志工作任务被取消,尝试优雅退出...") + except Exception as e: + logger.error(f"保存查询数据时出错: {str(e)}", exc_info=True) + +# 提取数据处理逻辑到单独异步函数 +async def process_log_data(data): + try: + async with aiosqlite.connect(DB_FILE) as conn: + if "dislike_reason" in data: + await conn.execute( + "INSERT INTO dislike_reasons (dislike_reason, workflow_run_id, timestamp) VALUES (?, ?, ?)", + (data["dislike_reason"], data["workflow_run_id"], data["timestamp"]) + ) + table_name = "dislike_reasons" + else: + await conn.execute( + "INSERT INTO query_types (query_type, workflow_run_id, timestamp) VALUES (?, ?, ?)", + (data["query_type"], data["workflow_run_id"], data["timestamp"]) + ) + table_name = "query_types" + await conn.commit() + logger.info(f"成功保存数据到表: {table_name}") except Exception as e: logger.error(f"处理日志数据时出错: {str(e)}", exc_info=True) @@ -184,28 +164,27 @@ app.add_middleware( # 应用启动事件 @app.on_event("startup") async def startup_event(): - global worker_thread + global worker_task # 初始化数据库 - init_database() - - # 启动后台工作线程 - worker_thread = threading.Thread(target=log_worker, daemon=True) - worker_thread.start() - logger.info("后台日志工作线程已启动") + await init_database() + # 启动后台异步任务 + worker_task = asyncio.create_task(log_worker()) + logger.info("后台日志工作任务已启动") # 应用关闭事件 @app.on_event("shutdown") -def shutdown_event(): - global worker_thread - if worker_thread: - # 发送退出信号 - log_queue.put(None) - # 等待工作线程处理剩余数据 - worker_thread.join(timeout=10.0) - if worker_thread.is_alive(): - logger.warning("工作线程未在超时时间内退出") - else: - logger.info("后台日志工作线程已停止") +async def shutdown_event(): + global worker_task + if worker_task: + # 发送退出信号并等待任务结束 + await log_queue.put(None) + await asyncio.sleep(0) # 让出执行权给worker处理退出 + try: + await asyncio.wait_for(worker_task, timeout=10.0) + logger.info("后台日志工作任务已停止") + except asyncio.TimeoutError: + worker_task.cancel() + logger.warning("工作任务未在超时时间内退出,已取消") # 添加健康检查端点 @app.get("/health", summary="健康检查") @@ -226,9 +205,9 @@ async def query_type(query_type: str, workflow_run_id:str): "workflow_run_id": workflow_run_id } - # 将数据放入队列 + # 将数据放入异步队列 try: - log_queue.put(query_data) + await log_queue.put(query_data) success = True logger.info(f"查询数据已加入队列,当前队列大小: {log_queue.qsize()}") except Exception as e: @@ -256,9 +235,9 @@ async def dislike_reason(reason: str, workflow_run_id: str): "workflow_run_id": workflow_run_id } - # 将数据放入队列 + # 将数据放入异步队列 try: - log_queue.put(dislike_data) + await log_queue.put(dislike_data) success = True logger.info(f"点踩原因数据已加入队列,当前队列大小: {log_queue.qsize()}") except Exception as e: @@ -276,35 +255,34 @@ async def dislike_reason(reason: str, workflow_run_id: str): @app.get("/stats", summary="查询统计数据") async def get_stats(): try: - with db_lock: - with closing(sqlite3.connect(DB_FILE)) as conn: - conn.row_factory = sqlite3.Row # 启用字典行工厂 - cursor = conn.cursor() - - # 查询类型统计 - cursor.execute("SELECT COUNT(*) as count FROM query_types") - query_count = cursor.fetchone()['count'] - - # 点踩原因统计 - cursor.execute("SELECT COUNT(*) as count FROM dislike_reasons") - dislike_count = cursor.fetchone()['count'] - - # 最近5条查询记录 - cursor.execute(""" - SELECT query_type, workflow_run_id, timestamp - FROM query_types - ORDER BY id DESC LIMIT 5 - """) - recent_queries = [dict(row) for row in cursor.fetchall()] - - # 最近5条点踩记录 - cursor.execute(""" - SELECT dislike_reason, workflow_run_id, timestamp - FROM dislike_reasons - ORDER BY id DESC LIMIT 5 - """) - recent_dislikes = [dict(row) for row in cursor.fetchall()] - + async with aiosqlite.connect(DB_FILE) as conn: + conn.row_factory = sqlite3.Row + # 查询类型统计 + cursor = await conn.execute("SELECT COUNT(*) as count FROM query_types") + row = await cursor.fetchone() + query_count = row['count'] if row else 0 + # 点踩原因统计 + cursor = await conn.execute("SELECT COUNT(*) as count FROM dislike_reasons") + row = await cursor.fetchone() + dislike_count = row['count'] if row else 0 + # 最近5条查询记录 + cursor = await conn.execute( + """ + SELECT query_type, workflow_run_id, timestamp + FROM query_types + ORDER BY id DESC LIMIT 5 + """ + ) + recent_queries = [dict(r) for r in await cursor.fetchall()] + # 最近5条点踩记录 + cursor = await conn.execute( + """ + SELECT dislike_reason, workflow_run_id, timestamp + FROM dislike_reasons + ORDER BY id DESC LIMIT 5 + """ + ) + recent_dislikes = [dict(r) for r in await cursor.fetchall()] return { "statistics": { "total_queries": query_count, @@ -322,25 +300,21 @@ async def get_stats(): @app.get("/query_by_workflow_id", summary="根据工作流ID获取查询类型数据") async def get_query_by_workflow_id(workflow_run_id: str): try: - with db_lock: - with closing(sqlite3.connect(DB_FILE)) as conn: - conn.row_factory = sqlite3.Row # 启用字典行工厂 - cursor = conn.cursor() - - # 查询指定工作流ID的查询类型数据 - cursor.execute(""" - SELECT id, query_type, workflow_run_id, timestamp - FROM query_types - WHERE workflow_run_id = ? - ORDER BY id DESC - """, (workflow_run_id,)) - - results = [dict(row) for row in cursor.fetchall()] - - if not results: - return {"message": f"未找到工作流ID为 {workflow_run_id} 的查询类型数据", "data": []} - - return {"data": results} + async with aiosqlite.connect(DB_FILE) as conn: + conn.row_factory = sqlite3.Row + cursor = await conn.execute( + """ + SELECT id, query_type, workflow_run_id, timestamp + FROM query_types + WHERE workflow_run_id = ? + ORDER BY id DESC + """, + (workflow_run_id,) + ) + results = [dict(r) for r in await cursor.fetchall()] + if not results: + return {"message": f"未找到工作流ID为 {workflow_run_id} 的查询类型数据", "data": []} + return {"data": results} except Exception as e: logger.error(f"根据工作流ID获取查询类型数据时出错: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"根据工作流ID获取查询类型数据时出错: {str(e)}") @@ -348,25 +322,21 @@ async def get_query_by_workflow_id(workflow_run_id: str): @app.get("/dislike_by_workflow_id", summary="根据工作流ID获取点踩原因数据") async def get_dislike_by_workflow_id(workflow_run_id: str): try: - with db_lock: - with closing(sqlite3.connect(DB_FILE)) as conn: - conn.row_factory = sqlite3.Row # 启用字典行工厂 - cursor = conn.cursor() - - # 查询指定工作流ID的点踩原因数据 - cursor.execute(""" - SELECT id, dislike_reason, workflow_run_id, timestamp - FROM dislike_reasons - WHERE workflow_run_id = ? - ORDER BY id DESC - """, (workflow_run_id,)) - - results = [dict(row) for row in cursor.fetchall()] - - if not results: - return {"message": f"未找到工作流ID为 {workflow_run_id} 的点踩原因数据", "data": []} - - return {"data": results} + async with aiosqlite.connect(DB_FILE) as conn: + conn.row_factory = sqlite3.Row + cursor = await conn.execute( + """ + SELECT id, dislike_reason, workflow_run_id, timestamp + FROM dislike_reasons + WHERE workflow_run_id = ? + ORDER BY id DESC + """, + (workflow_run_id,) + ) + results = [dict(r) for r in await cursor.fetchall()] + if not results: + return {"message": f"未找到工作流ID为 {workflow_run_id} 的点踩原因数据", "data": []} + return {"data": results} except Exception as e: logger.error(f"根据工作流ID获取点踩原因数据时出错: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"根据工作流ID获取点踩原因数据时出错: {str(e)}") @@ -374,6 +344,6 @@ async def get_dislike_by_workflow_id(workflow_run_id: str): if __name__ == "__main__": # 使用Uvicorn运行FastAPI应用 import uvicorn - uvicorn.run("rag2_0.dify.AnswerType:app", host="0.0.0.0", port=8003, reload=False, workers=1, log_level="info") + uvicorn.run("rag2_0.api.AnswerType_api:app", host="0.0.0.0", port=8003, reload=False, workers=1, log_level="info") # 生产环境可以使用以下命令启动: - # uvicorn rag2_0.dify.AnswerType:app --host 0.0.0.0 --port 8003 --workers 1 \ No newline at end of file + # uvicorn rag2_0.api.AnswerType_api:app --host 0.0.0.0 --port 8003 --workers 1 \ No newline at end of file diff --git a/rag2_0/api/DifyQueryRetrieval_api.py b/rag2_0/api/DifyQueryRetrieval_api.py index 4ef9c88..966a999 100644 --- a/rag2_0/api/DifyQueryRetrieval_api.py +++ b/rag2_0/api/DifyQueryRetrieval_api.py @@ -118,7 +118,7 @@ async def retrieve(request: RetrieveRequest): if __name__ == "__main__": # 使用Uvicorn运行FastAPI应用 import uvicorn - uvicorn.run("rag2_0.api.DifyQueryRetrieval_api:app", host="0.0.0.0", port=9002, reload=False, workers=1, log_level="info") + uvicorn.run("rag2_0.api.DifyQueryRetrieval_api:app", host="0.0.0.0", port=8002, reload=False, workers=1, log_level="info") # # 使用uvicorn启动服务 # import uvicorn # uvicorn.run( diff --git a/rag2_0/api/intent_recognition_api.py b/rag2_0/api/intent_recognition_api.py index f88aea7..2573806 100755 --- a/rag2_0/api/intent_recognition_api.py +++ b/rag2_0/api/intent_recognition_api.py @@ -175,7 +175,7 @@ if __name__ == "__main__": # 使用uvicorn启动服务 import uvicorn uvicorn.run( - "rag2_0.dify.intent_recognition_api:app", + "rag2_0.api.intent_recognition_api:app", host="0.0.0.0", port=9001, reload=True, # 开发环境启用热重载 diff --git a/rag2_0/api/query_dinge_qingdan_api.py b/rag2_0/api/query_dinge_qingdan_api.py index 04b671b..b074bcd 100644 --- a/rag2_0/api/query_dinge_qingdan_api.py +++ b/rag2_0/api/query_dinge_qingdan_api.py @@ -3,6 +3,8 @@ from fastapi import FastAPI, HTTPException, Query from pydantic import BaseModel from typing import List, Optional, Dict, Any import uvicorn +import asyncio +import aiosqlite import sqlite3 import sys import os @@ -73,11 +75,13 @@ class QingDanDingEQueryService: db_file=self.db_path ) - def get_similar_names_by_vector(self, query_text:str, vector_db:SQLiteVSS, field_map:dict, top_k:int=3, scope:str=None): - """使用向量检索获取相似名称""" + async def get_similar_names_by_vector(self, query_text:str, vector_db:SQLiteVSS, field_map:dict, top_k:int=3, scope:str=None): + """使用向量检索获取相似名称(异步包装)""" try: - # 使用向量数据库进行相似性搜索 - results = vector_db.similarity_search_with_score(query=query_text, k=30) + # 使用线程池包装同步向量检索,避免阻塞事件循环 + results = await asyncio.to_thread( + vector_db.similarity_search_with_score, query_text, 30 + ) # 提取结果中的元数据 similar_items = [] @@ -90,16 +94,16 @@ class QingDanDingEQueryService: metadata['similarity_score'] = float(score) similar_items.append(metadata) - # 按相似度分数排序,分数高的排前面 + # 分数越小越相似(SQLiteVSS 多为距离),已有代码按升序排序 similar_items.sort(key=lambda x: x['similarity_score']) return similar_items[:top_k] except Exception as e: print(f"向量检索出错: {str(e)}") return [] - def get_db_connection(self): - """获取数据库连接""" - conn = sqlite3.connect(self.db_path) + async def get_db_connection(self): + """获取数据库连接(aiosqlite 异步)""" + conn = await aiosqlite.connect(self.db_path) conn.row_factory = sqlite3.Row # 设置行工厂,使结果可以通过列名访问 return conn @@ -138,12 +142,9 @@ class QingDanDingEQueryService: # 合并结果,完全匹配的排在前面 return exact_matches + partial_matches - def query_ding_e_by_name(self, name, scope=None): + async def query_ding_e_by_name(self, name, scope=None): """根据定额名称查询定额子目表中详情信息,使用向量检索扩大查询范围""" try: - conn = self.get_db_connection() - cursor = conn.cursor() - # 获取表名和字段映射 zimu_table = ExcelToSQLiteProcessor.ding_e_table_names["定额子目"] mulu_table = ExcelToSQLiteProcessor.ding_e_table_names["定额目录"] @@ -151,7 +152,7 @@ class QingDanDingEQueryService: field_map = ExcelToSQLiteProcessor.ding_e_field_map # 1. 先使用向量检索获取相似名称 - similar_items = self.get_similar_names_by_vector(query_text=name, + similar_items = await self.get_similar_names_by_vector(query_text=name, vector_db=self.ding_e_vector_db, field_map=field_map, scope=scope) @@ -190,18 +191,16 @@ class QingDanDingEQueryService: query += f" AND attr.{field_map['适用范围']} LIKE ?" params.append(f'%{scope}%') - cursor.execute(query, params) - - # 获取结果 - results = cursor.fetchall() - data = [dict(row) for row in results] + async with await self.get_db_connection() as conn: + cursor = await conn.execute(query, params) + # 获取结果 + results = await cursor.fetchall() + data = [dict(row) for row in results] # 对结果进行排序,将全字匹配的排在前面 data = self.sort_results_by_exact_match(data, name, field_map['名称']) data = data[:self.top_k] - conn.close() - if not data: return {"success": True, "message": "未找到匹配的定额信息", "data": []} @@ -219,12 +218,10 @@ class QingDanDingEQueryService: except Exception as e: return {"success": False, "message": f"查询出错: {str(e)}"} - def query_ding_e_by_code(self, code, scope=None): + async def query_ding_e_by_code(self, code, scope=None): """根据定额编码查询定额子目表中详情信息""" try: code = code.upper() - conn = self.get_db_connection() - cursor = conn.cursor() # 获取表名和字段映射 zimu_table = ExcelToSQLiteProcessor.ding_e_table_names["定额子目"] @@ -255,18 +252,16 @@ class QingDanDingEQueryService: query += f" AND attr.{field_map['适用范围']} LIKE ?" params.append(f'%{scope}%') - cursor.execute(query, params) - - # 获取结果 - results = cursor.fetchall() - data = [dict(row) for row in results] + async with await self.get_db_connection() as conn: + cursor = await conn.execute(query, params) + # 获取结果 + results = await cursor.fetchall() + data = [dict(row) for row in results] # 对结果进行排序,将全字匹配的排在前面 data = self.sort_results_by_exact_match(data, code, field_map['编码']) data = data[:self.top_k] - conn.close() - if not data: return {"success": True, "message": "未找到匹配的定额信息", "data": []} @@ -284,12 +279,9 @@ class QingDanDingEQueryService: except Exception as e: return {"success": False, "message": f"查询出错: {str(e)}"} - def query_qing_dan_by_name(self, name, scope=None): + async def query_qing_dan_by_name(self, name, scope=None): """根据清单名称查询清单子目表中详情信息,使用向量检索扩大查询范围""" try: - conn = self.get_db_connection() - cursor = conn.cursor() - # 获取表名和字段映射 zimu_table = ExcelToSQLiteProcessor.qing_dan_table_names["清单子目"] mulu_table = ExcelToSQLiteProcessor.qing_dan_table_names["清单目录"] @@ -297,7 +289,7 @@ class QingDanDingEQueryService: field_map = ExcelToSQLiteProcessor.qing_dan_field_map # 1. 先使用向量检索获取相似名称 - similar_items = self.get_similar_names_by_vector(query_text=name, vector_db=self.qing_dan_vector_db, field_map=field_map, scope=scope) + similar_items = await self.get_similar_names_by_vector(query_text=name, vector_db=self.qing_dan_vector_db, field_map=field_map, scope=scope) similar_names = [item['mc'] for item in similar_items] # 构建查询条件,始终包含原始名称的模糊匹配 @@ -333,18 +325,16 @@ class QingDanDingEQueryService: query += f" AND attr.{field_map['适用范围']} LIKE ?" params.append(f'%{scope}%') - cursor.execute(query, params) - - # 获取结果 - results = cursor.fetchall() - data = [dict(row) for row in results] + async with await self.get_db_connection() as conn: + cursor = await conn.execute(query, params) + # 获取结果 + results = await cursor.fetchall() + data = [dict(row) for row in results] # 对结果进行排序,将全字匹配的排在前面 data = self.sort_results_by_exact_match(data, name, field_map['名称']) data = data[:self.top_k] - conn.close() - if not data: return {"success": True, "message": "未找到匹配的清单信息", "data": []} @@ -362,12 +352,10 @@ class QingDanDingEQueryService: except Exception as e: return {"success": False, "message": f"查询出错: {str(e)}"} - def query_qing_dan_by_code(self, code, scope=None): + async def query_qing_dan_by_code(self, code, scope=None): """根据清单编码查询清单子目表中详情信息""" try: code = code.upper() - conn = self.get_db_connection() - cursor = conn.cursor() # 获取表名和字段映射 zimu_table = ExcelToSQLiteProcessor.qing_dan_table_names["清单子目"] @@ -398,18 +386,16 @@ class QingDanDingEQueryService: query += f" AND attr.{field_map['适用范围']} LIKE ?" params.append(f'%{scope}%') - cursor.execute(query, params) - - # 获取结果 - results = cursor.fetchall() - data = [dict(row) for row in results] + async with await self.get_db_connection() as conn: + cursor = await conn.execute(query, params) + # 获取结果 + results = await cursor.fetchall() + data = [dict(row) for row in results] # 对结果进行排序,将全字匹配的排在前面 data = self.sort_results_by_exact_match(data, code, field_map['编码']) data = data[:self.top_k] - conn.close() - if not data: return {"success": True, "message": "未找到匹配的清单信息", "data": []} @@ -427,8 +413,8 @@ class QingDanDingEQueryService: except Exception as e: return {"success": False, "message": f"查询出错: {str(e)}"} - def batch_query(self, requests:BatchQueryRequest): - """批量查询接口,支持向量检索""" + async def batch_query(self, requests:BatchQueryRequest): + """批量查询接口,支持向量检索(并发执行)""" dinge_results = [] qingdan_results = [] tracking_dict = {} # 用于跟踪已查询过的项目,避免重复 @@ -439,41 +425,48 @@ class QingDanDingEQueryService: qingdan_info = requests.dinge_qingdan_info.qingdan_info scope = requests.scope + dinge_tasks = [] + qingdan_tasks = [] + # 处理定额编码查询 for code in dinge_info.dinge_code_list or []: key = f"dinge_code_{code}_{scope}" if key not in tracking_dict: - result = self.query_ding_e_by_code(code, scope) - if result["success"] and result["data"]: - dinge_results.extend(result["data"]) - tracking_dict[key] = True + dinge_tasks.append(self.query_ding_e_by_code(code, scope)) + tracking_dict[key] = True # 处理定额名称查询 for name in dinge_info.dinge_name_list or []: key = f"dinge_name_{name}_{scope}" if key not in tracking_dict: - result = self.query_ding_e_by_name(name, scope) - if result["success"] and result["data"]: - dinge_results.extend(result["data"]) - tracking_dict[key] = True + dinge_tasks.append(self.query_ding_e_by_name(name, scope)) + tracking_dict[key] = True # 处理清单编码查询 for code in qingdan_info.qingdan_code_list or []: key = f"qingdan_code_{code}_{scope}" if key not in tracking_dict: - result = self.query_qing_dan_by_code(code, scope) - if result["success"] and result["data"]: - qingdan_results.extend(result["data"]) - tracking_dict[key] = True + qingdan_tasks.append(self.query_qing_dan_by_code(code, scope)) + tracking_dict[key] = True # 处理清单名称查询 for name in qingdan_info.qingdan_name_list or []: key = f"qingdan_name_{name}_{scope}" if key not in tracking_dict: - result = self.query_qing_dan_by_name(name, scope) - if result["success"] and result["data"]: + qingdan_tasks.append(self.query_qing_dan_by_name(name, scope)) + tracking_dict[key] = True + + # 并发执行 + if dinge_tasks: + dinge_outs = await asyncio.gather(*dinge_tasks) + for result in dinge_outs: + if result and result.get("success") and result.get("data"): + dinge_results.extend(result["data"]) + if qingdan_tasks: + qingdan_outs = await asyncio.gather(*qingdan_tasks) + for result in qingdan_outs: + if result and result.get("success") and result.get("data"): qingdan_results.extend(result["data"]) - tracking_dict[key] = True # 限制返回结果数量 dinge_results = dinge_results[:self.top_k] @@ -505,7 +498,7 @@ async def query_ding_e_by_name( name: str = Query(..., description="定额名称"), scope: Optional[str] = Query(None, description="适用范围") ): - result = query_service.query_ding_e_by_name(name, scope) + result = await query_service.query_ding_e_by_name(name, scope) if not result["success"]: raise HTTPException(status_code=500, detail=result["message"]) return QueryResponse(**result) @@ -516,7 +509,7 @@ async def query_ding_e_by_code( code: str = Query(..., description="定额编码"), scope: Optional[str] = Query(None, description="适用范围") ): - result = query_service.query_ding_e_by_code(code, scope) + result = await query_service.query_ding_e_by_code(code, scope) if not result["success"]: raise HTTPException(status_code=500, detail=result["message"]) return QueryResponse(**result) @@ -527,7 +520,7 @@ async def query_qing_dan_by_name( name: str = Query(..., description="清单名称"), scope: Optional[str] = Query(None, description="适用范围") ): - result = query_service.query_qing_dan_by_name(name, scope) + result = await query_service.query_qing_dan_by_name(name, scope) if not result["success"]: raise HTTPException(status_code=500, detail=result["message"]) return QueryResponse(**result) @@ -538,7 +531,7 @@ async def query_qing_dan_by_code( code: str = Query(..., description="清单编码"), scope: Optional[str] = Query(None, description="适用范围") ): - result = query_service.query_qing_dan_by_code(code, scope) + result = await query_service.query_qing_dan_by_code(code, scope) if not result["success"]: raise HTTPException(status_code=500, detail=result["message"]) return QueryResponse(**result) @@ -546,7 +539,7 @@ async def query_qing_dan_by_code( # 5. 批量查询定额和清单信息 @app.post("/api/batch_query", response_model=BatchQueryResponse) async def batch_query(request: BatchQueryRequest): - result = query_service.batch_query(request) + result = await query_service.batch_query(request) if not result["success"]: raise HTTPException(status_code=500, detail=result["message"]) return BatchQueryResponse(**result) @@ -564,4 +557,4 @@ def main(): if __name__ == "__main__": main() - # uvicorn rag2_0.dify.query_dinge_qingdan_api:app --host 0.0.0.0 --port 8005 --workers 10 \ No newline at end of file + # uvicorn rag2_0.api.query_dinge_qingdan_api:app --host 0.0.0.0 --port 8005 --workers 10 \ No newline at end of file diff --git a/start_AnswerType.sh b/start_AnswerType.sh deleted file mode 100755 index c478200..0000000 --- a/start_AnswerType.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash - -# 获取当前脚本所在的绝对路径 -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" - -# 检查是否已经存在名为AnswerType的screen会话 -if screen -ls | grep -q "\.AnswerType\s"; then - echo "Screen session 'AnswerType' already exists." -else - # 启动一个名为AnswerType的screen会话,并在其中执行后续命令 - screen -dmS AnswerType bash -c " - cd \"$SCRIPT_DIR\" - uv run uvicorn rag2_0.api.AnswerType_api:app --host 0.0.0.0 --port 8003 --workers 1 - " - - # 输出提示信息 - echo "Started screen session 'AnswerType' and executed the command." -fi \ No newline at end of file diff --git a/start_DifyQueryRetrieval_api.sh b/start_DifyQueryRetrieval_api.sh deleted file mode 100755 index 70545fe..0000000 --- a/start_DifyQueryRetrieval_api.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash - -# 获取当前脚本所在的绝对路径 -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" - -# 检查是否已经存在名为DifyQueryRetrieval_api的screen会话 -if screen -ls | grep "DifyQueryRetrieval_api"; then - echo "Screen session 'DifyQueryRetrieval_api' already exists." -else - # 启动一个名为DifyQueryRetrieval_api的screen会话,并在其中执行后续命令 - screen -dmS DifyQueryRetrieval_api bash -c " - cd \"$SCRIPT_DIR\" - uv run uvicorn rag2_0.api.DifyQueryRetrieval_api:app --host 0.0.0.0 --port 8002 --workers 25 - " - - # 输出提示信息 - echo "Started screen session 'DifyQueryRetrieval_api' and executed the command." -fi \ No newline at end of file diff --git a/start_intent_recognition_api.sh b/start_intent_recognition_api.sh deleted file mode 100755 index ba0f564..0000000 --- a/start_intent_recognition_api.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash - -# 获取当前脚本所在的绝对路径 -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" - -# 检查是否已经存在名为xinference的screen会话 -if screen -ls | grep "intent_recognition_api"; then - echo "Screen session 'intent_recognition_api' already exists." -else - # 启动一个名为intent_recognition_api的screen会话,并在其中执行后续命令 - screen -dmS intent_recognition_api bash -c " - cd \"$SCRIPT_DIR\" - uv run uvicorn rag2_0.api.intent_recognition_api:app --host 0.0.0.0 --port 8001 --workers 25 - " - - # 输出提示信息 - echo "Started screen session 'intent_recognition_api' and executed the command." -fi \ No newline at end of file diff --git a/uv.lock b/uv.lock index aaf142d..3c6a780 100644 --- a/uv.lock +++ b/uv.lock @@ -97,6 +97,18 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/ec/6a/bc7e17a3e87a2985d3e8f4da4cd0f481060eb78fb08596c42be62c90a4d9/aiosignal-1.3.2-py2.py3-none-any.whl", hash = "sha256:45cde58e409a301715980c2b01d0c28bdde3770d8290b5eb2173759d9acb31a5" }, ] +[[package]] +name = "aiosqlite" +version = "0.21.0" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/13/7d/8bca2bf9a247c2c5dfeec1d7a5f40db6518f88d314b8bca9da29670d2671/aiosqlite-0.21.0.tar.gz", hash = "sha256:131bb8056daa3bc875608c631c678cda73922a2d4ba8aec373b19f18c17e7aa3" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/f5/10/6c25ed6de94c49f88a91fa5018cb4c0f3625f31d5be9f771ebe5cc7cd506/aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0" }, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -1519,6 +1531,7 @@ name = "rag2-0" version = "0.1.0" source = { virtual = "." } dependencies = [ + { name = "aiosqlite" }, { name = "bs4" }, { name = "faiss-cpu" }, { name = "fastapi" }, @@ -1548,6 +1561,7 @@ dependencies = [ [package.metadata] requires-dist = [ + { name = "aiosqlite", specifier = ">=0.20.0" }, { name = "bs4", specifier = ">=0.0.2" }, { name = "faiss-cpu", specifier = ">=1.11.0" }, { name = "fastapi", specifier = ">=0.115.14" },