Files
QueryRewrite/rag2_0/demo/heli_db_to_excel.py
T

494 lines
18 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import annotations
import json
import os
import re
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass
import threading
import pandas as pd
import pymysql
from tqdm import tqdm
import concurrent.futures
import sys
from queue import Queue, Empty, Full
os.makedirs('./data/logs', exist_ok=True)
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('./data/logs/mariadb_client.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# =====================
# 硬编码数据库配置(简化)
# =====================
DB_HOST = '192.168.0.123'
DB_PORT = 3307
DB_USER = 'fuzhimei'
DB_PASSWORD = 'fuzhimei@135'
DB_CHARSET = 'utf8mb4'
DB_CONNECT_TIMEOUT = 10
DB_READ_TIMEOUT = 300
DB_WRITE_TIMEOUT = 300
def parse_session_tags(input_string):
"""
解析sessionTag格式的字符串,支持任意数量的sessionTag
支持格式:sessionTagFirst, sessionTagSecond, sessionTagThird, sessionTagFourth 等
"""
# 去除外层的方括号和引号
cleaned_string = input_string.strip('[]"')
# 使用正则表达式匹配所有的sessionTag
# 匹配模式:sessionTag + 任意后缀 + = + 花括号内容
pattern = r'sessionTag(\w+)=\{([^}]+)\}'
matches = re.findall(pattern, cleaned_string)
result = {}
for tag_suffix, content in matches:
# 解析每个tag的内容
tag_data = {}
# 提取键值对,支持中文和各种字符
kv_pattern = r'(\w+)=([^,}]+?)(?=,\s*\w+=|$|,\s*$)'
kv_matches = re.findall(kv_pattern, content)
for key, value in kv_matches:
# 清理值的空白字符
cleaned_value = value.strip()
# 尝试转换数据类型
if cleaned_value.isdigit():
tag_data[key] = int(cleaned_value)
elif cleaned_value.lower() in ['true', 'false']:
tag_data[key] = cleaned_value.lower() == 'true'
else:
tag_data[key] = cleaned_value
# 构造完整的sessionTag键名
session_key = f'sessionTag{tag_suffix}'
result[session_key] = tag_data
return result
class DataProcessor:
"""数据处理器"""
@staticmethod
def clean_html_tags(text: str) -> str:
"""清除文本中的HTML标签"""
if not isinstance(text, str):
return str(text) if text is not None else ""
# 使用正则表达式移除HTML标签
clean_text = re.sub(r'<[^>]+>', '', text)
# 处理HTML实体
html_entities = {
'&nbsp;': ' ',
'&lt;': '<',
'&gt;': '>',
'&amp;': '&',
'&quot;': '"',
'&apos;': "'"
}
for entity, char in html_entities.items():
clean_text = clean_text.replace(entity, char)
return clean_text.strip()
@staticmethod
def get_session_tag_dict(json_data: str) -> dict:
"""解析JSON数据获取会话标签字典"""
try:
json_data_dict = json.loads(json_data)
session_tag_list_str = json_data_dict.get('sessionMultiTagList', None)
if not session_tag_list_str:
return {}
result = parse_session_tags(session_tag_list_str)
return result
except (json.JSONDecodeError, Exception) as e:
logger.error(f"解析会话标签时出错: {e}")
return {}
@staticmethod
def messages_df_to_list(messages_df: pd.DataFrame, session_row) -> List[Dict[str, Any]]:
"""将消息DataFrame转换为字典列表,使用高效的向量化操作"""
if messages_df.empty:
return []
# 过滤掉系统消息
mask = (messages_df["MODE"] != "system") & (messages_df["SYSTEM_MODE_MESSAGE_TYPE"].isna())
filtered_df = messages_df[mask].copy()
if filtered_df.empty:
return []
# 向量化操作
filtered_df['message_sender'] = filtered_df["MODE"].map({'reply': '坐席', 'receive': '访客'}).fillna('未知')
# 处理发送者昵称
filtered_df['sender_nickname'] = filtered_df.apply(
lambda row: row["AGENT_NAME"] if row["message_sender"] == "坐席" else row["CUS_NICK_NAME"],
axis=1
)
json_data = session_row['JSON']
session_tag_dict = DataProcessor.get_session_tag_dict(json_data)
# 处理内容
def process_content(row):
content = row["CONTENT"]
if row["MSG_TYPE"] == "attachment":
return f"附件:{DataProcessor.clean_html_tags(content)}"
elif row["MSG_TYPE"] == "image":
return f"图片:{DataProcessor.clean_html_tags(content)}"
else:
return content
filtered_df['processed_content'] = filtered_df.apply(process_content, axis=1)
# 过滤掉空昵称
filtered_df = filtered_df[filtered_df['sender_nickname'].notna() & (filtered_df['sender_nickname'] != '')]
# 转换为字典列表
result = []
for record in filtered_df.to_dict('records'):
# 如果上一条消息和当前消息的发送者、创建时间、消息内容相同,则跳过
if result and result[-1]['会话id'] == record['SESSION_ID'] and result[-1]['消息发送者'] == record['message_sender'] and result[-1]['创建时间'] == record['CREATE_TIME'] and result[-1]['消息内容'] == record['processed_content']:
continue
# 创建消息字典
message_dict = {
"账号id": record["ACCOUNT"],
"会话id": record["SESSION_ID"],
"消息内容": record["processed_content"],
"消息发送者": record["message_sender"],
"发送者昵称": record["sender_nickname"],
"创建时间": record["CREATE_TIME"],
"SEQUENCE_ID": record["SEQUENCE_ID"],
}
# 添加标签信息(如果有)
if session_tag_dict:
if 'sessionTagFirst' in session_tag_dict:
first_tag = session_tag_dict['sessionTagFirst']
message_dict["一级标签"] = first_tag.get('tagName', '')
if 'sessionTagSecond' in session_tag_dict:
second_tag = session_tag_dict['sessionTagSecond']
message_dict["二级标签"] = second_tag.get('tagName', '')
result.append(message_dict)
return result
class MariaDBClient:
"""简化版 MariaDB 客户端(内置轻量连接池以复用连接)"""
def __init__(self, max_connections: int = 10):
self.data_processor = DataProcessor()
self._max_connections = max_connections
self._pool: Queue = Queue(maxsize=max_connections)
self._active = 0
self._lock = threading.Lock()
# 预创建少量连接,降低首次延迟
initial = min(3, max_connections)
for _ in range(initial):
conn = self._create_connection()
if conn:
try:
self._pool.put_nowait(conn)
self._active += 1
except Full:
try:
conn.close()
except Exception:
pass
def _create_connection(self):
try:
return pymysql.connect(
host=DB_HOST,
port=DB_PORT,
user=DB_USER,
password=DB_PASSWORD,
charset=DB_CHARSET,
connect_timeout=DB_CONNECT_TIMEOUT,
read_timeout=DB_READ_TIMEOUT,
write_timeout=DB_WRITE_TIMEOUT,
autocommit=True
)
except Exception as e:
logger.error(f"创建数据库连接失败: {e}")
return None
def _acquire_connection(self):
# 先尝试不阻塞获取
try:
return self._pool.get_nowait()
except Empty:
# 池为空,若可创建新连接则创建,否则阻塞等待
with self._lock:
if self._active < self._max_connections:
conn = self._create_connection()
if conn:
self._active += 1
return conn
# 达到上限,阻塞等待可用连接
try:
return self._pool.get(timeout=30)
except Empty:
raise RuntimeError("获取数据库连接超时")
def _release_connection(self, conn):
if not conn:
return
try:
self._pool.put_nowait(conn)
except Full:
# 池已满,关闭多余连接
try:
conn.close()
except Exception:
pass
with self._lock:
self._active = max(0, self._active - 1)
def __enter__(self) -> 'MariaDBClient':
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.close()
def close(self) -> None:
"""关闭连接池中的连接"""
try:
while True:
try:
conn = self._pool.get_nowait()
except Empty:
break
try:
conn.close()
except Exception:
pass
finally:
with self._lock:
self._active = 0
def execute_query(self, sql: str, params: Optional[Tuple] = None) -> Tuple[Optional[pd.DataFrame], List[str]]:
"""执行SQL查询(复用连接池连接)"""
conn = None
try:
conn = self._acquire_connection()
with conn.cursor() as cursor:
cursor.execute(sql, params)
results = cursor.fetchall()
column_names = [desc[0] for desc in cursor.description] if cursor.description else []
if results:
df = pd.DataFrame(results, columns=column_names)
return df, column_names
else:
return pd.DataFrame(), column_names
except Exception as e:
logger.error(f"执行查询时出错: {e}")
logger.error(f"SQL: {sql}")
return None, []
finally:
if conn:
# 若连接异常,尝试关闭并减少活跃计数;否则归还
try:
conn.ping(reconnect=False)
self._release_connection(conn)
except Exception:
try:
conn.close()
except Exception:
pass
with self._lock:
self._active = max(0, self._active - 1)
def query_sessions(self, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""查询指定日期范围内的会话数据"""
sql = """
SELECT ACCOUNT, BEGIN_TIME, END_TIME, CUST_SEND_MESSAGE_COUNT,
AGENT_SEND_MESSAGE_COUNT, STATUS, CHANNEL_NAME, SESSION_ID, SESSION_TAG_NAME, JSON
FROM crm_hlyj.crm_hlyj_dsri
WHERE BEGIN_TIME >= %s
AND BEGIN_TIME < %s
AND STATUS = 'assign'
ORDER BY BEGIN_TIME DESC
"""
df, _ = self.execute_query(sql, (start_date, end_date))
return df
def get_sequence_id_by_session_id(self, session_id: str) -> Optional[pd.DataFrame]:
"""根据会话ID查询消息详情"""
sql = """
SELECT SEQUENCE_ID
FROM crm_hlyj.crm_hlyj_dmri
WHERE SESSION_ID = %s
"""
df, _ = self.execute_query(sql, (session_id,))
return df
def query_messages_by_session_id(self, session_id: str) -> Optional[pd.DataFrame]:
"""根据会话ID查询消息详情"""
sql = """
SELECT CREATE_TIME, CUS_NICK_NAME, MODE, MSG_TYPE, AGENT_NAME, CONTENT,
SESSION_ID, ACCOUNT, SYSTEM_MODE_MESSAGE_TYPE, SEQUENCE_ID
FROM crm_hlyj.crm_hlyj_dmri
WHERE SESSION_ID = %s
ORDER BY CREATE_TIME
"""
df, _ = self.execute_query(sql, (session_id,))
return df
def export_to_excel(self, data: List[Dict[str, Any]], filename: str, output_dir: str = "output") -> Optional[str]:
"""导出数据到Excel文件"""
if not data:
logger.warning(f"没有数据可导出到 {filename}")
return None
try:
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
# 生成文件路径
# timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
file_path = os.path.join(output_dir, f"{filename}.xlsx")
# 准备数据:不同对话之间添加空行
all_rows = []
current_session_id = None
for conversation in data:
if not conversation: # 跳过空对话
continue
# 如果是新的会话,添加空行(除了第一个会话)
if current_session_id and current_session_id != conversation[0]["会话id"]:
empty_row = {key: "" for key in conversation[0].keys()}
all_rows.append(empty_row)
# 更新当前会话ID
current_session_id = conversation[0]["会话id"]
# 添加当前会话的所有消息
all_rows.extend(conversation)
# 创建DataFrame并导出
if all_rows:
df = pd.DataFrame(all_rows)
with pd.ExcelWriter(file_path, engine='openpyxl') as writer:
df.to_excel(writer, sheet_name='对话记录', index=False)
logger.info(f"数据已导出到 {file_path}")
return file_path
else:
logger.warning("没有有效数据可导出")
return None
except Exception as e:
logger.error(f"导出到Excel时出错: {e}")
return None
def main() -> None:
"""主函数(精简版)"""
try:
logger.info(f"使用数据库配置: {DB_HOST}:{DB_PORT}")
# 创建数据库客户端(简化)
with MariaDBClient() as db_client:
# 查询会话数据
start_date = '2025-08-01 00:00:00'
end_date = '2025-08-01 23:00:00'
logger.info(f"查询时间范围: {start_date}{end_date}")
sessions_df = db_client.query_sessions(start_date, end_date)
if sessions_df is None or sessions_df.empty:
logger.warning("没有找到符合条件的会话数据")
return
# 直接并发处理每个会话(替代 SessionProcessor
total_sessions = len(sessions_df)
all_conversations: List[List[Dict[str, Any]]] = []
temp_save_lock = threading.Lock()
def process_single_session(session_row):
try:
session_id = session_row['SESSION_ID']
messages_df = db_client.query_messages_by_session_id(session_id)
if messages_df is not None and not messages_df.empty:
conversation = db_client.data_processor.messages_df_to_list(messages_df, session_row)
if conversation:
return conversation
except Exception as e:
logger.error(f"处理会话 {session_row.get('SESSION_ID', 'unknown')} 时出错: {e}")
return None
with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
future_to_session = {executor.submit(process_single_session, row): i for i, row in sessions_df.iterrows()}
with tqdm(total=total_sessions, desc="处理会话进度") as pbar:
for future in concurrent.futures.as_completed(future_to_session):
try:
conversation = future.result()
if conversation:
all_conversations.append(conversation)
if len(all_conversations) % 100 == 0:
with temp_save_lock:
logger.info(f"临时保存 {len(all_conversations)} 个对话")
temp_output_file = db_client.export_to_excel(
all_conversations,
f"客服对话记录_临时保存",
output_dir="/data/QueryRewrite/data/excel"
)
if temp_output_file:
logger.info(f"临时保存完成: {temp_output_file}")
except Exception as e:
session_idx = future_to_session[future]
logger.error(f"处理会话索引 {session_idx} 时出错: {e}")
pbar.update(1)
# 导出结果
if all_conversations:
output_file = db_client.export_to_excel(
all_conversations,
"客服对话记录",
output_dir="/data/QueryRewrite/data/excel"
)
if output_file:
logger.info(f"处理完成!共导出 {len(all_conversations)} 个对话到文件: {output_file}")
else:
logger.error("导出文件失败")
else:
logger.warning("没有有效的对话数据可导出")
except KeyboardInterrupt:
logger.info("用户中断程序")
except Exception as e:
logger.error(f"程序执行出错: {e}", exc_info=True)
if __name__ == "__main__":
main()