537 lines
20 KiB
Python
Executable File
537 lines
20 KiB
Python
Executable File
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
|
|
from __future__ import annotations
|
|
import json
|
|
import os
|
|
import re
|
|
import configparser
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
from dataclasses import dataclass
|
|
from contextlib import contextmanager
|
|
import threading
|
|
import time
|
|
from queue import Queue, Empty, Full
|
|
|
|
import pandas as pd
|
|
import pymysql
|
|
from pymysql.connections import Connection
|
|
from pymysql.cursors import Cursor
|
|
from tqdm import tqdm
|
|
import concurrent.futures
|
|
import sys
|
|
|
|
# 配置日志
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler('./data/log/mariadb_client.log'),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
os.makedirs('./data/log', exist_ok=True)
|
|
|
|
@dataclass
|
|
class DatabaseConfig:
|
|
"""数据库配置类"""
|
|
host: str = '192.168.0.123'
|
|
port: int = 3307
|
|
user: str = 'fuzhimei'
|
|
password: str = 'fuzhimei@135'
|
|
charset: str = 'utf8mb4'
|
|
connect_timeout: int = 10
|
|
read_timeout: int = 300
|
|
write_timeout: int = 300
|
|
|
|
@classmethod
|
|
def from_config_file(cls, config_file: str = 'config.ini') -> 'DatabaseConfig':
|
|
"""从配置文件加载配置"""
|
|
if not os.path.exists(config_file):
|
|
logger.warning(f"配置文件 {config_file} 不存在,使用默认配置")
|
|
return cls()
|
|
|
|
config = configparser.ConfigParser()
|
|
config.read(config_file, encoding='utf-8')
|
|
|
|
if 'database' not in config:
|
|
logger.warning("配置文件中没有 [database] 部分,使用默认配置")
|
|
return cls()
|
|
|
|
db_config = config['database']
|
|
return cls(
|
|
host=db_config.get('host', cls.host),
|
|
port=int(db_config.get('port', cls.port)),
|
|
user=db_config.get('user', cls.user),
|
|
password=db_config.get('password', cls.password),
|
|
charset=db_config.get('charset', cls.charset),
|
|
connect_timeout=int(db_config.get('connect_timeout', cls.connect_timeout)),
|
|
read_timeout=int(db_config.get('read_timeout', cls.read_timeout)),
|
|
write_timeout=int(db_config.get('write_timeout', cls.write_timeout))
|
|
)
|
|
|
|
|
|
class ConnectionPool:
|
|
"""数据库连接池"""
|
|
|
|
def __init__(self, config: DatabaseConfig, max_connections: int = 10):
|
|
self.config = config
|
|
self.max_connections = max_connections
|
|
self.pool = Queue(maxsize=max_connections)
|
|
self.active_connections = 0
|
|
self.lock = threading.Lock()
|
|
|
|
# 预创建一些连接
|
|
self._initialize_pool()
|
|
|
|
def _initialize_pool(self) -> None:
|
|
"""初始化连接池,预创建一些连接"""
|
|
initial_connections = min(3, self.max_connections)
|
|
for _ in range(initial_connections):
|
|
try:
|
|
conn = self._create_connection()
|
|
if conn:
|
|
self.pool.put_nowait(conn)
|
|
self.active_connections += 1
|
|
except Full:
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"初始化连接池时创建连接失败: {e}")
|
|
|
|
def _create_connection(self) -> Optional[Connection]:
|
|
"""创建新的数据库连接"""
|
|
try:
|
|
conn = pymysql.connect(
|
|
host=self.config.host,
|
|
port=self.config.port,
|
|
user=self.config.user,
|
|
password=self.config.password,
|
|
charset=self.config.charset,
|
|
connect_timeout=self.config.connect_timeout,
|
|
read_timeout=self.config.read_timeout,
|
|
write_timeout=self.config.write_timeout,
|
|
autocommit=True
|
|
)
|
|
return conn
|
|
except Exception as e:
|
|
logger.error(f"创建数据库连接失败: {e}")
|
|
return None
|
|
|
|
@contextmanager
|
|
def get_connection(self):
|
|
"""获取连接的上下文管理器"""
|
|
conn = None
|
|
try:
|
|
# 尝试从池中获取连接
|
|
try:
|
|
conn = self.pool.get_nowait()
|
|
except Empty:
|
|
# 池中没有连接,尝试创建新连接
|
|
with self.lock:
|
|
if self.active_connections < self.max_connections:
|
|
conn = self._create_connection()
|
|
if conn:
|
|
self.active_connections += 1
|
|
else:
|
|
raise Exception("无法创建新的数据库连接")
|
|
else:
|
|
# 等待可用连接
|
|
logger.info("等待可用连接...")
|
|
conn = self.pool.get(timeout=30)
|
|
|
|
# 检查连接是否仍然有效
|
|
if conn and not self._is_connection_alive(conn):
|
|
logger.warning("连接已失效,重新创建")
|
|
try:
|
|
conn.close()
|
|
except:
|
|
pass
|
|
conn = self._create_connection()
|
|
if not conn:
|
|
raise Exception("重新创建连接失败")
|
|
|
|
yield conn
|
|
|
|
except Exception as e:
|
|
logger.error(f"获取数据库连接时出错: {e}")
|
|
if conn:
|
|
try:
|
|
conn.close()
|
|
except:
|
|
pass
|
|
with self.lock:
|
|
self.active_connections -= 1
|
|
raise
|
|
else:
|
|
# 归还连接到池中
|
|
if conn:
|
|
try:
|
|
self.pool.put_nowait(conn)
|
|
except Full:
|
|
# 池已满,关闭连接
|
|
try:
|
|
conn.close()
|
|
except:
|
|
pass
|
|
with self.lock:
|
|
self.active_connections -= 1
|
|
|
|
def _is_connection_alive(self, conn: Connection) -> bool:
|
|
"""检查连接是否仍然有效"""
|
|
try:
|
|
conn.ping(reconnect=False)
|
|
return True
|
|
except:
|
|
return False
|
|
|
|
def close_all(self) -> None:
|
|
"""关闭所有连接"""
|
|
logger.info("正在关闭连接池中的所有连接...")
|
|
while not self.pool.empty():
|
|
try:
|
|
conn = self.pool.get_nowait()
|
|
conn.close()
|
|
except (Empty, Exception):
|
|
break
|
|
|
|
self.active_connections = 0
|
|
logger.info("连接池已关闭")
|
|
|
|
|
|
class DataProcessor:
|
|
"""数据处理器"""
|
|
|
|
@staticmethod
|
|
def clean_html_tags(text: str) -> str:
|
|
"""清除文本中的HTML标签"""
|
|
if not isinstance(text, str):
|
|
return str(text) if text is not None else ""
|
|
|
|
# 使用正则表达式移除HTML标签
|
|
clean_text = re.sub(r'<[^>]+>', '', text)
|
|
# 处理HTML实体
|
|
html_entities = {
|
|
' ': ' ',
|
|
'<': '<',
|
|
'>': '>',
|
|
'&': '&',
|
|
'"': '"',
|
|
''': "'"
|
|
}
|
|
for entity, char in html_entities.items():
|
|
clean_text = clean_text.replace(entity, char)
|
|
|
|
return clean_text.strip()
|
|
|
|
@staticmethod
|
|
def messages_df_to_list(messages_df: pd.DataFrame) -> List[Dict[str, Any]]:
|
|
"""将消息DataFrame转换为字典列表,使用高效的向量化操作"""
|
|
if messages_df.empty:
|
|
return []
|
|
|
|
# 过滤掉系统消息
|
|
mask = (messages_df["MODE"] != "system") & (messages_df["SYSTEM_MODE_MESSAGE_TYPE"].isna())
|
|
filtered_df = messages_df[mask].copy()
|
|
|
|
if filtered_df.empty:
|
|
return []
|
|
|
|
# 向量化操作
|
|
filtered_df['message_sender'] = filtered_df["MODE"].map({'reply': '坐席', 'receive': '访客'}).fillna('未知')
|
|
|
|
# 处理发送者昵称
|
|
filtered_df['sender_nickname'] = filtered_df.apply(
|
|
lambda row: row["AGENT_NAME"] if row["message_sender"] == "坐席" else row["CUS_NICK_NAME"],
|
|
axis=1
|
|
)
|
|
|
|
# 处理内容
|
|
def process_content(row):
|
|
content = row["CONTENT"]
|
|
if row["MSG_TYPE"] == "attachment":
|
|
return f"附件:{DataProcessor.clean_html_tags(content)}"
|
|
elif row["MSG_TYPE"] == "image":
|
|
return f"图片:{DataProcessor.clean_html_tags(content)}"
|
|
else:
|
|
return content
|
|
|
|
filtered_df['processed_content'] = filtered_df.apply(process_content, axis=1)
|
|
|
|
# 过滤掉空昵称
|
|
filtered_df = filtered_df[filtered_df['sender_nickname'].notna() & (filtered_df['sender_nickname'] != '')]
|
|
|
|
# 转换为字典列表
|
|
result = []
|
|
for record in filtered_df.to_dict('records'):
|
|
result.append({
|
|
"账号id": record["ACCOUNT"],
|
|
"会话id": record["SESSION_ID"],
|
|
"消息内容": record["processed_content"],
|
|
"消息发送者": record["message_sender"],
|
|
"发送者昵称": record["sender_nickname"],
|
|
"创建时间": record["CREATE_TIME"],
|
|
})
|
|
|
|
return result
|
|
|
|
|
|
class MariaDBClient:
|
|
"""优化后的MariaDB数据库客户端"""
|
|
|
|
def __init__(self, config: DatabaseConfig, max_connections: int = 10):
|
|
self.config = config
|
|
self.connection_pool = ConnectionPool(config, max_connections)
|
|
self.data_processor = DataProcessor()
|
|
|
|
def __enter__(self) -> 'MariaDBClient':
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
self.close()
|
|
|
|
def close(self) -> None:
|
|
"""关闭客户端"""
|
|
self.connection_pool.close_all()
|
|
|
|
def execute_query(self, sql: str, params: Optional[Tuple] = None) -> Tuple[Optional[pd.DataFrame], List[str]]:
|
|
"""执行SQL查询"""
|
|
try:
|
|
with self.connection_pool.get_connection() as conn:
|
|
with conn.cursor() as cursor:
|
|
cursor.execute(sql, params)
|
|
results = cursor.fetchall()
|
|
|
|
# 获取列名
|
|
column_names = [desc[0] for desc in cursor.description] if cursor.description else []
|
|
|
|
if results:
|
|
df = pd.DataFrame(results, columns=column_names)
|
|
return df, column_names
|
|
else:
|
|
return pd.DataFrame(), column_names
|
|
|
|
except Exception as e:
|
|
logger.error(f"执行查询时出错: {e}")
|
|
logger.error(f"SQL: {sql}")
|
|
return None, []
|
|
|
|
def query_sessions(self, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
|
"""查询指定日期范围内的会话数据"""
|
|
sql = """
|
|
SELECT ACCOUNT, BEGIN_TIME, END_TIME, CUST_SEND_MESSAGE_COUNT,
|
|
AGENT_SEND_MESSAGE_COUNT, STATUS, CHANNEL_NAME, SESSION_ID, SESSION_TAG_NAME
|
|
FROM crm_hlyj.crm_hlyj_dsri
|
|
WHERE BEGIN_TIME >= %s
|
|
AND BEGIN_TIME < %s
|
|
AND STATUS = 'assign'
|
|
ORDER BY BEGIN_TIME DESC
|
|
"""
|
|
|
|
df, _ = self.execute_query(sql, (start_date, end_date))
|
|
return df
|
|
|
|
def query_messages_by_session_id(self, session_id: str) -> Optional[pd.DataFrame]:
|
|
"""根据会话ID查询消息详情"""
|
|
sql = """
|
|
SELECT CREATE_TIME, CUS_NICK_NAME, MODE, MSG_TYPE, AGENT_NAME, CONTENT,
|
|
SESSION_ID, ACCOUNT, SYSTEM_MODE_MESSAGE_TYPE
|
|
FROM crm_hlyj.crm_hlyj_dmri
|
|
WHERE SESSION_ID = %s
|
|
ORDER BY CREATE_TIME
|
|
"""
|
|
|
|
df, _ = self.execute_query(sql, (session_id,))
|
|
return df
|
|
|
|
def export_to_excel(self, data: List[Dict[str, Any]], filename: str, output_dir: str = "output") -> Optional[str]:
|
|
"""导出数据到Excel文件"""
|
|
if not data:
|
|
logger.warning(f"没有数据可导出到 {filename}")
|
|
return None
|
|
|
|
try:
|
|
# 创建输出目录
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
# 生成文件路径
|
|
# timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
file_path = os.path.join(output_dir, f"{filename}.xlsx")
|
|
|
|
# 准备数据:不同对话之间添加空行
|
|
all_rows = []
|
|
current_session_id = None
|
|
|
|
for conversation in data:
|
|
if not conversation: # 跳过空对话
|
|
continue
|
|
|
|
# 如果是新的会话,添加空行(除了第一个会话)
|
|
if current_session_id and current_session_id != conversation[0]["会话id"]:
|
|
empty_row = {key: "" for key in conversation[0].keys()}
|
|
all_rows.append(empty_row)
|
|
|
|
# 更新当前会话ID
|
|
current_session_id = conversation[0]["会话id"]
|
|
|
|
# 添加当前会话的所有消息
|
|
all_rows.extend(conversation)
|
|
|
|
# 创建DataFrame并导出
|
|
if all_rows:
|
|
df = pd.DataFrame(all_rows)
|
|
with pd.ExcelWriter(file_path, engine='openpyxl') as writer:
|
|
df.to_excel(writer, sheet_name='对话记录', index=False)
|
|
|
|
logger.info(f"数据已导出到 {file_path}")
|
|
return file_path
|
|
else:
|
|
logger.warning("没有有效数据可导出")
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"导出到Excel时出错: {e}")
|
|
return None
|
|
|
|
|
|
def process_session_batch(db_client: MariaDBClient, session_batch: pd.DataFrame) -> List[List[Dict[str, Any]]]:
|
|
"""批量处理会话数据"""
|
|
conversations = []
|
|
|
|
for _, session_row in session_batch.iterrows():
|
|
try:
|
|
session_id = session_row['SESSION_ID']
|
|
messages_df = db_client.query_messages_by_session_id(session_id)
|
|
|
|
if messages_df is not None and not messages_df.empty:
|
|
conversation = db_client.data_processor.messages_df_to_list(messages_df)
|
|
if conversation:
|
|
conversations.append(conversation)
|
|
|
|
except Exception as e:
|
|
logger.error(f"处理会话 {session_row.get('SESSION_ID', 'unknown')} 时出错: {e}")
|
|
continue
|
|
|
|
return conversations
|
|
|
|
|
|
class SessionProcessor:
|
|
"""会话处理器,负责批量和并发处理"""
|
|
|
|
def __init__(self, db_client: MariaDBClient, max_workers: int = None, batch_size: int = 50):
|
|
self.db_client = db_client
|
|
self.max_workers = max_workers if max_workers is not None else os.cpu_count()
|
|
self.batch_size = batch_size
|
|
self.temp_save_lock = threading.Lock() # 添加锁用于保护临时保存操作
|
|
|
|
logger.info(f"初始化会话处理器: max_workers={self.max_workers}, batch_size={self.batch_size}")
|
|
|
|
def process_sessions(self, sessions_df: pd.DataFrame) -> List[List[Dict[str, Any]]]:
|
|
"""处理所有会话数据"""
|
|
if sessions_df.empty:
|
|
logger.warning("没有会话数据需要处理")
|
|
return []
|
|
|
|
total_sessions = len(sessions_df)
|
|
logger.info(f"开始处理 {total_sessions} 个会话...")
|
|
|
|
# 分批处理
|
|
all_conversations = []
|
|
batch_count = (total_sessions + self.batch_size - 1) // self.batch_size
|
|
# 使用线程池处理批次
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|
# 提交所有批次任务
|
|
future_to_batch = {}
|
|
|
|
for i in range(0, total_sessions, self.batch_size):
|
|
batch = sessions_df.iloc[i:i + self.batch_size]
|
|
future = executor.submit(process_session_batch, self.db_client, batch)
|
|
future_to_batch[future] = i // self.batch_size + 1
|
|
|
|
# 收集结果
|
|
with tqdm(total=batch_count, desc="处理批次进度") as pbar:
|
|
for future in concurrent.futures.as_completed(future_to_batch):
|
|
try:
|
|
batch_conversations = future.result()
|
|
all_conversations.extend(batch_conversations)
|
|
|
|
# 使用锁保护临时列表的操作
|
|
with self.temp_save_lock:
|
|
# 每处理100个对话临时保存一次
|
|
logger.info(f"临时保存 {len(all_conversations)} 个对话")
|
|
temp_output_file = self.db_client.export_to_excel(
|
|
all_conversations,
|
|
f"客服对话记录_临时保存",
|
|
output_dir="/data/QueryRewrite/data/excel"
|
|
)
|
|
if temp_output_file:
|
|
logger.info(f"临时保存完成: {temp_output_file}")
|
|
|
|
batch_num = future_to_batch[future]
|
|
logger.debug(f"批次 {batch_num} 完成,获得 {len(batch_conversations)} 个对话")
|
|
|
|
except Exception as e:
|
|
batch_num = future_to_batch[future]
|
|
logger.error(f"处理批次 {batch_num} 时出错: {e}")
|
|
|
|
pbar.update(1)
|
|
|
|
logger.info(f"处理完成,共获得 {len(all_conversations)} 个有效对话")
|
|
return all_conversations
|
|
|
|
|
|
def main() -> None:
|
|
"""主函数"""
|
|
try:
|
|
# 加载配置
|
|
config = DatabaseConfig.from_config_file()
|
|
logger.info(f"使用数据库配置: {config.host}:{config.port}")
|
|
|
|
# 创建数据库客户端
|
|
with MariaDBClient(config, max_connections=12) as db_client:
|
|
# 查询会话数据
|
|
start_date = '2025-01-01 00:00:00'
|
|
end_date = '2025-06-12 00:00:00'
|
|
|
|
logger.info(f"查询时间范围: {start_date} 到 {end_date}")
|
|
# 创建会话处理器
|
|
processor = SessionProcessor(db_client, batch_size=100)
|
|
is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
|
|
if is_debug:
|
|
messages_df = db_client.query_messages_by_session_id("86c919e0-09f1-11f0-84ae-2daf59566989")
|
|
print(db_client.data_processor.messages_df_to_list(messages_df))
|
|
return []
|
|
|
|
sessions_df = db_client.query_sessions(start_date, end_date)
|
|
|
|
if sessions_df is None or sessions_df.empty:
|
|
logger.warning("没有找到符合条件的会话数据")
|
|
return
|
|
|
|
# 处理会话数据
|
|
all_conversations = processor.process_sessions(sessions_df)
|
|
# 导出结果
|
|
if all_conversations:
|
|
output_file = db_client.export_to_excel(
|
|
all_conversations,
|
|
"客服对话记录",
|
|
output_dir="/data/QueryRewrite/data/excel"
|
|
)
|
|
|
|
if output_file:
|
|
logger.info(f"处理完成!共导出 {len(all_conversations)} 个对话到文件: {output_file}")
|
|
else:
|
|
logger.error("导出文件失败")
|
|
else:
|
|
logger.warning("没有有效的对话数据可导出")
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("用户中断程序")
|
|
except Exception as e:
|
|
logger.error(f"程序执行出错: {e}", exc_info=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |