Files
QueryRewrite/rag2_0/demo/heli_db_to_excel.py
T

612 lines
23 KiB
Python
Executable File

#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import annotations
import json
import os
import re
import configparser
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass
from contextlib import contextmanager
import threading
import time
from queue import Queue, Empty, Full
import pandas as pd
import pymysql
from pymysql.connections import Connection
from pymysql.cursors import Cursor
from tqdm import tqdm
import concurrent.futures
import sys
os.makedirs('./data/logs', exist_ok=True)
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('./data/logs/mariadb_client.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
def parse_session_tags(input_string):
"""
解析sessionTag格式的字符串,支持任意数量的sessionTag
支持格式:sessionTagFirst, sessionTagSecond, sessionTagThird, sessionTagFourth 等
"""
# 去除外层的方括号和引号
cleaned_string = input_string.strip('[]"')
# 使用正则表达式匹配所有的sessionTag
# 匹配模式:sessionTag + 任意后缀 + = + 花括号内容
pattern = r'sessionTag(\w+)=\{([^}]+)\}'
matches = re.findall(pattern, cleaned_string)
result = {}
for tag_suffix, content in matches:
# 解析每个tag的内容
tag_data = {}
# 提取键值对,支持中文和各种字符
kv_pattern = r'(\w+)=([^,}]+?)(?=,\s*\w+=|$|,\s*$)'
kv_matches = re.findall(kv_pattern, content)
for key, value in kv_matches:
# 清理值的空白字符
cleaned_value = value.strip()
# 尝试转换数据类型
if cleaned_value.isdigit():
tag_data[key] = int(cleaned_value)
elif cleaned_value.lower() in ['true', 'false']:
tag_data[key] = cleaned_value.lower() == 'true'
else:
tag_data[key] = cleaned_value
# 构造完整的sessionTag键名
session_key = f'sessionTag{tag_suffix}'
result[session_key] = tag_data
return result
@dataclass
class DatabaseConfig:
"""数据库配置类"""
host: str = '192.168.0.123'
port: int = 3307
user: str = 'fuzhimei'
password: str = 'fuzhimei@135'
charset: str = 'utf8mb4'
connect_timeout: int = 10
read_timeout: int = 300
write_timeout: int = 300
@classmethod
def from_config_file(cls, config_file: str = 'config.ini') -> 'DatabaseConfig':
"""从配置文件加载配置"""
if not os.path.exists(config_file):
logger.warning(f"配置文件 {config_file} 不存在,使用默认配置")
return cls()
config = configparser.ConfigParser()
config.read(config_file, encoding='utf-8')
if 'database' not in config:
logger.warning("配置文件中没有 [database] 部分,使用默认配置")
return cls()
db_config = config['database']
return cls(
host=db_config.get('host', cls.host),
port=int(db_config.get('port', cls.port)),
user=db_config.get('user', cls.user),
password=db_config.get('password', cls.password),
charset=db_config.get('charset', cls.charset),
connect_timeout=int(db_config.get('connect_timeout', cls.connect_timeout)),
read_timeout=int(db_config.get('read_timeout', cls.read_timeout)),
write_timeout=int(db_config.get('write_timeout', cls.write_timeout))
)
class ConnectionPool:
"""数据库连接池"""
def __init__(self, config: DatabaseConfig, max_connections: int = 10):
self.config = config
self.max_connections = max_connections
self.pool = Queue(maxsize=max_connections)
self.active_connections = 0
self.lock = threading.Lock()
# 预创建一些连接
self._initialize_pool()
def _initialize_pool(self) -> None:
"""初始化连接池,预创建一些连接"""
initial_connections = min(3, self.max_connections)
for _ in range(initial_connections):
try:
conn = self._create_connection()
if conn:
self.pool.put_nowait(conn)
self.active_connections += 1
except Full:
break
except Exception as e:
logger.error(f"初始化连接池时创建连接失败: {e}")
def _create_connection(self) -> Optional[Connection]:
"""创建新的数据库连接"""
try:
conn = pymysql.connect(
host=self.config.host,
port=self.config.port,
user=self.config.user,
password=self.config.password,
charset=self.config.charset,
connect_timeout=self.config.connect_timeout,
read_timeout=self.config.read_timeout,
write_timeout=self.config.write_timeout,
autocommit=True
)
return conn
except Exception as e:
logger.error(f"创建数据库连接失败: {e}")
return None
@contextmanager
def get_connection(self):
"""获取连接的上下文管理器"""
conn = None
try:
# 尝试从池中获取连接
try:
conn = self.pool.get_nowait()
except Empty:
# 池中没有连接,尝试创建新连接
with self.lock:
if self.active_connections < self.max_connections:
conn = self._create_connection()
if conn:
self.active_connections += 1
else:
raise Exception("无法创建新的数据库连接")
else:
# 等待可用连接
logger.info("等待可用连接...")
conn = self.pool.get(timeout=30)
# 检查连接是否仍然有效
if conn and not self._is_connection_alive(conn):
logger.warning("连接已失效,重新创建")
try:
conn.close()
except:
pass
conn = self._create_connection()
if not conn:
raise Exception("重新创建连接失败")
yield conn
except Exception as e:
logger.error(f"获取数据库连接时出错: {e}")
if conn:
try:
conn.close()
except:
pass
with self.lock:
self.active_connections -= 1
raise
else:
# 归还连接到池中
if conn:
try:
self.pool.put_nowait(conn)
except Full:
# 池已满,关闭连接
try:
conn.close()
except:
pass
with self.lock:
self.active_connections -= 1
def _is_connection_alive(self, conn: Connection) -> bool:
"""检查连接是否仍然有效"""
try:
conn.ping(reconnect=False)
return True
except:
return False
def close_all(self) -> None:
"""关闭所有连接"""
logger.info("正在关闭连接池中的所有连接...")
while not self.pool.empty():
try:
conn = self.pool.get_nowait()
conn.close()
except (Empty, Exception):
break
self.active_connections = 0
logger.info("连接池已关闭")
class DataProcessor:
"""数据处理器"""
@staticmethod
def clean_html_tags(text: str) -> str:
"""清除文本中的HTML标签"""
if not isinstance(text, str):
return str(text) if text is not None else ""
# 使用正则表达式移除HTML标签
clean_text = re.sub(r'<[^>]+>', '', text)
# 处理HTML实体
html_entities = {
'&nbsp;': ' ',
'&lt;': '<',
'&gt;': '>',
'&amp;': '&',
'&quot;': '"',
'&apos;': "'"
}
for entity, char in html_entities.items():
clean_text = clean_text.replace(entity, char)
return clean_text.strip()
@staticmethod
def get_session_tag_dict(json_data: str) -> dict:
"""解析JSON数据获取会话标签字典"""
try:
json_data_dict = json.loads(json_data)
session_tag_list_str = json_data_dict.get('sessionMultiTagList', None)
if not session_tag_list_str:
return {}
result = parse_session_tags(session_tag_list_str)
return result
except (json.JSONDecodeError, Exception) as e:
logger.error(f"解析会话标签时出错: {e}")
return {}
@staticmethod
def messages_df_to_list(messages_df: pd.DataFrame, session_row) -> List[Dict[str, Any]]:
"""将消息DataFrame转换为字典列表,使用高效的向量化操作"""
if messages_df.empty:
return []
# 过滤掉系统消息
mask = (messages_df["MODE"] != "system") & (messages_df["SYSTEM_MODE_MESSAGE_TYPE"].isna())
filtered_df = messages_df[mask].copy()
if filtered_df.empty:
return []
# 向量化操作
filtered_df['message_sender'] = filtered_df["MODE"].map({'reply': '坐席', 'receive': '访客'}).fillna('未知')
# 处理发送者昵称
filtered_df['sender_nickname'] = filtered_df.apply(
lambda row: row["AGENT_NAME"] if row["message_sender"] == "坐席" else row["CUS_NICK_NAME"],
axis=1
)
json_data = session_row['JSON']
session_tag_dict = DataProcessor.get_session_tag_dict(json_data)
# 处理内容
def process_content(row):
content = row["CONTENT"]
if row["MSG_TYPE"] == "attachment":
return f"附件:{DataProcessor.clean_html_tags(content)}"
elif row["MSG_TYPE"] == "image":
return f"图片:{DataProcessor.clean_html_tags(content)}"
else:
return content
filtered_df['processed_content'] = filtered_df.apply(process_content, axis=1)
# 过滤掉空昵称
filtered_df = filtered_df[filtered_df['sender_nickname'].notna() & (filtered_df['sender_nickname'] != '')]
# 转换为字典列表
result = []
for record in filtered_df.to_dict('records'):
# 如果上一条消息和当前消息的发送者、创建时间、消息内容相同,则跳过
if result and result[-1]['会话id'] == record['SESSION_ID'] and result[-1]['消息发送者'] == record['message_sender'] and result[-1]['创建时间'] == record['CREATE_TIME'] and result[-1]['消息内容'] == record['processed_content']:
continue
# 创建消息字典
message_dict = {
"账号id": record["ACCOUNT"],
"会话id": record["SESSION_ID"],
"消息内容": record["processed_content"],
"消息发送者": record["message_sender"],
"发送者昵称": record["sender_nickname"],
"创建时间": record["CREATE_TIME"],
"SEQUENCE_ID": record["SEQUENCE_ID"],
}
# 添加标签信息(如果有)
if session_tag_dict:
if 'sessionTagFirst' in session_tag_dict:
first_tag = session_tag_dict['sessionTagFirst']
message_dict["一级标签"] = first_tag.get('tagName', '')
if 'sessionTagSecond' in session_tag_dict:
second_tag = session_tag_dict['sessionTagSecond']
message_dict["二级标签"] = second_tag.get('tagName', '')
result.append(message_dict)
return result
class MariaDBClient:
"""优化后的MariaDB数据库客户端"""
def __init__(self, config: DatabaseConfig, max_connections: int = 10):
self.config = config
self.connection_pool = ConnectionPool(config, max_connections)
self.data_processor = DataProcessor()
def __enter__(self) -> 'MariaDBClient':
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.close()
def close(self) -> None:
"""关闭客户端"""
self.connection_pool.close_all()
def execute_query(self, sql: str, params: Optional[Tuple] = None) -> Tuple[Optional[pd.DataFrame], List[str]]:
"""执行SQL查询"""
try:
with self.connection_pool.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(sql, params)
results = cursor.fetchall()
# 获取列名
column_names = [desc[0] for desc in cursor.description] if cursor.description else []
if results:
df = pd.DataFrame(results, columns=column_names)
return df, column_names
else:
return pd.DataFrame(), column_names
except Exception as e:
logger.error(f"执行查询时出错: {e}")
logger.error(f"SQL: {sql}")
return None, []
def query_sessions(self, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""查询指定日期范围内的会话数据"""
sql = """
SELECT ACCOUNT, BEGIN_TIME, END_TIME, CUST_SEND_MESSAGE_COUNT,
AGENT_SEND_MESSAGE_COUNT, STATUS, CHANNEL_NAME, SESSION_ID, SESSION_TAG_NAME, JSON
FROM crm_hlyj.crm_hlyj_dsri
WHERE BEGIN_TIME >= %s
AND BEGIN_TIME < %s
AND STATUS = 'assign'
ORDER BY BEGIN_TIME DESC
"""
df, _ = self.execute_query(sql, (start_date, end_date))
return df
def get_sequence_id_by_session_id(self, session_id: str) -> Optional[pd.DataFrame]:
"""根据会话ID查询消息详情"""
sql = """
SELECT SEQUENCE_ID
FROM crm_hlyj.crm_hlyj_dmri
WHERE SESSION_ID = %s
"""
df, _ = self.execute_query(sql, (session_id,))
return df
def query_messages_by_session_id(self, session_id: str) -> Optional[pd.DataFrame]:
"""根据会话ID查询消息详情"""
sql = """
SELECT CREATE_TIME, CUS_NICK_NAME, MODE, MSG_TYPE, AGENT_NAME, CONTENT,
SESSION_ID, ACCOUNT, SYSTEM_MODE_MESSAGE_TYPE, SEQUENCE_ID
FROM crm_hlyj.crm_hlyj_dmri
WHERE SESSION_ID = %s
ORDER BY CREATE_TIME
"""
df, _ = self.execute_query(sql, (session_id,))
return df
def export_to_excel(self, data: List[Dict[str, Any]], filename: str, output_dir: str = "output") -> Optional[str]:
"""导出数据到Excel文件"""
if not data:
logger.warning(f"没有数据可导出到 {filename}")
return None
try:
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
# 生成文件路径
# timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
file_path = os.path.join(output_dir, f"{filename}.xlsx")
# 准备数据:不同对话之间添加空行
all_rows = []
current_session_id = None
for conversation in data:
if not conversation: # 跳过空对话
continue
# 如果是新的会话,添加空行(除了第一个会话)
if current_session_id and current_session_id != conversation[0]["会话id"]:
empty_row = {key: "" for key in conversation[0].keys()}
all_rows.append(empty_row)
# 更新当前会话ID
current_session_id = conversation[0]["会话id"]
# 添加当前会话的所有消息
all_rows.extend(conversation)
# 创建DataFrame并导出
if all_rows:
df = pd.DataFrame(all_rows)
with pd.ExcelWriter(file_path, engine='openpyxl') as writer:
df.to_excel(writer, sheet_name='对话记录', index=False)
logger.info(f"数据已导出到 {file_path}")
return file_path
else:
logger.warning("没有有效数据可导出")
return None
except Exception as e:
logger.error(f"导出到Excel时出错: {e}")
return None
class SessionProcessor:
"""会话处理器,负责并发处理"""
def __init__(self, db_client: MariaDBClient, max_workers: int = None):
self.db_client = db_client
self.max_workers = max_workers if max_workers is not None else os.cpu_count()
self.temp_save_lock = threading.Lock() # 添加锁用于保护临时保存操作
logger.info(f"初始化会话处理器: max_workers={self.max_workers}")
def process_sessions(self, sessions_df: pd.DataFrame) -> List[List[Dict[str, Any]]]:
"""处理所有会话数据"""
if sessions_df.empty:
logger.warning("没有会话数据需要处理")
return []
total_sessions = len(sessions_df)
logger.info(f"开始处理 {total_sessions} 个会话...")
all_conversations = []
# 直接并发处理每个会话
def process_single_session(session_row):
try:
session_id = session_row['SESSION_ID']
messages_df = self.db_client.query_messages_by_session_id(session_id)
if messages_df is not None and not messages_df.empty:
conversation = self.db_client.data_processor.messages_df_to_list(messages_df, session_row)
if conversation:
return conversation
except Exception as e:
logger.error(f"处理会话 {session_row.get('SESSION_ID', 'unknown')} 时出错: {e}")
return None
# 使用线程池处理所有会话
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# 提交所有会话处理任务
future_to_session = {
executor.submit(process_single_session, row): i
for i, row in sessions_df.iterrows()
}
# 收集结果
with tqdm(total=total_sessions, desc="处理会话进度") as pbar:
for future in concurrent.futures.as_completed(future_to_session):
try:
conversation = future.result()
if conversation:
all_conversations.append(conversation)
# 每处理100个对话临时保存一次
if len(all_conversations) % 100 == 0:
with self.temp_save_lock:
logger.info(f"临时保存 {len(all_conversations)} 个对话")
temp_output_file = self.db_client.export_to_excel(
all_conversations,
f"客服对话记录_临时保存",
output_dir="/data/QueryRewrite/data/excel"
)
if temp_output_file:
logger.info(f"临时保存完成: {temp_output_file}")
except Exception as e:
session_idx = future_to_session[future]
logger.error(f"处理会话索引 {session_idx} 时出错: {e}")
pbar.update(1)
logger.info(f"处理完成,共获得 {len(all_conversations)} 个有效对话")
return all_conversations
def main() -> None:
"""主函数"""
try:
# 加载配置
config = DatabaseConfig.from_config_file()
logger.info(f"使用数据库配置: {config.host}:{config.port}")
# 创建数据库客户端
with MariaDBClient(config, max_connections=12) as db_client:
# 查询会话数据
start_date = '2025-08-01 00:00:00'
end_date = '2025-08-01 23:00:00'
logger.info(f"查询时间范围: {start_date}{end_date}")
# 创建会话处理器
processor = SessionProcessor(db_client)
# is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
# if is_debug:
# messages_df = db_client.query_messages_by_session_id("86c919e0-09f1-11f0-84ae-2daf59566989")
# print(db_client.data_processor.messages_df_to_list(messages_df))
# return []
sessions_df = db_client.query_sessions(start_date, end_date)
if sessions_df is None or sessions_df.empty:
logger.warning("没有找到符合条件的会话数据")
return
# 处理会话数据
all_conversations = processor.process_sessions(sessions_df)
# 导出结果
if all_conversations:
output_file = db_client.export_to_excel(
all_conversations,
"客服对话记录",
output_dir="/data/QueryRewrite/data/excel"
)
if output_file:
logger.info(f"处理完成!共导出 {len(all_conversations)} 个对话到文件: {output_file}")
else:
logger.error("导出文件失败")
else:
logger.warning("没有有效的对话数据可导出")
except KeyboardInterrupt:
logger.info("用户中断程序")
except Exception as e:
logger.error(f"程序执行出错: {e}", exc_info=True)
if __name__ == "__main__":
main()