新增数据库支持,初始化数据库并创建查询类型和点踩原因表,优化日志记录,添加多个API以支持点踩原因和查询统计功能。

This commit is contained in:
2025-07-26 12:36:01 +08:00
parent 3f6f5d038c
commit 780f423200
4 changed files with 213 additions and 48 deletions
+199 -43
View File
@@ -10,6 +10,8 @@ from typing import Dict, List, Any, Optional
import asyncio import asyncio
import threading import threading
import queue import queue
import sqlite3
from contextlib import closing
from dotenv import load_dotenv from dotenv import load_dotenv
import json import json
@@ -31,14 +33,47 @@ import sys
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval
# 定义文件锁和JSON文件路径 # 定义数据库路径
file_lock = asyncio.Lock() DATA_DIR = os.path.join(os.getcwd(), "data")
QUERY_LOG_DIR = os.path.join(os.getcwd(), "data", "query_logs") DB_DIR = os.path.join(DATA_DIR, "db")
QUERY_DATA_FILE = os.path.join(QUERY_LOG_DIR, "answer_type_logs.json") DB_FILE = os.path.join(DB_DIR, "answer_logs.db")
# 创建异步日志队列和工作线程 # 创建异步日志队列和工作线程
log_queue = queue.Queue() log_queue = queue.Queue()
worker_thread = None worker_thread = None
db_lock = threading.Lock() # 数据库操作锁
# 初始化数据库
def init_database():
os.makedirs(DB_DIR, exist_ok=True)
with db_lock:
with closing(sqlite3.connect(DB_FILE)) as conn:
cursor = conn.cursor()
# 创建查询类型表
cursor.execute('''
CREATE TABLE IF NOT EXISTS query_types (
id INTEGER PRIMARY KEY AUTOINCREMENT,
query_type TEXT NOT NULL,
workflow_run_id TEXT NOT NULL,
timestamp TEXT NOT NULL
)
''')
# 创建点踩原因表
cursor.execute('''
CREATE TABLE IF NOT EXISTS dislike_reasons (
id INTEGER PRIMARY KEY AUTOINCREMENT,
dislike_reason TEXT NOT NULL,
workflow_run_id TEXT NOT NULL,
timestamp TEXT NOT NULL
)
''')
conn.commit()
logger.info("数据库初始化完成")
# 后台工作线程函数 # 后台工作线程函数
def log_worker(): def log_worker():
@@ -65,44 +100,49 @@ def log_worker():
# 提取数据处理逻辑到单独函数 # 提取数据处理逻辑到单独函数
def process_log_data(data): def process_log_data(data):
try: try:
# 确保目录存在 with db_lock:
os.makedirs(os.path.dirname(QUERY_DATA_FILE), exist_ok=True) with closing(sqlite3.connect(DB_FILE)) as conn:
cursor = conn.cursor()
# 读取现有数据
existing_data = [] if "dislike_reason" in data:
if os.path.exists(QUERY_DATA_FILE) and os.path.getsize(QUERY_DATA_FILE) > 0: # 保存点踩原因
with open(QUERY_DATA_FILE, 'r', encoding='utf-8') as f: cursor.execute(
try: "INSERT INTO dislike_reasons (dislike_reason, workflow_run_id, timestamp) VALUES (?, ?, ?)",
existing_data = json.load(f) (data["dislike_reason"], data["workflow_run_id"], data["timestamp"])
except json.JSONDecodeError: )
logger.error(f"JSON文件解析错误,将创建新文件: {QUERY_DATA_FILE}") table_name = "dislike_reasons"
existing_data = [] else:
# 保存查询类型
# 添加新数据 cursor.execute(
existing_data.append(data) "INSERT INTO query_types (query_type, workflow_run_id, timestamp) VALUES (?, ?, ?)",
(data["query_type"], data["workflow_run_id"], data["timestamp"])
# 写入文件 )
with open(QUERY_DATA_FILE, 'w', encoding='utf-8') as f: table_name = "query_types"
json.dump(existing_data, f, ensure_ascii=False, indent=2)
conn.commit()
logger.info(f"成功保存查询数据到: {QUERY_DATA_FILE}")
logger.info(f"成功保存数据到表: {table_name}")
except Exception as e: except Exception as e:
logger.error(f"处理日志数据时出错: {str(e)}", exc_info=True) logger.error(f"处理日志数据时出错: {str(e)}", exc_info=True)
# 创建日志目录 # 创建数据目录
os.makedirs(QUERY_LOG_DIR, exist_ok=True) os.makedirs(DATA_DIR, exist_ok=True)
# 配置日志 - 同时输出到控制台和文件 # 配置日志 - 同时输出到控制台和文件
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
# 创建日志目录
LOG_DIR = os.path.join(DATA_DIR, "logs")
os.makedirs(LOG_DIR, exist_ok=True)
# 创建控制台处理器 # 创建控制台处理器
console_handler = logging.StreamHandler() console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO) console_handler.setLevel(logging.INFO)
# 创建文件处理器 # 创建文件处理器
file_handler = logging.FileHandler( file_handler = logging.FileHandler(
os.path.join(QUERY_LOG_DIR, "answer_type_service.log"), os.path.join(LOG_DIR, "answer_type_service.log"),
encoding='utf-8' encoding='utf-8'
) )
file_handler.setLevel(logging.INFO) file_handler.setLevel(logging.INFO)
@@ -145,12 +185,8 @@ app.add_middleware(
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():
global worker_thread global worker_thread
# 确保日志目录存在 # 初始化数据库
os.makedirs(QUERY_LOG_DIR, exist_ok=True) init_database()
# 确保日志文件存在
if not os.path.exists(QUERY_DATA_FILE):
with open(QUERY_DATA_FILE, 'w', encoding='utf-8') as f:
json.dump([], f, ensure_ascii=False)
# 启动后台工作线程 # 启动后台工作线程
worker_thread = threading.Thread(target=log_worker, daemon=True) worker_thread = threading.Thread(target=log_worker, daemon=True)
@@ -206,18 +242,138 @@ async def query_type(query_type: str, workflow_run_id:str):
logger.error(f"处理请求时出错: {str(e)}", exc_info=True) logger.error(f"处理请求时出错: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"处理请求时出错: {str(e)}") raise HTTPException(status_code=500, detail=f"处理请求时出错: {str(e)}")
@app.get("/dislike_reason", summary="记录点踩原因")
async def dislike_reason(reason: str, workflow_run_id: str):
try:
# 记录请求
logger.info(f"接收到点踩原因: {reason}, workflow_run_id: {workflow_run_id}")
# 准备数据
timestamp = datetime.datetime.now().isoformat()
dislike_data = {
"dislike_reason": reason,
"timestamp": timestamp,
"workflow_run_id": workflow_run_id
}
# 将数据放入队列
try:
log_queue.put(dislike_data)
success = True
logger.info(f"点踩原因数据已加入队列,当前队列大小: {log_queue.qsize()}")
except Exception as e:
success = False
logger.error(f"加入队列时出错: {str(e)}", exc_info=True)
# 返回响应
content = f"<strong>点踩原因</strong>: {reason}<br><strong>操作是否成功</strong>: {'成功' if success else '失败'}"
return HTMLResponse(content=content)
except Exception as e:
logger.error(f"处理点踩原因请求时出错: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"处理点踩原因请求时出错: {str(e)}")
# 添加数据查询API
@app.get("/stats", summary="查询统计数据")
async def get_stats():
try:
with db_lock:
with closing(sqlite3.connect(DB_FILE)) as conn:
conn.row_factory = sqlite3.Row # 启用字典行工厂
cursor = conn.cursor()
# 查询类型统计
cursor.execute("SELECT COUNT(*) as count FROM query_types")
query_count = cursor.fetchone()['count']
# 点踩原因统计
cursor.execute("SELECT COUNT(*) as count FROM dislike_reasons")
dislike_count = cursor.fetchone()['count']
# 最近5条查询记录
cursor.execute("""
SELECT query_type, workflow_run_id, timestamp
FROM query_types
ORDER BY id DESC LIMIT 5
""")
recent_queries = [dict(row) for row in cursor.fetchall()]
# 最近5条点踩记录
cursor.execute("""
SELECT dislike_reason, workflow_run_id, timestamp
FROM dislike_reasons
ORDER BY id DESC LIMIT 5
""")
recent_dislikes = [dict(row) for row in cursor.fetchall()]
return {
"statistics": {
"total_queries": query_count,
"total_dislikes": dislike_count
},
"recent_data": {
"queries": recent_queries,
"dislikes": recent_dislikes
}
}
except Exception as e:
logger.error(f"获取统计数据时出错: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"获取统计数据时出错: {str(e)}")
@app.get("/query_by_workflow_id", summary="根据工作流ID获取查询类型数据")
async def get_query_by_workflow_id(workflow_run_id: str):
try:
with db_lock:
with closing(sqlite3.connect(DB_FILE)) as conn:
conn.row_factory = sqlite3.Row # 启用字典行工厂
cursor = conn.cursor()
# 查询指定工作流ID的查询类型数据
cursor.execute("""
SELECT id, query_type, workflow_run_id, timestamp
FROM query_types
WHERE workflow_run_id = ?
ORDER BY id DESC
""", (workflow_run_id,))
results = [dict(row) for row in cursor.fetchall()]
if not results:
return {"message": f"未找到工作流ID为 {workflow_run_id} 的查询类型数据", "data": []}
return {"data": results}
except Exception as e:
logger.error(f"根据工作流ID获取查询类型数据时出错: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"根据工作流ID获取查询类型数据时出错: {str(e)}")
@app.get("/dislike_by_workflow_id", summary="根据工作流ID获取点踩原因数据")
async def get_dislike_by_workflow_id(workflow_run_id: str):
try:
with db_lock:
with closing(sqlite3.connect(DB_FILE)) as conn:
conn.row_factory = sqlite3.Row # 启用字典行工厂
cursor = conn.cursor()
# 查询指定工作流ID的点踩原因数据
cursor.execute("""
SELECT id, dislike_reason, workflow_run_id, timestamp
FROM dislike_reasons
WHERE workflow_run_id = ?
ORDER BY id DESC
""", (workflow_run_id,))
results = [dict(row) for row in cursor.fetchall()]
if not results:
return {"message": f"未找到工作流ID为 {workflow_run_id} 的点踩原因数据", "data": []}
return {"data": results}
except Exception as e:
logger.error(f"根据工作流ID获取点踩原因数据时出错: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"根据工作流ID获取点踩原因数据时出错: {str(e)}")
if __name__ == "__main__": if __name__ == "__main__":
# 使用Uvicorn运行FastAPI应用 # 使用Uvicorn运行FastAPI应用
import uvicorn import uvicorn
uvicorn.run("rag2_0.dify.AnswerType:app", host="0.0.0.0", port=8003, reload=False, workers=1, log_level="info") uvicorn.run("rag2_0.dify.AnswerType:app", host="0.0.0.0", port=8003, reload=False, workers=1, log_level="info")
# # 使用uvicorn启动服务
# import uvicorn
# uvicorn.run(
# "rag2_0.dify.intent_recognition_api:app",
# host="0.0.0.0",
# port=8001,
# reload=False, # 开发环境启用热重载
# workers=1 # 生产环境可以增加worker数量
# )
# 生产环境可以使用以下命令启动: # 生产环境可以使用以下命令启动:
# uvicorn rag2_0.dify.AnswerType:app --host 0.0.0.0 --port 8003 --workers 1 # uvicorn rag2_0.dify.AnswerType:app --host 0.0.0.0 --port 8003 --workers 1
+5 -2
View File
@@ -21,12 +21,15 @@ import sys
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval
# 确保日志目录存在
os.makedirs('data/logs', exist_ok=True)
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', format='%(asctime)s - %(name)s - [%(thread)d] - %(levelname)s - %(message)s',
handlers=[ handlers=[
logging.StreamHandler() logging.StreamHandler(),
logging.FileHandler(f'data/logs/dify_query_retrieval_{datetime.datetime.now().strftime("%Y%m%d")}.log', encoding='utf-8')
] ]
) )
logging.getLogger('httpx').setLevel(logging.WARNING) logging.getLogger('httpx').setLevel(logging.WARNING)
+3
View File
@@ -357,6 +357,9 @@ class DifyTool:
""" """
return self.dify_pgsql.get_app_conversations(conversation_id) return self.dify_pgsql.get_app_conversations(conversation_id)
def get_workflow_node_executions_info(self, workflow_run_id:str):
return self.dify_pgsql.get_workflow_node_executions_info(workflow_run_id)
def get_message_rating(self, msg_id): def get_message_rating(self, msg_id):
return self.dify_pgsql.get_message_rating(msg_id) return self.dify_pgsql.get_message_rating(msg_id)
+6 -3
View File
@@ -18,12 +18,15 @@ import sys
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
from rag2_0.intent_recognition import AsyncIntentRecognizer from rag2_0.intent_recognition import AsyncIntentRecognizer
# 确保日志目录存在
os.makedirs('data/logs', exist_ok=True)
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', format='%(asctime)s - %(process)d - %(thread)d - %(name)s - %(levelname)s - %(message)s',
handlers=[ handlers=[
logging.StreamHandler() logging.StreamHandler(),
logging.FileHandler(f'data/logs/intent_recognition_{datetime.datetime.now().strftime("%Y%m%d")}.log', encoding='utf-8')
] ]
) )
logging.getLogger('httpx').setLevel(logging.WARNING) logging.getLogger('httpx').setLevel(logging.WARNING)
@@ -118,7 +121,7 @@ async def intent_recognize(request: IntentRecognizeRequest):
end_time = time.time() end_time = time.time()
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S %z") current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S %z")
logger.info(f"[{os.getpid()}] 意图识别耗时: {end_time - start_time:.2f}") logger.info(f"意图识别耗时: {end_time - start_time:.2f}")
# 提取分类信息 # 提取分类信息
classification = result["classification"] classification = result["classification"]