新增数据库支持,初始化数据库并创建查询类型和点踩原因表,优化日志记录,添加多个API以支持点踩原因和查询统计功能。
This commit is contained in:
+196
-40
@@ -10,6 +10,8 @@ from typing import Dict, List, Any, Optional
|
|||||||
import asyncio
|
import asyncio
|
||||||
import threading
|
import threading
|
||||||
import queue
|
import queue
|
||||||
|
import sqlite3
|
||||||
|
from contextlib import closing
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
import json
|
import json
|
||||||
@@ -31,14 +33,47 @@ import sys
|
|||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval
|
from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval
|
||||||
|
|
||||||
# 定义文件锁和JSON文件路径
|
# 定义数据库路径
|
||||||
file_lock = asyncio.Lock()
|
DATA_DIR = os.path.join(os.getcwd(), "data")
|
||||||
QUERY_LOG_DIR = os.path.join(os.getcwd(), "data", "query_logs")
|
DB_DIR = os.path.join(DATA_DIR, "db")
|
||||||
QUERY_DATA_FILE = os.path.join(QUERY_LOG_DIR, "answer_type_logs.json")
|
DB_FILE = os.path.join(DB_DIR, "answer_logs.db")
|
||||||
|
|
||||||
# 创建异步日志队列和工作线程
|
# 创建异步日志队列和工作线程
|
||||||
log_queue = queue.Queue()
|
log_queue = queue.Queue()
|
||||||
worker_thread = None
|
worker_thread = None
|
||||||
|
db_lock = threading.Lock() # 数据库操作锁
|
||||||
|
|
||||||
|
# 初始化数据库
|
||||||
|
def init_database():
|
||||||
|
os.makedirs(DB_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
with db_lock:
|
||||||
|
with closing(sqlite3.connect(DB_FILE)) as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# 创建查询类型表
|
||||||
|
cursor.execute('''
|
||||||
|
CREATE TABLE IF NOT EXISTS query_types (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
query_type TEXT NOT NULL,
|
||||||
|
workflow_run_id TEXT NOT NULL,
|
||||||
|
timestamp TEXT NOT NULL
|
||||||
|
)
|
||||||
|
''')
|
||||||
|
|
||||||
|
# 创建点踩原因表
|
||||||
|
cursor.execute('''
|
||||||
|
CREATE TABLE IF NOT EXISTS dislike_reasons (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
dislike_reason TEXT NOT NULL,
|
||||||
|
workflow_run_id TEXT NOT NULL,
|
||||||
|
timestamp TEXT NOT NULL
|
||||||
|
)
|
||||||
|
''')
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
logger.info("数据库初始化完成")
|
||||||
|
|
||||||
# 后台工作线程函数
|
# 后台工作线程函数
|
||||||
def log_worker():
|
def log_worker():
|
||||||
@@ -65,44 +100,49 @@ def log_worker():
|
|||||||
# 提取数据处理逻辑到单独函数
|
# 提取数据处理逻辑到单独函数
|
||||||
def process_log_data(data):
|
def process_log_data(data):
|
||||||
try:
|
try:
|
||||||
# 确保目录存在
|
with db_lock:
|
||||||
os.makedirs(os.path.dirname(QUERY_DATA_FILE), exist_ok=True)
|
with closing(sqlite3.connect(DB_FILE)) as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# 读取现有数据
|
if "dislike_reason" in data:
|
||||||
existing_data = []
|
# 保存点踩原因
|
||||||
if os.path.exists(QUERY_DATA_FILE) and os.path.getsize(QUERY_DATA_FILE) > 0:
|
cursor.execute(
|
||||||
with open(QUERY_DATA_FILE, 'r', encoding='utf-8') as f:
|
"INSERT INTO dislike_reasons (dislike_reason, workflow_run_id, timestamp) VALUES (?, ?, ?)",
|
||||||
try:
|
(data["dislike_reason"], data["workflow_run_id"], data["timestamp"])
|
||||||
existing_data = json.load(f)
|
)
|
||||||
except json.JSONDecodeError:
|
table_name = "dislike_reasons"
|
||||||
logger.error(f"JSON文件解析错误,将创建新文件: {QUERY_DATA_FILE}")
|
else:
|
||||||
existing_data = []
|
# 保存查询类型
|
||||||
|
cursor.execute(
|
||||||
|
"INSERT INTO query_types (query_type, workflow_run_id, timestamp) VALUES (?, ?, ?)",
|
||||||
|
(data["query_type"], data["workflow_run_id"], data["timestamp"])
|
||||||
|
)
|
||||||
|
table_name = "query_types"
|
||||||
|
|
||||||
# 添加新数据
|
conn.commit()
|
||||||
existing_data.append(data)
|
|
||||||
|
|
||||||
# 写入文件
|
logger.info(f"成功保存数据到表: {table_name}")
|
||||||
with open(QUERY_DATA_FILE, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(existing_data, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
logger.info(f"成功保存查询数据到: {QUERY_DATA_FILE}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理日志数据时出错: {str(e)}", exc_info=True)
|
logger.error(f"处理日志数据时出错: {str(e)}", exc_info=True)
|
||||||
|
|
||||||
# 创建日志目录
|
# 创建数据目录
|
||||||
os.makedirs(QUERY_LOG_DIR, exist_ok=True)
|
os.makedirs(DATA_DIR, exist_ok=True)
|
||||||
|
|
||||||
# 配置日志 - 同时输出到控制台和文件
|
# 配置日志 - 同时输出到控制台和文件
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
# 创建日志目录
|
||||||
|
LOG_DIR = os.path.join(DATA_DIR, "logs")
|
||||||
|
os.makedirs(LOG_DIR, exist_ok=True)
|
||||||
|
|
||||||
# 创建控制台处理器
|
# 创建控制台处理器
|
||||||
console_handler = logging.StreamHandler()
|
console_handler = logging.StreamHandler()
|
||||||
console_handler.setLevel(logging.INFO)
|
console_handler.setLevel(logging.INFO)
|
||||||
|
|
||||||
# 创建文件处理器
|
# 创建文件处理器
|
||||||
file_handler = logging.FileHandler(
|
file_handler = logging.FileHandler(
|
||||||
os.path.join(QUERY_LOG_DIR, "answer_type_service.log"),
|
os.path.join(LOG_DIR, "answer_type_service.log"),
|
||||||
encoding='utf-8'
|
encoding='utf-8'
|
||||||
)
|
)
|
||||||
file_handler.setLevel(logging.INFO)
|
file_handler.setLevel(logging.INFO)
|
||||||
@@ -145,12 +185,8 @@ app.add_middleware(
|
|||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
global worker_thread
|
global worker_thread
|
||||||
# 确保日志目录存在
|
# 初始化数据库
|
||||||
os.makedirs(QUERY_LOG_DIR, exist_ok=True)
|
init_database()
|
||||||
# 确保日志文件存在
|
|
||||||
if not os.path.exists(QUERY_DATA_FILE):
|
|
||||||
with open(QUERY_DATA_FILE, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump([], f, ensure_ascii=False)
|
|
||||||
|
|
||||||
# 启动后台工作线程
|
# 启动后台工作线程
|
||||||
worker_thread = threading.Thread(target=log_worker, daemon=True)
|
worker_thread = threading.Thread(target=log_worker, daemon=True)
|
||||||
@@ -206,18 +242,138 @@ async def query_type(query_type: str, workflow_run_id:str):
|
|||||||
logger.error(f"处理请求时出错: {str(e)}", exc_info=True)
|
logger.error(f"处理请求时出错: {str(e)}", exc_info=True)
|
||||||
raise HTTPException(status_code=500, detail=f"处理请求时出错: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"处理请求时出错: {str(e)}")
|
||||||
|
|
||||||
|
@app.get("/dislike_reason", summary="记录点踩原因")
|
||||||
|
async def dislike_reason(reason: str, workflow_run_id: str):
|
||||||
|
try:
|
||||||
|
# 记录请求
|
||||||
|
logger.info(f"接收到点踩原因: {reason}, workflow_run_id: {workflow_run_id}")
|
||||||
|
|
||||||
|
# 准备数据
|
||||||
|
timestamp = datetime.datetime.now().isoformat()
|
||||||
|
dislike_data = {
|
||||||
|
"dislike_reason": reason,
|
||||||
|
"timestamp": timestamp,
|
||||||
|
"workflow_run_id": workflow_run_id
|
||||||
|
}
|
||||||
|
|
||||||
|
# 将数据放入队列
|
||||||
|
try:
|
||||||
|
log_queue.put(dislike_data)
|
||||||
|
success = True
|
||||||
|
logger.info(f"点踩原因数据已加入队列,当前队列大小: {log_queue.qsize()}")
|
||||||
|
except Exception as e:
|
||||||
|
success = False
|
||||||
|
logger.error(f"加入队列时出错: {str(e)}", exc_info=True)
|
||||||
|
|
||||||
|
# 返回响应
|
||||||
|
content = f"<strong>点踩原因</strong>: {reason}<br><strong>操作是否成功</strong>: {'成功' if success else '失败'}"
|
||||||
|
return HTMLResponse(content=content)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理点踩原因请求时出错: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"处理点踩原因请求时出错: {str(e)}")
|
||||||
|
|
||||||
|
# 添加数据查询API
|
||||||
|
@app.get("/stats", summary="查询统计数据")
|
||||||
|
async def get_stats():
|
||||||
|
try:
|
||||||
|
with db_lock:
|
||||||
|
with closing(sqlite3.connect(DB_FILE)) as conn:
|
||||||
|
conn.row_factory = sqlite3.Row # 启用字典行工厂
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# 查询类型统计
|
||||||
|
cursor.execute("SELECT COUNT(*) as count FROM query_types")
|
||||||
|
query_count = cursor.fetchone()['count']
|
||||||
|
|
||||||
|
# 点踩原因统计
|
||||||
|
cursor.execute("SELECT COUNT(*) as count FROM dislike_reasons")
|
||||||
|
dislike_count = cursor.fetchone()['count']
|
||||||
|
|
||||||
|
# 最近5条查询记录
|
||||||
|
cursor.execute("""
|
||||||
|
SELECT query_type, workflow_run_id, timestamp
|
||||||
|
FROM query_types
|
||||||
|
ORDER BY id DESC LIMIT 5
|
||||||
|
""")
|
||||||
|
recent_queries = [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
# 最近5条点踩记录
|
||||||
|
cursor.execute("""
|
||||||
|
SELECT dislike_reason, workflow_run_id, timestamp
|
||||||
|
FROM dislike_reasons
|
||||||
|
ORDER BY id DESC LIMIT 5
|
||||||
|
""")
|
||||||
|
recent_dislikes = [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"statistics": {
|
||||||
|
"total_queries": query_count,
|
||||||
|
"total_dislikes": dislike_count
|
||||||
|
},
|
||||||
|
"recent_data": {
|
||||||
|
"queries": recent_queries,
|
||||||
|
"dislikes": recent_dislikes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取统计数据时出错: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"获取统计数据时出错: {str(e)}")
|
||||||
|
|
||||||
|
@app.get("/query_by_workflow_id", summary="根据工作流ID获取查询类型数据")
|
||||||
|
async def get_query_by_workflow_id(workflow_run_id: str):
|
||||||
|
try:
|
||||||
|
with db_lock:
|
||||||
|
with closing(sqlite3.connect(DB_FILE)) as conn:
|
||||||
|
conn.row_factory = sqlite3.Row # 启用字典行工厂
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# 查询指定工作流ID的查询类型数据
|
||||||
|
cursor.execute("""
|
||||||
|
SELECT id, query_type, workflow_run_id, timestamp
|
||||||
|
FROM query_types
|
||||||
|
WHERE workflow_run_id = ?
|
||||||
|
ORDER BY id DESC
|
||||||
|
""", (workflow_run_id,))
|
||||||
|
|
||||||
|
results = [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return {"message": f"未找到工作流ID为 {workflow_run_id} 的查询类型数据", "data": []}
|
||||||
|
|
||||||
|
return {"data": results}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"根据工作流ID获取查询类型数据时出错: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"根据工作流ID获取查询类型数据时出错: {str(e)}")
|
||||||
|
|
||||||
|
@app.get("/dislike_by_workflow_id", summary="根据工作流ID获取点踩原因数据")
|
||||||
|
async def get_dislike_by_workflow_id(workflow_run_id: str):
|
||||||
|
try:
|
||||||
|
with db_lock:
|
||||||
|
with closing(sqlite3.connect(DB_FILE)) as conn:
|
||||||
|
conn.row_factory = sqlite3.Row # 启用字典行工厂
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# 查询指定工作流ID的点踩原因数据
|
||||||
|
cursor.execute("""
|
||||||
|
SELECT id, dislike_reason, workflow_run_id, timestamp
|
||||||
|
FROM dislike_reasons
|
||||||
|
WHERE workflow_run_id = ?
|
||||||
|
ORDER BY id DESC
|
||||||
|
""", (workflow_run_id,))
|
||||||
|
|
||||||
|
results = [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return {"message": f"未找到工作流ID为 {workflow_run_id} 的点踩原因数据", "data": []}
|
||||||
|
|
||||||
|
return {"data": results}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"根据工作流ID获取点踩原因数据时出错: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"根据工作流ID获取点踩原因数据时出错: {str(e)}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 使用Uvicorn运行FastAPI应用
|
# 使用Uvicorn运行FastAPI应用
|
||||||
import uvicorn
|
import uvicorn
|
||||||
uvicorn.run("rag2_0.dify.AnswerType:app", host="0.0.0.0", port=8003, reload=False, workers=1, log_level="info")
|
uvicorn.run("rag2_0.dify.AnswerType:app", host="0.0.0.0", port=8003, reload=False, workers=1, log_level="info")
|
||||||
# # 使用uvicorn启动服务
|
|
||||||
# import uvicorn
|
|
||||||
# uvicorn.run(
|
|
||||||
# "rag2_0.dify.intent_recognition_api:app",
|
|
||||||
# host="0.0.0.0",
|
|
||||||
# port=8001,
|
|
||||||
# reload=False, # 开发环境启用热重载
|
|
||||||
# workers=1 # 生产环境可以增加worker数量
|
|
||||||
# )
|
|
||||||
# 生产环境可以使用以下命令启动:
|
# 生产环境可以使用以下命令启动:
|
||||||
# uvicorn rag2_0.dify.AnswerType:app --host 0.0.0.0 --port 8003 --workers 1
|
# uvicorn rag2_0.dify.AnswerType:app --host 0.0.0.0 --port 8003 --workers 1
|
||||||
@@ -21,12 +21,15 @@ import sys
|
|||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval
|
from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval
|
||||||
|
|
||||||
|
# 确保日志目录存在
|
||||||
|
os.makedirs('data/logs', exist_ok=True)
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
format='%(asctime)s - %(name)s - [%(thread)d] - %(levelname)s - %(message)s',
|
||||||
handlers=[
|
handlers=[
|
||||||
logging.StreamHandler()
|
logging.StreamHandler(),
|
||||||
|
logging.FileHandler(f'data/logs/dify_query_retrieval_{datetime.datetime.now().strftime("%Y%m%d")}.log', encoding='utf-8')
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||||
|
|||||||
@@ -357,6 +357,9 @@ class DifyTool:
|
|||||||
"""
|
"""
|
||||||
return self.dify_pgsql.get_app_conversations(conversation_id)
|
return self.dify_pgsql.get_app_conversations(conversation_id)
|
||||||
|
|
||||||
|
def get_workflow_node_executions_info(self, workflow_run_id:str):
|
||||||
|
return self.dify_pgsql.get_workflow_node_executions_info(workflow_run_id)
|
||||||
|
|
||||||
def get_message_rating(self, msg_id):
|
def get_message_rating(self, msg_id):
|
||||||
return self.dify_pgsql.get_message_rating(msg_id)
|
return self.dify_pgsql.get_message_rating(msg_id)
|
||||||
|
|
||||||
|
|||||||
@@ -18,12 +18,15 @@ import sys
|
|||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
from rag2_0.intent_recognition import AsyncIntentRecognizer
|
from rag2_0.intent_recognition import AsyncIntentRecognizer
|
||||||
|
|
||||||
|
# 确保日志目录存在
|
||||||
|
os.makedirs('data/logs', exist_ok=True)
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
format='%(asctime)s - %(process)d - %(thread)d - %(name)s - %(levelname)s - %(message)s',
|
||||||
handlers=[
|
handlers=[
|
||||||
logging.StreamHandler()
|
logging.StreamHandler(),
|
||||||
|
logging.FileHandler(f'data/logs/intent_recognition_{datetime.datetime.now().strftime("%Y%m%d")}.log', encoding='utf-8')
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||||
@@ -118,7 +121,7 @@ async def intent_recognize(request: IntentRecognizeRequest):
|
|||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S %z")
|
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S %z")
|
||||||
logger.info(f"[{os.getpid()}] 意图识别耗时: {end_time - start_time:.2f}秒")
|
logger.info(f"意图识别耗时: {end_time - start_time:.2f}秒")
|
||||||
|
|
||||||
# 提取分类信息
|
# 提取分类信息
|
||||||
classification = result["classification"]
|
classification = result["classification"]
|
||||||
|
|||||||
Reference in New Issue
Block a user