调整提示词、简化代码

This commit is contained in:
2025-09-25 16:23:13 +08:00
parent 640e02f89e
commit 2b13fdab99
5 changed files with 235 additions and 476 deletions
+168 -286
View File
@@ -5,23 +5,18 @@ from __future__ import annotations
import json
import os
import re
import configparser
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass
from contextlib import contextmanager
import threading
import time
from queue import Queue, Empty, Full
import pandas as pd
import pymysql
from pymysql.connections import Connection
from pymysql.cursors import Cursor
from tqdm import tqdm
import concurrent.futures
import sys
from queue import Queue, Empty, Full
os.makedirs('./data/logs', exist_ok=True)
# 配置日志
@@ -35,6 +30,18 @@ logging.basicConfig(
)
logger = logging.getLogger(__name__)
# =====================
# 硬编码数据库配置(简化)
# =====================
DB_HOST = '192.168.0.123'
DB_PORT = 3307
DB_USER = 'fuzhimei'
DB_PASSWORD = 'fuzhimei@135'
DB_CHARSET = 'utf8mb4'
DB_CONNECT_TIMEOUT = 10
DB_READ_TIMEOUT = 300
DB_WRITE_TIMEOUT = 300
def parse_session_tags(input_string):
"""
解析sessionTag格式的字符串,支持任意数量的sessionTag
@@ -74,172 +81,7 @@ def parse_session_tags(input_string):
session_key = f'sessionTag{tag_suffix}'
result[session_key] = tag_data
return result
@dataclass
class DatabaseConfig:
"""数据库配置类"""
host: str = '192.168.0.123'
port: int = 3307
user: str = 'fuzhimei'
password: str = 'fuzhimei@135'
charset: str = 'utf8mb4'
connect_timeout: int = 10
read_timeout: int = 300
write_timeout: int = 300
@classmethod
def from_config_file(cls, config_file: str = 'config.ini') -> 'DatabaseConfig':
"""从配置文件加载配置"""
if not os.path.exists(config_file):
logger.warning(f"配置文件 {config_file} 不存在,使用默认配置")
return cls()
config = configparser.ConfigParser()
config.read(config_file, encoding='utf-8')
if 'database' not in config:
logger.warning("配置文件中没有 [database] 部分,使用默认配置")
return cls()
db_config = config['database']
return cls(
host=db_config.get('host', cls.host),
port=int(db_config.get('port', cls.port)),
user=db_config.get('user', cls.user),
password=db_config.get('password', cls.password),
charset=db_config.get('charset', cls.charset),
connect_timeout=int(db_config.get('connect_timeout', cls.connect_timeout)),
read_timeout=int(db_config.get('read_timeout', cls.read_timeout)),
write_timeout=int(db_config.get('write_timeout', cls.write_timeout))
)
class ConnectionPool:
"""数据库连接池"""
def __init__(self, config: DatabaseConfig, max_connections: int = 10):
self.config = config
self.max_connections = max_connections
self.pool = Queue(maxsize=max_connections)
self.active_connections = 0
self.lock = threading.Lock()
# 预创建一些连接
self._initialize_pool()
def _initialize_pool(self) -> None:
"""初始化连接池,预创建一些连接"""
initial_connections = min(3, self.max_connections)
for _ in range(initial_connections):
try:
conn = self._create_connection()
if conn:
self.pool.put_nowait(conn)
self.active_connections += 1
except Full:
break
except Exception as e:
logger.error(f"初始化连接池时创建连接失败: {e}")
def _create_connection(self) -> Optional[Connection]:
"""创建新的数据库连接"""
try:
conn = pymysql.connect(
host=self.config.host,
port=self.config.port,
user=self.config.user,
password=self.config.password,
charset=self.config.charset,
connect_timeout=self.config.connect_timeout,
read_timeout=self.config.read_timeout,
write_timeout=self.config.write_timeout,
autocommit=True
)
return conn
except Exception as e:
logger.error(f"创建数据库连接失败: {e}")
return None
@contextmanager
def get_connection(self):
"""获取连接的上下文管理器"""
conn = None
try:
# 尝试从池中获取连接
try:
conn = self.pool.get_nowait()
except Empty:
# 池中没有连接,尝试创建新连接
with self.lock:
if self.active_connections < self.max_connections:
conn = self._create_connection()
if conn:
self.active_connections += 1
else:
raise Exception("无法创建新的数据库连接")
else:
# 等待可用连接
logger.info("等待可用连接...")
conn = self.pool.get(timeout=30)
# 检查连接是否仍然有效
if conn and not self._is_connection_alive(conn):
logger.warning("连接已失效,重新创建")
try:
conn.close()
except:
pass
conn = self._create_connection()
if not conn:
raise Exception("重新创建连接失败")
yield conn
except Exception as e:
logger.error(f"获取数据库连接时出错: {e}")
if conn:
try:
conn.close()
except:
pass
with self.lock:
self.active_connections -= 1
raise
else:
# 归还连接到池中
if conn:
try:
self.pool.put_nowait(conn)
except Full:
# 池已满,关闭连接
try:
conn.close()
except:
pass
with self.lock:
self.active_connections -= 1
def _is_connection_alive(self, conn: Connection) -> bool:
"""检查连接是否仍然有效"""
try:
conn.ping(reconnect=False)
return True
except:
return False
def close_all(self) -> None:
"""关闭所有连接"""
logger.info("正在关闭连接池中的所有连接...")
while not self.pool.empty():
try:
conn = self.pool.get_nowait()
conn.close()
except (Empty, Exception):
break
self.active_connections = 0
logger.info("连接池已关闭")
return result
class DataProcessor:
@@ -357,12 +199,76 @@ class DataProcessor:
class MariaDBClient:
"""优化后的MariaDB数据库客户端"""
"""简化版 MariaDB 客户端(内置轻量连接池以复用连接)"""
def __init__(self, config: DatabaseConfig, max_connections: int = 10):
self.config = config
self.connection_pool = ConnectionPool(config, max_connections)
def __init__(self, max_connections: int = 10):
self.data_processor = DataProcessor()
self._max_connections = max_connections
self._pool: Queue = Queue(maxsize=max_connections)
self._active = 0
self._lock = threading.Lock()
# 预创建少量连接,降低首次延迟
initial = min(3, max_connections)
for _ in range(initial):
conn = self._create_connection()
if conn:
try:
self._pool.put_nowait(conn)
self._active += 1
except Full:
try:
conn.close()
except Exception:
pass
def _create_connection(self):
try:
return pymysql.connect(
host=DB_HOST,
port=DB_PORT,
user=DB_USER,
password=DB_PASSWORD,
charset=DB_CHARSET,
connect_timeout=DB_CONNECT_TIMEOUT,
read_timeout=DB_READ_TIMEOUT,
write_timeout=DB_WRITE_TIMEOUT,
autocommit=True
)
except Exception as e:
logger.error(f"创建数据库连接失败: {e}")
return None
def _acquire_connection(self):
# 先尝试不阻塞获取
try:
return self._pool.get_nowait()
except Empty:
# 池为空,若可创建新连接则创建,否则阻塞等待
with self._lock:
if self._active < self._max_connections:
conn = self._create_connection()
if conn:
self._active += 1
return conn
# 达到上限,阻塞等待可用连接
try:
return self._pool.get(timeout=30)
except Empty:
raise RuntimeError("获取数据库连接超时")
def _release_connection(self, conn):
if not conn:
return
try:
self._pool.put_nowait(conn)
except Full:
# 池已满,关闭多余连接
try:
conn.close()
except Exception:
pass
with self._lock:
self._active = max(0, self._active - 1)
def __enter__(self) -> 'MariaDBClient':
return self
@@ -371,30 +277,52 @@ class MariaDBClient:
self.close()
def close(self) -> None:
"""关闭客户端"""
self.connection_pool.close_all()
"""关闭连接池中的连接"""
try:
while True:
try:
conn = self._pool.get_nowait()
except Empty:
break
try:
conn.close()
except Exception:
pass
finally:
with self._lock:
self._active = 0
def execute_query(self, sql: str, params: Optional[Tuple] = None) -> Tuple[Optional[pd.DataFrame], List[str]]:
"""执行SQL查询"""
"""执行SQL查询(复用连接池连接)"""
conn = None
try:
with self.connection_pool.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(sql, params)
results = cursor.fetchall()
# 获取列名
column_names = [desc[0] for desc in cursor.description] if cursor.description else []
if results:
df = pd.DataFrame(results, columns=column_names)
return df, column_names
else:
return pd.DataFrame(), column_names
conn = self._acquire_connection()
with conn.cursor() as cursor:
cursor.execute(sql, params)
results = cursor.fetchall()
column_names = [desc[0] for desc in cursor.description] if cursor.description else []
if results:
df = pd.DataFrame(results, columns=column_names)
return df, column_names
else:
return pd.DataFrame(), column_names
except Exception as e:
logger.error(f"执行查询时出错: {e}")
logger.error(f"SQL: {sql}")
return None, []
finally:
if conn:
# 若连接异常,尝试关闭并减少活跃计数;否则归还
try:
conn.ping(reconnect=False)
self._release_connection(conn)
except Exception:
try:
conn.close()
except Exception:
pass
with self._lock:
self._active = max(0, self._active - 1)
def query_sessions(self, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""查询指定日期范围内的会话数据"""
@@ -483,125 +411,79 @@ class MariaDBClient:
logger.error(f"导出到Excel时出错: {e}")
return None
class SessionProcessor:
"""会话处理器,负责并发处理"""
def __init__(self, db_client: MariaDBClient, max_workers: int = None):
self.db_client = db_client
self.max_workers = max_workers if max_workers is not None else os.cpu_count()
self.temp_save_lock = threading.Lock() # 添加锁用于保护临时保存操作
logger.info(f"初始化会话处理器: max_workers={self.max_workers}")
def process_sessions(self, sessions_df: pd.DataFrame) -> List[List[Dict[str, Any]]]:
"""处理所有会话数据"""
if sessions_df.empty:
logger.warning("没有会话数据需要处理")
return []
total_sessions = len(sessions_df)
logger.info(f"开始处理 {total_sessions} 个会话...")
all_conversations = []
# 直接并发处理每个会话
def process_single_session(session_row):
try:
session_id = session_row['SESSION_ID']
messages_df = self.db_client.query_messages_by_session_id(session_id)
if messages_df is not None and not messages_df.empty:
conversation = self.db_client.data_processor.messages_df_to_list(messages_df, session_row)
if conversation:
return conversation
except Exception as e:
logger.error(f"处理会话 {session_row.get('SESSION_ID', 'unknown')} 时出错: {e}")
return None
# 使用线程池处理所有会话
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# 提交所有会话处理任务
future_to_session = {
executor.submit(process_single_session, row): i
for i, row in sessions_df.iterrows()
}
# 收集结果
with tqdm(total=total_sessions, desc="处理会话进度") as pbar:
for future in concurrent.futures.as_completed(future_to_session):
try:
conversation = future.result()
if conversation:
all_conversations.append(conversation)
# 每处理100个对话临时保存一次
if len(all_conversations) % 100 == 0:
with self.temp_save_lock:
logger.info(f"临时保存 {len(all_conversations)} 个对话")
temp_output_file = self.db_client.export_to_excel(
all_conversations,
f"客服对话记录_临时保存",
output_dir="/data/QueryRewrite/data/excel"
)
if temp_output_file:
logger.info(f"临时保存完成: {temp_output_file}")
except Exception as e:
session_idx = future_to_session[future]
logger.error(f"处理会话索引 {session_idx} 时出错: {e}")
pbar.update(1)
logger.info(f"处理完成,共获得 {len(all_conversations)} 个有效对话")
return all_conversations
def main() -> None:
"""主函数"""
"""主函数(精简版)"""
try:
# 加载配置
config = DatabaseConfig.from_config_file()
logger.info(f"使用数据库配置: {config.host}:{config.port}")
logger.info(f"使用数据库配置: {DB_HOST}:{DB_PORT}")
# 创建数据库客户端
with MariaDBClient(config, max_connections=12) as db_client:
# 创建数据库客户端(简化)
with MariaDBClient() as db_client:
# 查询会话数据
start_date = '2025-08-01 00:00:00'
end_date = '2025-08-01 23:00:00'
logger.info(f"查询时间范围: {start_date}{end_date}")
# 创建会话处理器
processor = SessionProcessor(db_client)
# is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
# if is_debug:
# messages_df = db_client.query_messages_by_session_id("86c919e0-09f1-11f0-84ae-2daf59566989")
# print(db_client.data_processor.messages_df_to_list(messages_df))
# return []
sessions_df = db_client.query_sessions(start_date, end_date)
if sessions_df is None or sessions_df.empty:
logger.warning("没有找到符合条件的会话数据")
return
# 处理会话数据
all_conversations = processor.process_sessions(sessions_df)
# 直接并发处理每个会话(替代 SessionProcessor
total_sessions = len(sessions_df)
all_conversations: List[List[Dict[str, Any]]] = []
temp_save_lock = threading.Lock()
def process_single_session(session_row):
try:
session_id = session_row['SESSION_ID']
messages_df = db_client.query_messages_by_session_id(session_id)
if messages_df is not None and not messages_df.empty:
conversation = db_client.data_processor.messages_df_to_list(messages_df, session_row)
if conversation:
return conversation
except Exception as e:
logger.error(f"处理会话 {session_row.get('SESSION_ID', 'unknown')} 时出错: {e}")
return None
with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
future_to_session = {executor.submit(process_single_session, row): i for i, row in sessions_df.iterrows()}
with tqdm(total=total_sessions, desc="处理会话进度") as pbar:
for future in concurrent.futures.as_completed(future_to_session):
try:
conversation = future.result()
if conversation:
all_conversations.append(conversation)
if len(all_conversations) % 100 == 0:
with temp_save_lock:
logger.info(f"临时保存 {len(all_conversations)} 个对话")
temp_output_file = db_client.export_to_excel(
all_conversations,
f"客服对话记录_临时保存",
output_dir="/data/QueryRewrite/data/excel"
)
if temp_output_file:
logger.info(f"临时保存完成: {temp_output_file}")
except Exception as e:
session_idx = future_to_session[future]
logger.error(f"处理会话索引 {session_idx} 时出错: {e}")
pbar.update(1)
# 导出结果
if all_conversations:
output_file = db_client.export_to_excel(
all_conversations,
all_conversations,
"客服对话记录",
output_dir="/data/QueryRewrite/data/excel"
)
if output_file:
logger.info(f"处理完成!共导出 {len(all_conversations)} 个对话到文件: {output_file}")
else:
logger.error("导出文件失败")
else:
logger.warning("没有有效的对话数据可导出")
except KeyboardInterrupt:
logger.info("用户中断程序")
except Exception as e: