diff --git a/rag2_0/dify/AnswerType.py b/rag2_0/dify/AnswerType.py
index 8b4be05..faae6dd 100644
--- a/rag2_0/dify/AnswerType.py
+++ b/rag2_0/dify/AnswerType.py
@@ -10,6 +10,8 @@ from typing import Dict, List, Any, Optional
import asyncio
import threading
import queue
+import sqlite3
+from contextlib import closing
from dotenv import load_dotenv
import json
@@ -31,14 +33,47 @@ import sys
sys.path.append(os.getcwd())
from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval
-# 定义文件锁和JSON文件路径
-file_lock = asyncio.Lock()
-QUERY_LOG_DIR = os.path.join(os.getcwd(), "data", "query_logs")
-QUERY_DATA_FILE = os.path.join(QUERY_LOG_DIR, "answer_type_logs.json")
+# 定义数据库路径
+DATA_DIR = os.path.join(os.getcwd(), "data")
+DB_DIR = os.path.join(DATA_DIR, "db")
+DB_FILE = os.path.join(DB_DIR, "answer_logs.db")
# 创建异步日志队列和工作线程
log_queue = queue.Queue()
worker_thread = None
+db_lock = threading.Lock() # 数据库操作锁
+
+# 初始化数据库
+def init_database():
+ os.makedirs(DB_DIR, exist_ok=True)
+
+ with db_lock:
+ with closing(sqlite3.connect(DB_FILE)) as conn:
+ cursor = conn.cursor()
+
+ # 创建查询类型表
+ cursor.execute('''
+ CREATE TABLE IF NOT EXISTS query_types (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ query_type TEXT NOT NULL,
+ workflow_run_id TEXT NOT NULL,
+ timestamp TEXT NOT NULL
+ )
+ ''')
+
+ # 创建点踩原因表
+ cursor.execute('''
+ CREATE TABLE IF NOT EXISTS dislike_reasons (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ dislike_reason TEXT NOT NULL,
+ workflow_run_id TEXT NOT NULL,
+ timestamp TEXT NOT NULL
+ )
+ ''')
+
+ conn.commit()
+
+ logger.info("数据库初始化完成")
# 后台工作线程函数
def log_worker():
@@ -65,44 +100,49 @@ def log_worker():
# 提取数据处理逻辑到单独函数
def process_log_data(data):
try:
- # 确保目录存在
- os.makedirs(os.path.dirname(QUERY_DATA_FILE), exist_ok=True)
-
- # 读取现有数据
- existing_data = []
- if os.path.exists(QUERY_DATA_FILE) and os.path.getsize(QUERY_DATA_FILE) > 0:
- with open(QUERY_DATA_FILE, 'r', encoding='utf-8') as f:
- try:
- existing_data = json.load(f)
- except json.JSONDecodeError:
- logger.error(f"JSON文件解析错误,将创建新文件: {QUERY_DATA_FILE}")
- existing_data = []
-
- # 添加新数据
- existing_data.append(data)
-
- # 写入文件
- with open(QUERY_DATA_FILE, 'w', encoding='utf-8') as f:
- json.dump(existing_data, f, ensure_ascii=False, indent=2)
-
- logger.info(f"成功保存查询数据到: {QUERY_DATA_FILE}")
+ with db_lock:
+ with closing(sqlite3.connect(DB_FILE)) as conn:
+ cursor = conn.cursor()
+
+ if "dislike_reason" in data:
+ # 保存点踩原因
+ cursor.execute(
+ "INSERT INTO dislike_reasons (dislike_reason, workflow_run_id, timestamp) VALUES (?, ?, ?)",
+ (data["dislike_reason"], data["workflow_run_id"], data["timestamp"])
+ )
+ table_name = "dislike_reasons"
+ else:
+ # 保存查询类型
+ cursor.execute(
+ "INSERT INTO query_types (query_type, workflow_run_id, timestamp) VALUES (?, ?, ?)",
+ (data["query_type"], data["workflow_run_id"], data["timestamp"])
+ )
+ table_name = "query_types"
+
+ conn.commit()
+
+ logger.info(f"成功保存数据到表: {table_name}")
except Exception as e:
logger.error(f"处理日志数据时出错: {str(e)}", exc_info=True)
-# 创建日志目录
-os.makedirs(QUERY_LOG_DIR, exist_ok=True)
+# 创建数据目录
+os.makedirs(DATA_DIR, exist_ok=True)
# 配置日志 - 同时输出到控制台和文件
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
+# 创建日志目录
+LOG_DIR = os.path.join(DATA_DIR, "logs")
+os.makedirs(LOG_DIR, exist_ok=True)
+
# 创建控制台处理器
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
# 创建文件处理器
file_handler = logging.FileHandler(
- os.path.join(QUERY_LOG_DIR, "answer_type_service.log"),
+ os.path.join(LOG_DIR, "answer_type_service.log"),
encoding='utf-8'
)
file_handler.setLevel(logging.INFO)
@@ -145,12 +185,8 @@ app.add_middleware(
@app.on_event("startup")
async def startup_event():
global worker_thread
- # 确保日志目录存在
- os.makedirs(QUERY_LOG_DIR, exist_ok=True)
- # 确保日志文件存在
- if not os.path.exists(QUERY_DATA_FILE):
- with open(QUERY_DATA_FILE, 'w', encoding='utf-8') as f:
- json.dump([], f, ensure_ascii=False)
+ # 初始化数据库
+ init_database()
# 启动后台工作线程
worker_thread = threading.Thread(target=log_worker, daemon=True)
@@ -206,18 +242,138 @@ async def query_type(query_type: str, workflow_run_id:str):
logger.error(f"处理请求时出错: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"处理请求时出错: {str(e)}")
+@app.get("/dislike_reason", summary="记录点踩原因")
+async def dislike_reason(reason: str, workflow_run_id: str):
+ try:
+ # 记录请求
+ logger.info(f"接收到点踩原因: {reason}, workflow_run_id: {workflow_run_id}")
+
+ # 准备数据
+ timestamp = datetime.datetime.now().isoformat()
+ dislike_data = {
+ "dislike_reason": reason,
+ "timestamp": timestamp,
+ "workflow_run_id": workflow_run_id
+ }
+
+ # 将数据放入队列
+ try:
+ log_queue.put(dislike_data)
+ success = True
+ logger.info(f"点踩原因数据已加入队列,当前队列大小: {log_queue.qsize()}")
+ except Exception as e:
+ success = False
+ logger.error(f"加入队列时出错: {str(e)}", exc_info=True)
+
+ # 返回响应
+ content = f"点踩原因: {reason}
操作是否成功: {'成功' if success else '失败'}"
+ return HTMLResponse(content=content)
+ except Exception as e:
+ logger.error(f"处理点踩原因请求时出错: {str(e)}", exc_info=True)
+ raise HTTPException(status_code=500, detail=f"处理点踩原因请求时出错: {str(e)}")
+
+# 添加数据查询API
+@app.get("/stats", summary="查询统计数据")
+async def get_stats():
+ try:
+ with db_lock:
+ with closing(sqlite3.connect(DB_FILE)) as conn:
+ conn.row_factory = sqlite3.Row # 启用字典行工厂
+ cursor = conn.cursor()
+
+ # 查询类型统计
+ cursor.execute("SELECT COUNT(*) as count FROM query_types")
+ query_count = cursor.fetchone()['count']
+
+ # 点踩原因统计
+ cursor.execute("SELECT COUNT(*) as count FROM dislike_reasons")
+ dislike_count = cursor.fetchone()['count']
+
+ # 最近5条查询记录
+ cursor.execute("""
+ SELECT query_type, workflow_run_id, timestamp
+ FROM query_types
+ ORDER BY id DESC LIMIT 5
+ """)
+ recent_queries = [dict(row) for row in cursor.fetchall()]
+
+ # 最近5条点踩记录
+ cursor.execute("""
+ SELECT dislike_reason, workflow_run_id, timestamp
+ FROM dislike_reasons
+ ORDER BY id DESC LIMIT 5
+ """)
+ recent_dislikes = [dict(row) for row in cursor.fetchall()]
+
+ return {
+ "statistics": {
+ "total_queries": query_count,
+ "total_dislikes": dislike_count
+ },
+ "recent_data": {
+ "queries": recent_queries,
+ "dislikes": recent_dislikes
+ }
+ }
+ except Exception as e:
+ logger.error(f"获取统计数据时出错: {str(e)}", exc_info=True)
+ raise HTTPException(status_code=500, detail=f"获取统计数据时出错: {str(e)}")
+
+@app.get("/query_by_workflow_id", summary="根据工作流ID获取查询类型数据")
+async def get_query_by_workflow_id(workflow_run_id: str):
+ try:
+ with db_lock:
+ with closing(sqlite3.connect(DB_FILE)) as conn:
+ conn.row_factory = sqlite3.Row # 启用字典行工厂
+ cursor = conn.cursor()
+
+ # 查询指定工作流ID的查询类型数据
+ cursor.execute("""
+ SELECT id, query_type, workflow_run_id, timestamp
+ FROM query_types
+ WHERE workflow_run_id = ?
+ ORDER BY id DESC
+ """, (workflow_run_id,))
+
+ results = [dict(row) for row in cursor.fetchall()]
+
+ if not results:
+ return {"message": f"未找到工作流ID为 {workflow_run_id} 的查询类型数据", "data": []}
+
+ return {"data": results}
+ except Exception as e:
+ logger.error(f"根据工作流ID获取查询类型数据时出错: {str(e)}", exc_info=True)
+ raise HTTPException(status_code=500, detail=f"根据工作流ID获取查询类型数据时出错: {str(e)}")
+
+@app.get("/dislike_by_workflow_id", summary="根据工作流ID获取点踩原因数据")
+async def get_dislike_by_workflow_id(workflow_run_id: str):
+ try:
+ with db_lock:
+ with closing(sqlite3.connect(DB_FILE)) as conn:
+ conn.row_factory = sqlite3.Row # 启用字典行工厂
+ cursor = conn.cursor()
+
+ # 查询指定工作流ID的点踩原因数据
+ cursor.execute("""
+ SELECT id, dislike_reason, workflow_run_id, timestamp
+ FROM dislike_reasons
+ WHERE workflow_run_id = ?
+ ORDER BY id DESC
+ """, (workflow_run_id,))
+
+ results = [dict(row) for row in cursor.fetchall()]
+
+ if not results:
+ return {"message": f"未找到工作流ID为 {workflow_run_id} 的点踩原因数据", "data": []}
+
+ return {"data": results}
+ except Exception as e:
+ logger.error(f"根据工作流ID获取点踩原因数据时出错: {str(e)}", exc_info=True)
+ raise HTTPException(status_code=500, detail=f"根据工作流ID获取点踩原因数据时出错: {str(e)}")
+
if __name__ == "__main__":
# 使用Uvicorn运行FastAPI应用
import uvicorn
uvicorn.run("rag2_0.dify.AnswerType:app", host="0.0.0.0", port=8003, reload=False, workers=1, log_level="info")
- # # 使用uvicorn启动服务
- # import uvicorn
- # uvicorn.run(
- # "rag2_0.dify.intent_recognition_api:app",
- # host="0.0.0.0",
- # port=8001,
- # reload=False, # 开发环境启用热重载
- # workers=1 # 生产环境可以增加worker数量
- # )
# 生产环境可以使用以下命令启动:
# uvicorn rag2_0.dify.AnswerType:app --host 0.0.0.0 --port 8003 --workers 1
\ No newline at end of file
diff --git a/rag2_0/dify/DifyQueryRetrieval_api.py b/rag2_0/dify/DifyQueryRetrieval_api.py
index 52c05c9..aaffbf6 100644
--- a/rag2_0/dify/DifyQueryRetrieval_api.py
+++ b/rag2_0/dify/DifyQueryRetrieval_api.py
@@ -21,12 +21,15 @@ import sys
sys.path.append(os.getcwd())
from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval
+# 确保日志目录存在
+os.makedirs('data/logs', exist_ok=True)
logging.basicConfig(
level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
+ format='%(asctime)s - %(name)s - [%(thread)d] - %(levelname)s - %(message)s',
handlers=[
- logging.StreamHandler()
+ logging.StreamHandler(),
+ logging.FileHandler(f'data/logs/dify_query_retrieval_{datetime.datetime.now().strftime("%Y%m%d")}.log', encoding='utf-8')
]
)
logging.getLogger('httpx').setLevel(logging.WARNING)
diff --git a/rag2_0/dify/dify_tool.py b/rag2_0/dify/dify_tool.py
index af5b433..2e06a8e 100755
--- a/rag2_0/dify/dify_tool.py
+++ b/rag2_0/dify/dify_tool.py
@@ -357,6 +357,9 @@ class DifyTool:
"""
return self.dify_pgsql.get_app_conversations(conversation_id)
+ def get_workflow_node_executions_info(self, workflow_run_id:str):
+ return self.dify_pgsql.get_workflow_node_executions_info(workflow_run_id)
+
def get_message_rating(self, msg_id):
return self.dify_pgsql.get_message_rating(msg_id)
diff --git a/rag2_0/dify/intent_recognition_api.py b/rag2_0/dify/intent_recognition_api.py
index c527fe7..ccccb0a 100755
--- a/rag2_0/dify/intent_recognition_api.py
+++ b/rag2_0/dify/intent_recognition_api.py
@@ -18,12 +18,15 @@ import sys
sys.path.append(os.getcwd())
from rag2_0.intent_recognition import AsyncIntentRecognizer
+# 确保日志目录存在
+os.makedirs('data/logs', exist_ok=True)
logging.basicConfig(
level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
+ format='%(asctime)s - %(process)d - %(thread)d - %(name)s - %(levelname)s - %(message)s',
handlers=[
- logging.StreamHandler()
+ logging.StreamHandler(),
+ logging.FileHandler(f'data/logs/intent_recognition_{datetime.datetime.now().strftime("%Y%m%d")}.log', encoding='utf-8')
]
)
logging.getLogger('httpx').setLevel(logging.WARNING)
@@ -118,7 +121,7 @@ async def intent_recognize(request: IntentRecognizeRequest):
end_time = time.time()
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S %z")
- logger.info(f"[{os.getpid()}] 意图识别耗时: {end_time - start_time:.2f}秒")
+ logger.info(f"意图识别耗时: {end_time - start_time:.2f}秒")
# 提取分类信息
classification = result["classification"]