添加一个统一的脚本管理服务

This commit is contained in:
2025-08-29 09:18:24 +08:00
parent 5ec18811d9
commit 78dc1673aa
10 changed files with 484 additions and 291 deletions
+269
View File
@@ -0,0 +1,269 @@
#!/usr/bin/env bash
# 统一管理脚本:启动/停止/查看 四个 API 服务
# 支持服务:
# - intent -> rag2_0.api.intent_recognition_api:app (port 8001, workers 25)
# - dify -> rag2_0.api.DifyQueryRetrieval_api:app (port 8002, workers 25)
# - answertype -> rag2_0.api.AnswerType_api:app (port 8003, workers 1)
# - qingdan -> rag2_0.api.query_dinge_qingdan_api:app (port 8005, workers 1)
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
# 定义服务配置:会话名 与 启动命令
SERVICE_NAMES=(intent dify answertype qingdan)
service_port() {
case "$1" in
intent) echo "8001" ;;
dify) echo "8002" ;;
answertype) echo "8003" ;;
qingdan) echo "8005" ;;
*) echo "" ;;
esac
}
session_name() {
case "$1" in
intent) echo "intent_recognition_api" ;;
dify) echo "DifyQueryRetrieval_api" ;;
answertype) echo "AnswerType" ;;
qingdan) echo "query_dinge_qingdan_api" ;;
*) echo "" ;;
esac
}
start_command() {
case "$1" in
intent)
echo "cd \"$SCRIPT_DIR\" && uv run uvicorn rag2_0.api.intent_recognition_api:app --host 0.0.0.0 --port 8001 --workers 25" ;;
dify)
echo "cd \"$SCRIPT_DIR\" && uv run uvicorn rag2_0.api.DifyQueryRetrieval_api:app --host 0.0.0.0 --port 8002 --workers 25" ;;
answertype)
echo "cd \"$SCRIPT_DIR\" && uv run uvicorn rag2_0.api.AnswerType_api:app --host 0.0.0.0 --port 8003 --workers 1" ;;
qingdan)
echo "cd \"$SCRIPT_DIR\" && uv run uvicorn rag2_0.api.query_dinge_qingdan_api:app --host 0.0.0.0 --port 8005 --workers 4" ;;
*) echo "" ;;
esac
}
exists_session() {
# 使用严格匹配,避免误判
local name="$1"
if screen -ls 2>/dev/null | grep -q "\\.${name}\\s"; then
return 0
fi
return 1
}
# 按端口获取监听该端口的任意一个PID,优先用 ss,其次 lsof
pids_on_port() {
local port="$1"
# 从 ss 提取 pid 列表
local ss_pids
ss_pids=$(ss -lptn 2>/dev/null \
| grep -E ":${port}\\b" \
| awk '{print $NF}' \
| sed -n 's/.*pid=\([0-9]\+\),.*/\1/p' \
| sort -u)
if [[ -n "$ss_pids" ]]; then
echo "$ss_pids"
return 0
fi
# 从 lsof 提取 pid 列表
if command -v lsof >/dev/null 2>&1; then
local lsof_pids
lsof_pids=$(lsof -nP -i :"$port" -sTCP:LISTEN -t 2>/dev/null | sort -u)
if [[ -n "$lsof_pids" ]]; then
echo "$lsof_pids"
return 0
fi
fi
return 1
}
# 根据端口优雅终止(TERM)并在必要时强制(KILL)清理进程
kill_by_port() {
local port="$1"
local pids
pids=$(pids_on_port "$port" || true)
if [[ -z "$pids" ]]; then
return 0
fi
echo "[清理] 端口 $port 仍被占用,发送 SIGTERM 到: $pids"
kill -TERM $pids 2>/dev/null || true
sleep 2
# 再次检查
local left
left=$(pids_on_port "$port" || true)
if [[ -n "$left" ]]; then
echo "[强制] 端口 $port 仍占用,发送 SIGKILL 到: $left"
kill -KILL $left 2>/dev/null || true
fi
}
start_service() {
local svc="$1"
local sname
sname="$(session_name "$svc")"
if [[ -z "$sname" ]]; then echo "未知服务: $svc"; return 2; fi
if exists_session "$sname"; then
echo "[跳过] 会话 '$sname' 已存在"
return 0
fi
local cmd
cmd="$(start_command "$svc")"
if [[ -z "$cmd" ]]; then echo "未配置启动命令: $svc"; return 2; fi
screen -dmS "$sname" bash -c "$cmd"
echo "[启动] $svc -> screen 会话 '$sname'"
}
stop_service() {
local svc="$1"
local sname
sname="$(session_name "$svc")"
local port
port="$(service_port "$svc")"
if [[ -z "$sname" || -z "$port" ]]; then echo "未知服务: $svc"; return 2; fi
# 1) 先尝试关闭 screen 会话
if exists_session "$sname"; then
screen -S "$sname" -X quit || true
echo "[停止] $svc -> '$sname'"
else
echo "[提示] 未发现 screen 会话: $sname"
fi
# 2) 等待释放端口
sleep 2
# 3) 如果仍占用,按端口清理
if ss -lptn 2>/dev/null | grep -E -q ":${port}\\b" || (command -v lsof >/dev/null 2>&1 && lsof -i :"$port" -sTCP:LISTEN >/dev/null 2>&1); then
kill_by_port "$port"
fi
}
status_service() {
local svc="$1"
local sname
sname="$(session_name "$svc")"
if [[ -z "$sname" ]]; then echo "未知服务: $svc"; return 2; fi
if exists_session "$sname"; then
echo "[运行中] $svc -> '$sname'"
else
echo "[未运行] $svc"
fi
}
attach_service() {
local svc="$1"
local sname
sname="$(session_name "$svc")"
if [[ -z "$sname" ]]; then echo "未知服务: $svc"; return 2; fi
if exists_session "$sname"; then
echo "附着到会话: $sname (退出: Ctrl+A 然后 D)"
screen -r "$sname"
else
echo "服务未运行: $svc"
return 1
fi
}
start_all() {
for s in "${SERVICE_NAMES[@]}"; do
start_service "$s"
done
}
stop_all() {
for s in "${SERVICE_NAMES[@]}"; do
stop_service "$s"
done
}
status_all() {
for s in "${SERVICE_NAMES[@]}"; do
status_service "$s"
done
}
restart_service() {
local svc="$1"
stop_service "$svc"
# 等待会话释放
sleep 1
start_service "$svc"
}
usage() {
cat <<EOF
用法: $0 <command> [service]
command:
start [svc] 启动指定服务;不指定时启动全部
stop [svc] 停止指定服务;不指定时停止全部
restart [svc] 重启指定服务;不指定时重启全部
status 查看所有服务状态
attach <svc> 附着到指定服务的 screen 会话
force-stop [svc] 强制结束进程(按端口终止);不指定时对全部执行
service 可选值:
intent | dify | answertype | qingdan
EOF
}
main() {
local cmd="${1:-}"; shift || true
case "$cmd" in
start)
local svc="${1:-all}"
if [[ "$svc" == "all" ]]; then start_all; else start_service "$svc"; fi
;;
stop)
local svc="${1:-all}"
if [[ "$svc" == "all" ]]; then stop_all; else stop_service "$svc"; fi
;;
restart)
local svc="${1:-all}"
if [[ "$svc" == "all" ]]; then
for s in "${SERVICE_NAMES[@]}"; do restart_service "$s"; done
else
restart_service "$svc"
fi
;;
status)
status_all
;;
attach)
local svc="${1:-}"
if [[ -z "$svc" ]]; then echo "请指定服务"; usage; exit 2; fi
attach_service "$svc"
;;
force-stop)
local svc="${1:-all}"
if [[ "$svc" == "all" ]]; then
for s in "${SERVICE_NAMES[@]}"; do
# 仅按端口强制清理
p="$(service_port "$s")"
if [[ -n "$p" ]]; then kill_by_port "$p"; fi
done
else
local p
p="$(service_port "$svc")"
if [[ -z "$p" ]]; then echo "未知服务: $svc"; exit 2; fi
kill_by_port "$p"
fi
;;
""|-h|--help|help)
usage
;;
*)
echo "未知命令: $cmd" >&2
usage
exit 2
;;
esac
}
main "$@"
+1
View File
@@ -6,6 +6,7 @@ readme = "README.md"
requires-python = ">=3.11" requires-python = ">=3.11"
dependencies = [ dependencies = [
"bs4>=0.0.2", "bs4>=0.0.2",
"aiosqlite>=0.20.0",
"faiss-cpu>=1.11.0", "faiss-cpu>=1.11.0",
"fastapi>=0.115.14", "fastapi>=0.115.14",
"flask>=3.1.1", "flask>=3.1.1",
+85 -115
View File
@@ -8,8 +8,7 @@ from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Dict, List, Any, Optional from typing import Dict, List, Any, Optional
import asyncio import asyncio
import threading import aiosqlite
import queue
import sqlite3 import sqlite3
from contextlib import closing from contextlib import closing
@@ -38,21 +37,15 @@ DATA_DIR = os.path.join(os.getcwd(), "data")
DB_DIR = os.path.join(DATA_DIR, "db") DB_DIR = os.path.join(DATA_DIR, "db")
DB_FILE = os.path.join(DB_DIR, "answer_logs.db") DB_FILE = os.path.join(DB_DIR, "answer_logs.db")
# 创建异步日志队列和工作线程 # 创建异步日志队列和后台任务
log_queue = queue.Queue() log_queue: asyncio.Queue = asyncio.Queue()
worker_thread = None worker_task: asyncio.Task | None = None
db_lock = threading.Lock() # 数据库操作锁
# 初始化数据库 # 初始化数据库(异步)
def init_database(): async def init_database():
os.makedirs(DB_DIR, exist_ok=True) os.makedirs(DB_DIR, exist_ok=True)
async with aiosqlite.connect(DB_FILE) as conn:
with db_lock: await conn.execute('''
with closing(sqlite3.connect(DB_FILE)) as conn:
cursor = conn.cursor()
# 创建查询类型表
cursor.execute('''
CREATE TABLE IF NOT EXISTS query_types ( CREATE TABLE IF NOT EXISTS query_types (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
query_type TEXT NOT NULL, query_type TEXT NOT NULL,
@@ -60,9 +53,7 @@ def init_database():
timestamp TEXT NOT NULL timestamp TEXT NOT NULL
) )
''') ''')
await conn.execute('''
# 创建点踩原因表
cursor.execute('''
CREATE TABLE IF NOT EXISTS dislike_reasons ( CREATE TABLE IF NOT EXISTS dislike_reasons (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
dislike_reason TEXT NOT NULL, dislike_reason TEXT NOT NULL,
@@ -70,57 +61,46 @@ def init_database():
timestamp TEXT NOT NULL timestamp TEXT NOT NULL
) )
''') ''')
await conn.commit()
conn.commit()
logger.info("数据库初始化完成") logger.info("数据库初始化完成")
# 后台工作线程函数 # 后台异步工作协程
def log_worker(): async def log_worker():
while True:
try: try:
# 从队列获取数据,设置超时以允许线程退出 while True:
data = log_queue.get(timeout=1.0) data = await log_queue.get()
if data is None: # 接收到退出信号 if data is None:
# 处理剩余数据后退出 # 排空剩余数据后退出
while not log_queue.empty(): while not log_queue.empty():
data = log_queue.get_nowait() pending = log_queue.get_nowait()
if data is None: # 跳过额外的停止信号 if pending is None:
continue continue
process_log_data(data) await process_log_data(pending)
break break
await process_log_data(data)
process_log_data(data)
log_queue.task_done() log_queue.task_done()
except queue.Empty: except asyncio.CancelledError:
continue logger.info("日志工作任务被取消,尝试优雅退出...")
except Exception as e: except Exception as e:
logger.error(f"保存查询数据时出错: {str(e)}", exc_info=True) logger.error(f"保存查询数据时出错: {str(e)}", exc_info=True)
# 提取数据处理逻辑到单独函数 # 提取数据处理逻辑到单独异步函数
def process_log_data(data): async def process_log_data(data):
try: try:
with db_lock: async with aiosqlite.connect(DB_FILE) as conn:
with closing(sqlite3.connect(DB_FILE)) as conn:
cursor = conn.cursor()
if "dislike_reason" in data: if "dislike_reason" in data:
# 保存点踩原因 await conn.execute(
cursor.execute(
"INSERT INTO dislike_reasons (dislike_reason, workflow_run_id, timestamp) VALUES (?, ?, ?)", "INSERT INTO dislike_reasons (dislike_reason, workflow_run_id, timestamp) VALUES (?, ?, ?)",
(data["dislike_reason"], data["workflow_run_id"], data["timestamp"]) (data["dislike_reason"], data["workflow_run_id"], data["timestamp"])
) )
table_name = "dislike_reasons" table_name = "dislike_reasons"
else: else:
# 保存查询类型 await conn.execute(
cursor.execute(
"INSERT INTO query_types (query_type, workflow_run_id, timestamp) VALUES (?, ?, ?)", "INSERT INTO query_types (query_type, workflow_run_id, timestamp) VALUES (?, ?, ?)",
(data["query_type"], data["workflow_run_id"], data["timestamp"]) (data["query_type"], data["workflow_run_id"], data["timestamp"])
) )
table_name = "query_types" table_name = "query_types"
await conn.commit()
conn.commit()
logger.info(f"成功保存数据到表: {table_name}") logger.info(f"成功保存数据到表: {table_name}")
except Exception as e: except Exception as e:
logger.error(f"处理日志数据时出错: {str(e)}", exc_info=True) logger.error(f"处理日志数据时出错: {str(e)}", exc_info=True)
@@ -184,28 +164,27 @@ app.add_middleware(
# 应用启动事件 # 应用启动事件
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():
global worker_thread global worker_task
# 初始化数据库 # 初始化数据库
init_database() await init_database()
# 启动后台异步任务
# 启动后台工作线程 worker_task = asyncio.create_task(log_worker())
worker_thread = threading.Thread(target=log_worker, daemon=True) logger.info("后台日志工作任务已启动")
worker_thread.start()
logger.info("后台日志工作线程已启动")
# 应用关闭事件 # 应用关闭事件
@app.on_event("shutdown") @app.on_event("shutdown")
def shutdown_event(): async def shutdown_event():
global worker_thread global worker_task
if worker_thread: if worker_task:
# 发送退出信号 # 发送退出信号并等待任务结束
log_queue.put(None) await log_queue.put(None)
# 等待工作线程处理剩余数据 await asyncio.sleep(0) # 让出执行权给worker处理退出
worker_thread.join(timeout=10.0) try:
if worker_thread.is_alive(): await asyncio.wait_for(worker_task, timeout=10.0)
logger.warning("工作线程未在超时时间内退出") logger.info("后台日志工作任务已停止")
else: except asyncio.TimeoutError:
logger.info("后台日志工作线程已停止") worker_task.cancel()
logger.warning("工作任务未在超时时间内退出,已取消")
# 添加健康检查端点 # 添加健康检查端点
@app.get("/health", summary="健康检查") @app.get("/health", summary="健康检查")
@@ -226,9 +205,9 @@ async def query_type(query_type: str, workflow_run_id:str):
"workflow_run_id": workflow_run_id "workflow_run_id": workflow_run_id
} }
# 将数据放入队列 # 将数据放入异步队列
try: try:
log_queue.put(query_data) await log_queue.put(query_data)
success = True success = True
logger.info(f"查询数据已加入队列,当前队列大小: {log_queue.qsize()}") logger.info(f"查询数据已加入队列,当前队列大小: {log_queue.qsize()}")
except Exception as e: except Exception as e:
@@ -256,9 +235,9 @@ async def dislike_reason(reason: str, workflow_run_id: str):
"workflow_run_id": workflow_run_id "workflow_run_id": workflow_run_id
} }
# 将数据放入队列 # 将数据放入异步队列
try: try:
log_queue.put(dislike_data) await log_queue.put(dislike_data)
success = True success = True
logger.info(f"点踩原因数据已加入队列,当前队列大小: {log_queue.qsize()}") logger.info(f"点踩原因数据已加入队列,当前队列大小: {log_queue.qsize()}")
except Exception as e: except Exception as e:
@@ -276,35 +255,34 @@ async def dislike_reason(reason: str, workflow_run_id: str):
@app.get("/stats", summary="查询统计数据") @app.get("/stats", summary="查询统计数据")
async def get_stats(): async def get_stats():
try: try:
with db_lock: async with aiosqlite.connect(DB_FILE) as conn:
with closing(sqlite3.connect(DB_FILE)) as conn: conn.row_factory = sqlite3.Row
conn.row_factory = sqlite3.Row # 启用字典行工厂
cursor = conn.cursor()
# 查询类型统计 # 查询类型统计
cursor.execute("SELECT COUNT(*) as count FROM query_types") cursor = await conn.execute("SELECT COUNT(*) as count FROM query_types")
query_count = cursor.fetchone()['count'] row = await cursor.fetchone()
query_count = row['count'] if row else 0
# 点踩原因统计 # 点踩原因统计
cursor.execute("SELECT COUNT(*) as count FROM dislike_reasons") cursor = await conn.execute("SELECT COUNT(*) as count FROM dislike_reasons")
dislike_count = cursor.fetchone()['count'] row = await cursor.fetchone()
dislike_count = row['count'] if row else 0
# 最近5条查询记录 # 最近5条查询记录
cursor.execute(""" cursor = await conn.execute(
"""
SELECT query_type, workflow_run_id, timestamp SELECT query_type, workflow_run_id, timestamp
FROM query_types FROM query_types
ORDER BY id DESC LIMIT 5 ORDER BY id DESC LIMIT 5
""") """
recent_queries = [dict(row) for row in cursor.fetchall()] )
recent_queries = [dict(r) for r in await cursor.fetchall()]
# 最近5条点踩记录 # 最近5条点踩记录
cursor.execute(""" cursor = await conn.execute(
"""
SELECT dislike_reason, workflow_run_id, timestamp SELECT dislike_reason, workflow_run_id, timestamp
FROM dislike_reasons FROM dislike_reasons
ORDER BY id DESC LIMIT 5 ORDER BY id DESC LIMIT 5
""") """
recent_dislikes = [dict(row) for row in cursor.fetchall()] )
recent_dislikes = [dict(r) for r in await cursor.fetchall()]
return { return {
"statistics": { "statistics": {
"total_queries": query_count, "total_queries": query_count,
@@ -322,24 +300,20 @@ async def get_stats():
@app.get("/query_by_workflow_id", summary="根据工作流ID获取查询类型数据") @app.get("/query_by_workflow_id", summary="根据工作流ID获取查询类型数据")
async def get_query_by_workflow_id(workflow_run_id: str): async def get_query_by_workflow_id(workflow_run_id: str):
try: try:
with db_lock: async with aiosqlite.connect(DB_FILE) as conn:
with closing(sqlite3.connect(DB_FILE)) as conn: conn.row_factory = sqlite3.Row
conn.row_factory = sqlite3.Row # 启用字典行工厂 cursor = await conn.execute(
cursor = conn.cursor() """
# 查询指定工作流ID的查询类型数据
cursor.execute("""
SELECT id, query_type, workflow_run_id, timestamp SELECT id, query_type, workflow_run_id, timestamp
FROM query_types FROM query_types
WHERE workflow_run_id = ? WHERE workflow_run_id = ?
ORDER BY id DESC ORDER BY id DESC
""", (workflow_run_id,)) """,
(workflow_run_id,)
results = [dict(row) for row in cursor.fetchall()] )
results = [dict(r) for r in await cursor.fetchall()]
if not results: if not results:
return {"message": f"未找到工作流ID为 {workflow_run_id} 的查询类型数据", "data": []} return {"message": f"未找到工作流ID为 {workflow_run_id} 的查询类型数据", "data": []}
return {"data": results} return {"data": results}
except Exception as e: except Exception as e:
logger.error(f"根据工作流ID获取查询类型数据时出错: {str(e)}", exc_info=True) logger.error(f"根据工作流ID获取查询类型数据时出错: {str(e)}", exc_info=True)
@@ -348,24 +322,20 @@ async def get_query_by_workflow_id(workflow_run_id: str):
@app.get("/dislike_by_workflow_id", summary="根据工作流ID获取点踩原因数据") @app.get("/dislike_by_workflow_id", summary="根据工作流ID获取点踩原因数据")
async def get_dislike_by_workflow_id(workflow_run_id: str): async def get_dislike_by_workflow_id(workflow_run_id: str):
try: try:
with db_lock: async with aiosqlite.connect(DB_FILE) as conn:
with closing(sqlite3.connect(DB_FILE)) as conn: conn.row_factory = sqlite3.Row
conn.row_factory = sqlite3.Row # 启用字典行工厂 cursor = await conn.execute(
cursor = conn.cursor() """
# 查询指定工作流ID的点踩原因数据
cursor.execute("""
SELECT id, dislike_reason, workflow_run_id, timestamp SELECT id, dislike_reason, workflow_run_id, timestamp
FROM dislike_reasons FROM dislike_reasons
WHERE workflow_run_id = ? WHERE workflow_run_id = ?
ORDER BY id DESC ORDER BY id DESC
""", (workflow_run_id,)) """,
(workflow_run_id,)
results = [dict(row) for row in cursor.fetchall()] )
results = [dict(r) for r in await cursor.fetchall()]
if not results: if not results:
return {"message": f"未找到工作流ID为 {workflow_run_id} 的点踩原因数据", "data": []} return {"message": f"未找到工作流ID为 {workflow_run_id} 的点踩原因数据", "data": []}
return {"data": results} return {"data": results}
except Exception as e: except Exception as e:
logger.error(f"根据工作流ID获取点踩原因数据时出错: {str(e)}", exc_info=True) logger.error(f"根据工作流ID获取点踩原因数据时出错: {str(e)}", exc_info=True)
@@ -374,6 +344,6 @@ async def get_dislike_by_workflow_id(workflow_run_id: str):
if __name__ == "__main__": if __name__ == "__main__":
# 使用Uvicorn运行FastAPI应用 # 使用Uvicorn运行FastAPI应用
import uvicorn import uvicorn
uvicorn.run("rag2_0.dify.AnswerType:app", host="0.0.0.0", port=8003, reload=False, workers=1, log_level="info") uvicorn.run("rag2_0.api.AnswerType_api:app", host="0.0.0.0", port=8003, reload=False, workers=1, log_level="info")
# 生产环境可以使用以下命令启动: # 生产环境可以使用以下命令启动:
# uvicorn rag2_0.dify.AnswerType:app --host 0.0.0.0 --port 8003 --workers 1 # uvicorn rag2_0.api.AnswerType_api:app --host 0.0.0.0 --port 8003 --workers 1
+1 -1
View File
@@ -118,7 +118,7 @@ async def retrieve(request: RetrieveRequest):
if __name__ == "__main__": if __name__ == "__main__":
# 使用Uvicorn运行FastAPI应用 # 使用Uvicorn运行FastAPI应用
import uvicorn import uvicorn
uvicorn.run("rag2_0.api.DifyQueryRetrieval_api:app", host="0.0.0.0", port=9002, reload=False, workers=1, log_level="info") uvicorn.run("rag2_0.api.DifyQueryRetrieval_api:app", host="0.0.0.0", port=8002, reload=False, workers=1, log_level="info")
# # 使用uvicorn启动服务 # # 使用uvicorn启动服务
# import uvicorn # import uvicorn
# uvicorn.run( # uvicorn.run(
+1 -1
View File
@@ -175,7 +175,7 @@ if __name__ == "__main__":
# 使用uvicorn启动服务 # 使用uvicorn启动服务
import uvicorn import uvicorn
uvicorn.run( uvicorn.run(
"rag2_0.dify.intent_recognition_api:app", "rag2_0.api.intent_recognition_api:app",
host="0.0.0.0", host="0.0.0.0",
port=9001, port=9001,
reload=True, # 开发环境启用热重载 reload=True, # 开发环境启用热重载
+57 -64
View File
@@ -3,6 +3,8 @@ from fastapi import FastAPI, HTTPException, Query
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Optional, Dict, Any from typing import List, Optional, Dict, Any
import uvicorn import uvicorn
import asyncio
import aiosqlite
import sqlite3 import sqlite3
import sys import sys
import os import os
@@ -73,11 +75,13 @@ class QingDanDingEQueryService:
db_file=self.db_path db_file=self.db_path
) )
def get_similar_names_by_vector(self, query_text:str, vector_db:SQLiteVSS, field_map:dict, top_k:int=3, scope:str=None): async def get_similar_names_by_vector(self, query_text:str, vector_db:SQLiteVSS, field_map:dict, top_k:int=3, scope:str=None):
"""使用向量检索获取相似名称""" """使用向量检索获取相似名称(异步包装)"""
try: try:
# 使用向量数据库进行相似性搜索 # 使用线程池包装同步向量检索,避免阻塞事件循环
results = vector_db.similarity_search_with_score(query=query_text, k=30) results = await asyncio.to_thread(
vector_db.similarity_search_with_score, query_text, 30
)
# 提取结果中的元数据 # 提取结果中的元数据
similar_items = [] similar_items = []
@@ -90,16 +94,16 @@ class QingDanDingEQueryService:
metadata['similarity_score'] = float(score) metadata['similarity_score'] = float(score)
similar_items.append(metadata) similar_items.append(metadata)
# 按相似度分数排序,分数高的排前面 # 分数越小越相似(SQLiteVSS 多为距离),已有代码按升序排序
similar_items.sort(key=lambda x: x['similarity_score']) similar_items.sort(key=lambda x: x['similarity_score'])
return similar_items[:top_k] return similar_items[:top_k]
except Exception as e: except Exception as e:
print(f"向量检索出错: {str(e)}") print(f"向量检索出错: {str(e)}")
return [] return []
def get_db_connection(self): async def get_db_connection(self):
"""获取数据库连接""" """获取数据库连接aiosqlite 异步)"""
conn = sqlite3.connect(self.db_path) conn = await aiosqlite.connect(self.db_path)
conn.row_factory = sqlite3.Row # 设置行工厂,使结果可以通过列名访问 conn.row_factory = sqlite3.Row # 设置行工厂,使结果可以通过列名访问
return conn return conn
@@ -138,12 +142,9 @@ class QingDanDingEQueryService:
# 合并结果,完全匹配的排在前面 # 合并结果,完全匹配的排在前面
return exact_matches + partial_matches return exact_matches + partial_matches
def query_ding_e_by_name(self, name, scope=None): async def query_ding_e_by_name(self, name, scope=None):
"""根据定额名称查询定额子目表中详情信息,使用向量检索扩大查询范围""" """根据定额名称查询定额子目表中详情信息,使用向量检索扩大查询范围"""
try: try:
conn = self.get_db_connection()
cursor = conn.cursor()
# 获取表名和字段映射 # 获取表名和字段映射
zimu_table = ExcelToSQLiteProcessor.ding_e_table_names["定额子目"] zimu_table = ExcelToSQLiteProcessor.ding_e_table_names["定额子目"]
mulu_table = ExcelToSQLiteProcessor.ding_e_table_names["定额目录"] mulu_table = ExcelToSQLiteProcessor.ding_e_table_names["定额目录"]
@@ -151,7 +152,7 @@ class QingDanDingEQueryService:
field_map = ExcelToSQLiteProcessor.ding_e_field_map field_map = ExcelToSQLiteProcessor.ding_e_field_map
# 1. 先使用向量检索获取相似名称 # 1. 先使用向量检索获取相似名称
similar_items = self.get_similar_names_by_vector(query_text=name, similar_items = await self.get_similar_names_by_vector(query_text=name,
vector_db=self.ding_e_vector_db, vector_db=self.ding_e_vector_db,
field_map=field_map, field_map=field_map,
scope=scope) scope=scope)
@@ -190,18 +191,16 @@ class QingDanDingEQueryService:
query += f" AND attr.{field_map['适用范围']} LIKE ?" query += f" AND attr.{field_map['适用范围']} LIKE ?"
params.append(f'%{scope}%') params.append(f'%{scope}%')
cursor.execute(query, params) async with await self.get_db_connection() as conn:
cursor = await conn.execute(query, params)
# 获取结果 # 获取结果
results = cursor.fetchall() results = await cursor.fetchall()
data = [dict(row) for row in results] data = [dict(row) for row in results]
# 对结果进行排序,将全字匹配的排在前面 # 对结果进行排序,将全字匹配的排在前面
data = self.sort_results_by_exact_match(data, name, field_map['名称']) data = self.sort_results_by_exact_match(data, name, field_map['名称'])
data = data[:self.top_k] data = data[:self.top_k]
conn.close()
if not data: if not data:
return {"success": True, "message": "未找到匹配的定额信息", "data": []} return {"success": True, "message": "未找到匹配的定额信息", "data": []}
@@ -219,12 +218,10 @@ class QingDanDingEQueryService:
except Exception as e: except Exception as e:
return {"success": False, "message": f"查询出错: {str(e)}"} return {"success": False, "message": f"查询出错: {str(e)}"}
def query_ding_e_by_code(self, code, scope=None): async def query_ding_e_by_code(self, code, scope=None):
"""根据定额编码查询定额子目表中详情信息""" """根据定额编码查询定额子目表中详情信息"""
try: try:
code = code.upper() code = code.upper()
conn = self.get_db_connection()
cursor = conn.cursor()
# 获取表名和字段映射 # 获取表名和字段映射
zimu_table = ExcelToSQLiteProcessor.ding_e_table_names["定额子目"] zimu_table = ExcelToSQLiteProcessor.ding_e_table_names["定额子目"]
@@ -255,18 +252,16 @@ class QingDanDingEQueryService:
query += f" AND attr.{field_map['适用范围']} LIKE ?" query += f" AND attr.{field_map['适用范围']} LIKE ?"
params.append(f'%{scope}%') params.append(f'%{scope}%')
cursor.execute(query, params) async with await self.get_db_connection() as conn:
cursor = await conn.execute(query, params)
# 获取结果 # 获取结果
results = cursor.fetchall() results = await cursor.fetchall()
data = [dict(row) for row in results] data = [dict(row) for row in results]
# 对结果进行排序,将全字匹配的排在前面 # 对结果进行排序,将全字匹配的排在前面
data = self.sort_results_by_exact_match(data, code, field_map['编码']) data = self.sort_results_by_exact_match(data, code, field_map['编码'])
data = data[:self.top_k] data = data[:self.top_k]
conn.close()
if not data: if not data:
return {"success": True, "message": "未找到匹配的定额信息", "data": []} return {"success": True, "message": "未找到匹配的定额信息", "data": []}
@@ -284,12 +279,9 @@ class QingDanDingEQueryService:
except Exception as e: except Exception as e:
return {"success": False, "message": f"查询出错: {str(e)}"} return {"success": False, "message": f"查询出错: {str(e)}"}
def query_qing_dan_by_name(self, name, scope=None): async def query_qing_dan_by_name(self, name, scope=None):
"""根据清单名称查询清单子目表中详情信息,使用向量检索扩大查询范围""" """根据清单名称查询清单子目表中详情信息,使用向量检索扩大查询范围"""
try: try:
conn = self.get_db_connection()
cursor = conn.cursor()
# 获取表名和字段映射 # 获取表名和字段映射
zimu_table = ExcelToSQLiteProcessor.qing_dan_table_names["清单子目"] zimu_table = ExcelToSQLiteProcessor.qing_dan_table_names["清单子目"]
mulu_table = ExcelToSQLiteProcessor.qing_dan_table_names["清单目录"] mulu_table = ExcelToSQLiteProcessor.qing_dan_table_names["清单目录"]
@@ -297,7 +289,7 @@ class QingDanDingEQueryService:
field_map = ExcelToSQLiteProcessor.qing_dan_field_map field_map = ExcelToSQLiteProcessor.qing_dan_field_map
# 1. 先使用向量检索获取相似名称 # 1. 先使用向量检索获取相似名称
similar_items = self.get_similar_names_by_vector(query_text=name, vector_db=self.qing_dan_vector_db, field_map=field_map, scope=scope) similar_items = await self.get_similar_names_by_vector(query_text=name, vector_db=self.qing_dan_vector_db, field_map=field_map, scope=scope)
similar_names = [item['mc'] for item in similar_items] similar_names = [item['mc'] for item in similar_items]
# 构建查询条件,始终包含原始名称的模糊匹配 # 构建查询条件,始终包含原始名称的模糊匹配
@@ -333,18 +325,16 @@ class QingDanDingEQueryService:
query += f" AND attr.{field_map['适用范围']} LIKE ?" query += f" AND attr.{field_map['适用范围']} LIKE ?"
params.append(f'%{scope}%') params.append(f'%{scope}%')
cursor.execute(query, params) async with await self.get_db_connection() as conn:
cursor = await conn.execute(query, params)
# 获取结果 # 获取结果
results = cursor.fetchall() results = await cursor.fetchall()
data = [dict(row) for row in results] data = [dict(row) for row in results]
# 对结果进行排序,将全字匹配的排在前面 # 对结果进行排序,将全字匹配的排在前面
data = self.sort_results_by_exact_match(data, name, field_map['名称']) data = self.sort_results_by_exact_match(data, name, field_map['名称'])
data = data[:self.top_k] data = data[:self.top_k]
conn.close()
if not data: if not data:
return {"success": True, "message": "未找到匹配的清单信息", "data": []} return {"success": True, "message": "未找到匹配的清单信息", "data": []}
@@ -362,12 +352,10 @@ class QingDanDingEQueryService:
except Exception as e: except Exception as e:
return {"success": False, "message": f"查询出错: {str(e)}"} return {"success": False, "message": f"查询出错: {str(e)}"}
def query_qing_dan_by_code(self, code, scope=None): async def query_qing_dan_by_code(self, code, scope=None):
"""根据清单编码查询清单子目表中详情信息""" """根据清单编码查询清单子目表中详情信息"""
try: try:
code = code.upper() code = code.upper()
conn = self.get_db_connection()
cursor = conn.cursor()
# 获取表名和字段映射 # 获取表名和字段映射
zimu_table = ExcelToSQLiteProcessor.qing_dan_table_names["清单子目"] zimu_table = ExcelToSQLiteProcessor.qing_dan_table_names["清单子目"]
@@ -398,18 +386,16 @@ class QingDanDingEQueryService:
query += f" AND attr.{field_map['适用范围']} LIKE ?" query += f" AND attr.{field_map['适用范围']} LIKE ?"
params.append(f'%{scope}%') params.append(f'%{scope}%')
cursor.execute(query, params) async with await self.get_db_connection() as conn:
cursor = await conn.execute(query, params)
# 获取结果 # 获取结果
results = cursor.fetchall() results = await cursor.fetchall()
data = [dict(row) for row in results] data = [dict(row) for row in results]
# 对结果进行排序,将全字匹配的排在前面 # 对结果进行排序,将全字匹配的排在前面
data = self.sort_results_by_exact_match(data, code, field_map['编码']) data = self.sort_results_by_exact_match(data, code, field_map['编码'])
data = data[:self.top_k] data = data[:self.top_k]
conn.close()
if not data: if not data:
return {"success": True, "message": "未找到匹配的清单信息", "data": []} return {"success": True, "message": "未找到匹配的清单信息", "data": []}
@@ -427,8 +413,8 @@ class QingDanDingEQueryService:
except Exception as e: except Exception as e:
return {"success": False, "message": f"查询出错: {str(e)}"} return {"success": False, "message": f"查询出错: {str(e)}"}
def batch_query(self, requests:BatchQueryRequest): async def batch_query(self, requests:BatchQueryRequest):
"""批量查询接口,支持向量检索""" """批量查询接口,支持向量检索(并发执行)"""
dinge_results = [] dinge_results = []
qingdan_results = [] qingdan_results = []
tracking_dict = {} # 用于跟踪已查询过的项目,避免重复 tracking_dict = {} # 用于跟踪已查询过的项目,避免重复
@@ -439,42 +425,49 @@ class QingDanDingEQueryService:
qingdan_info = requests.dinge_qingdan_info.qingdan_info qingdan_info = requests.dinge_qingdan_info.qingdan_info
scope = requests.scope scope = requests.scope
dinge_tasks = []
qingdan_tasks = []
# 处理定额编码查询 # 处理定额编码查询
for code in dinge_info.dinge_code_list or []: for code in dinge_info.dinge_code_list or []:
key = f"dinge_code_{code}_{scope}" key = f"dinge_code_{code}_{scope}"
if key not in tracking_dict: if key not in tracking_dict:
result = self.query_ding_e_by_code(code, scope) dinge_tasks.append(self.query_ding_e_by_code(code, scope))
if result["success"] and result["data"]:
dinge_results.extend(result["data"])
tracking_dict[key] = True tracking_dict[key] = True
# 处理定额名称查询 # 处理定额名称查询
for name in dinge_info.dinge_name_list or []: for name in dinge_info.dinge_name_list or []:
key = f"dinge_name_{name}_{scope}" key = f"dinge_name_{name}_{scope}"
if key not in tracking_dict: if key not in tracking_dict:
result = self.query_ding_e_by_name(name, scope) dinge_tasks.append(self.query_ding_e_by_name(name, scope))
if result["success"] and result["data"]:
dinge_results.extend(result["data"])
tracking_dict[key] = True tracking_dict[key] = True
# 处理清单编码查询 # 处理清单编码查询
for code in qingdan_info.qingdan_code_list or []: for code in qingdan_info.qingdan_code_list or []:
key = f"qingdan_code_{code}_{scope}" key = f"qingdan_code_{code}_{scope}"
if key not in tracking_dict: if key not in tracking_dict:
result = self.query_qing_dan_by_code(code, scope) qingdan_tasks.append(self.query_qing_dan_by_code(code, scope))
if result["success"] and result["data"]:
qingdan_results.extend(result["data"])
tracking_dict[key] = True tracking_dict[key] = True
# 处理清单名称查询 # 处理清单名称查询
for name in qingdan_info.qingdan_name_list or []: for name in qingdan_info.qingdan_name_list or []:
key = f"qingdan_name_{name}_{scope}" key = f"qingdan_name_{name}_{scope}"
if key not in tracking_dict: if key not in tracking_dict:
result = self.query_qing_dan_by_name(name, scope) qingdan_tasks.append(self.query_qing_dan_by_name(name, scope))
if result["success"] and result["data"]:
qingdan_results.extend(result["data"])
tracking_dict[key] = True tracking_dict[key] = True
# 并发执行
if dinge_tasks:
dinge_outs = await asyncio.gather(*dinge_tasks)
for result in dinge_outs:
if result and result.get("success") and result.get("data"):
dinge_results.extend(result["data"])
if qingdan_tasks:
qingdan_outs = await asyncio.gather(*qingdan_tasks)
for result in qingdan_outs:
if result and result.get("success") and result.get("data"):
qingdan_results.extend(result["data"])
# 限制返回结果数量 # 限制返回结果数量
dinge_results = dinge_results[:self.top_k] dinge_results = dinge_results[:self.top_k]
qingdan_results = qingdan_results[:self.top_k] qingdan_results = qingdan_results[:self.top_k]
@@ -505,7 +498,7 @@ async def query_ding_e_by_name(
name: str = Query(..., description="定额名称"), name: str = Query(..., description="定额名称"),
scope: Optional[str] = Query(None, description="适用范围") scope: Optional[str] = Query(None, description="适用范围")
): ):
result = query_service.query_ding_e_by_name(name, scope) result = await query_service.query_ding_e_by_name(name, scope)
if not result["success"]: if not result["success"]:
raise HTTPException(status_code=500, detail=result["message"]) raise HTTPException(status_code=500, detail=result["message"])
return QueryResponse(**result) return QueryResponse(**result)
@@ -516,7 +509,7 @@ async def query_ding_e_by_code(
code: str = Query(..., description="定额编码"), code: str = Query(..., description="定额编码"),
scope: Optional[str] = Query(None, description="适用范围") scope: Optional[str] = Query(None, description="适用范围")
): ):
result = query_service.query_ding_e_by_code(code, scope) result = await query_service.query_ding_e_by_code(code, scope)
if not result["success"]: if not result["success"]:
raise HTTPException(status_code=500, detail=result["message"]) raise HTTPException(status_code=500, detail=result["message"])
return QueryResponse(**result) return QueryResponse(**result)
@@ -527,7 +520,7 @@ async def query_qing_dan_by_name(
name: str = Query(..., description="清单名称"), name: str = Query(..., description="清单名称"),
scope: Optional[str] = Query(None, description="适用范围") scope: Optional[str] = Query(None, description="适用范围")
): ):
result = query_service.query_qing_dan_by_name(name, scope) result = await query_service.query_qing_dan_by_name(name, scope)
if not result["success"]: if not result["success"]:
raise HTTPException(status_code=500, detail=result["message"]) raise HTTPException(status_code=500, detail=result["message"])
return QueryResponse(**result) return QueryResponse(**result)
@@ -538,7 +531,7 @@ async def query_qing_dan_by_code(
code: str = Query(..., description="清单编码"), code: str = Query(..., description="清单编码"),
scope: Optional[str] = Query(None, description="适用范围") scope: Optional[str] = Query(None, description="适用范围")
): ):
result = query_service.query_qing_dan_by_code(code, scope) result = await query_service.query_qing_dan_by_code(code, scope)
if not result["success"]: if not result["success"]:
raise HTTPException(status_code=500, detail=result["message"]) raise HTTPException(status_code=500, detail=result["message"])
return QueryResponse(**result) return QueryResponse(**result)
@@ -546,7 +539,7 @@ async def query_qing_dan_by_code(
# 5. 批量查询定额和清单信息 # 5. 批量查询定额和清单信息
@app.post("/api/batch_query", response_model=BatchQueryResponse) @app.post("/api/batch_query", response_model=BatchQueryResponse)
async def batch_query(request: BatchQueryRequest): async def batch_query(request: BatchQueryRequest):
result = query_service.batch_query(request) result = await query_service.batch_query(request)
if not result["success"]: if not result["success"]:
raise HTTPException(status_code=500, detail=result["message"]) raise HTTPException(status_code=500, detail=result["message"])
return BatchQueryResponse(**result) return BatchQueryResponse(**result)
@@ -564,4 +557,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()
# uvicorn rag2_0.dify.query_dinge_qingdan_api:app --host 0.0.0.0 --port 8005 --workers 10 # uvicorn rag2_0.api.query_dinge_qingdan_api:app --host 0.0.0.0 --port 8005 --workers 10
-18
View File
@@ -1,18 +0,0 @@
#!/bin/bash
# 获取当前脚本所在的绝对路径
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
# 检查是否已经存在名为AnswerType的screen会话
if screen -ls | grep -q "\.AnswerType\s"; then
echo "Screen session 'AnswerType' already exists."
else
# 启动一个名为AnswerType的screen会话,并在其中执行后续命令
screen -dmS AnswerType bash -c "
cd \"$SCRIPT_DIR\"
uv run uvicorn rag2_0.api.AnswerType_api:app --host 0.0.0.0 --port 8003 --workers 1
"
# 输出提示信息
echo "Started screen session 'AnswerType' and executed the command."
fi
-18
View File
@@ -1,18 +0,0 @@
#!/bin/bash
# 获取当前脚本所在的绝对路径
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
# 检查是否已经存在名为DifyQueryRetrieval_api的screen会话
if screen -ls | grep "DifyQueryRetrieval_api"; then
echo "Screen session 'DifyQueryRetrieval_api' already exists."
else
# 启动一个名为DifyQueryRetrieval_api的screen会话,并在其中执行后续命令
screen -dmS DifyQueryRetrieval_api bash -c "
cd \"$SCRIPT_DIR\"
uv run uvicorn rag2_0.api.DifyQueryRetrieval_api:app --host 0.0.0.0 --port 8002 --workers 25
"
# 输出提示信息
echo "Started screen session 'DifyQueryRetrieval_api' and executed the command."
fi
-18
View File
@@ -1,18 +0,0 @@
#!/bin/bash
# 获取当前脚本所在的绝对路径
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
# 检查是否已经存在名为xinference的screen会话
if screen -ls | grep "intent_recognition_api"; then
echo "Screen session 'intent_recognition_api' already exists."
else
# 启动一个名为intent_recognition_api的screen会话,并在其中执行后续命令
screen -dmS intent_recognition_api bash -c "
cd \"$SCRIPT_DIR\"
uv run uvicorn rag2_0.api.intent_recognition_api:app --host 0.0.0.0 --port 8001 --workers 25
"
# 输出提示信息
echo "Started screen session 'intent_recognition_api' and executed the command."
fi
Generated
+14
View File
@@ -97,6 +97,18 @@ wheels = [
{ url = "https://mirrors.aliyun.com/pypi/packages/ec/6a/bc7e17a3e87a2985d3e8f4da4cd0f481060eb78fb08596c42be62c90a4d9/aiosignal-1.3.2-py2.py3-none-any.whl", hash = "sha256:45cde58e409a301715980c2b01d0c28bdde3770d8290b5eb2173759d9acb31a5" }, { url = "https://mirrors.aliyun.com/pypi/packages/ec/6a/bc7e17a3e87a2985d3e8f4da4cd0f481060eb78fb08596c42be62c90a4d9/aiosignal-1.3.2-py2.py3-none-any.whl", hash = "sha256:45cde58e409a301715980c2b01d0c28bdde3770d8290b5eb2173759d9acb31a5" },
] ]
[[package]]
name = "aiosqlite"
version = "0.21.0"
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
dependencies = [
{ name = "typing-extensions" },
]
sdist = { url = "https://mirrors.aliyun.com/pypi/packages/13/7d/8bca2bf9a247c2c5dfeec1d7a5f40db6518f88d314b8bca9da29670d2671/aiosqlite-0.21.0.tar.gz", hash = "sha256:131bb8056daa3bc875608c631c678cda73922a2d4ba8aec373b19f18c17e7aa3" }
wheels = [
{ url = "https://mirrors.aliyun.com/pypi/packages/f5/10/6c25ed6de94c49f88a91fa5018cb4c0f3625f31d5be9f771ebe5cc7cd506/aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0" },
]
[[package]] [[package]]
name = "annotated-types" name = "annotated-types"
version = "0.7.0" version = "0.7.0"
@@ -1519,6 +1531,7 @@ name = "rag2-0"
version = "0.1.0" version = "0.1.0"
source = { virtual = "." } source = { virtual = "." }
dependencies = [ dependencies = [
{ name = "aiosqlite" },
{ name = "bs4" }, { name = "bs4" },
{ name = "faiss-cpu" }, { name = "faiss-cpu" },
{ name = "fastapi" }, { name = "fastapi" },
@@ -1548,6 +1561,7 @@ dependencies = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "aiosqlite", specifier = ">=0.20.0" },
{ name = "bs4", specifier = ">=0.0.2" }, { name = "bs4", specifier = ">=0.0.2" },
{ name = "faiss-cpu", specifier = ">=1.11.0" }, { name = "faiss-cpu", specifier = ">=1.11.0" },
{ name = "fastapi", specifier = ">=0.115.14" }, { name = "fastapi", specifier = ">=0.115.14" },