调整提示词、简化代码
This commit is contained in:
+168
-286
@@ -5,23 +5,18 @@ from __future__ import annotations
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import configparser
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from dataclasses import dataclass
|
||||
from contextlib import contextmanager
|
||||
import threading
|
||||
import time
|
||||
from queue import Queue, Empty, Full
|
||||
|
||||
import pandas as pd
|
||||
import pymysql
|
||||
from pymysql.connections import Connection
|
||||
from pymysql.cursors import Cursor
|
||||
from tqdm import tqdm
|
||||
import concurrent.futures
|
||||
import sys
|
||||
from queue import Queue, Empty, Full
|
||||
|
||||
os.makedirs('./data/logs', exist_ok=True)
|
||||
# 配置日志
|
||||
@@ -35,6 +30,18 @@ logging.basicConfig(
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# =====================
|
||||
# 硬编码数据库配置(简化)
|
||||
# =====================
|
||||
DB_HOST = '192.168.0.123'
|
||||
DB_PORT = 3307
|
||||
DB_USER = 'fuzhimei'
|
||||
DB_PASSWORD = 'fuzhimei@135'
|
||||
DB_CHARSET = 'utf8mb4'
|
||||
DB_CONNECT_TIMEOUT = 10
|
||||
DB_READ_TIMEOUT = 300
|
||||
DB_WRITE_TIMEOUT = 300
|
||||
|
||||
def parse_session_tags(input_string):
|
||||
"""
|
||||
解析sessionTag格式的字符串,支持任意数量的sessionTag
|
||||
@@ -74,172 +81,7 @@ def parse_session_tags(input_string):
|
||||
session_key = f'sessionTag{tag_suffix}'
|
||||
result[session_key] = tag_data
|
||||
|
||||
return result
|
||||
|
||||
@dataclass
|
||||
class DatabaseConfig:
|
||||
"""数据库配置类"""
|
||||
host: str = '192.168.0.123'
|
||||
port: int = 3307
|
||||
user: str = 'fuzhimei'
|
||||
password: str = 'fuzhimei@135'
|
||||
charset: str = 'utf8mb4'
|
||||
connect_timeout: int = 10
|
||||
read_timeout: int = 300
|
||||
write_timeout: int = 300
|
||||
|
||||
@classmethod
|
||||
def from_config_file(cls, config_file: str = 'config.ini') -> 'DatabaseConfig':
|
||||
"""从配置文件加载配置"""
|
||||
if not os.path.exists(config_file):
|
||||
logger.warning(f"配置文件 {config_file} 不存在,使用默认配置")
|
||||
return cls()
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read(config_file, encoding='utf-8')
|
||||
|
||||
if 'database' not in config:
|
||||
logger.warning("配置文件中没有 [database] 部分,使用默认配置")
|
||||
return cls()
|
||||
|
||||
db_config = config['database']
|
||||
return cls(
|
||||
host=db_config.get('host', cls.host),
|
||||
port=int(db_config.get('port', cls.port)),
|
||||
user=db_config.get('user', cls.user),
|
||||
password=db_config.get('password', cls.password),
|
||||
charset=db_config.get('charset', cls.charset),
|
||||
connect_timeout=int(db_config.get('connect_timeout', cls.connect_timeout)),
|
||||
read_timeout=int(db_config.get('read_timeout', cls.read_timeout)),
|
||||
write_timeout=int(db_config.get('write_timeout', cls.write_timeout))
|
||||
)
|
||||
|
||||
|
||||
class ConnectionPool:
|
||||
"""数据库连接池"""
|
||||
|
||||
def __init__(self, config: DatabaseConfig, max_connections: int = 10):
|
||||
self.config = config
|
||||
self.max_connections = max_connections
|
||||
self.pool = Queue(maxsize=max_connections)
|
||||
self.active_connections = 0
|
||||
self.lock = threading.Lock()
|
||||
|
||||
# 预创建一些连接
|
||||
self._initialize_pool()
|
||||
|
||||
def _initialize_pool(self) -> None:
|
||||
"""初始化连接池,预创建一些连接"""
|
||||
initial_connections = min(3, self.max_connections)
|
||||
for _ in range(initial_connections):
|
||||
try:
|
||||
conn = self._create_connection()
|
||||
if conn:
|
||||
self.pool.put_nowait(conn)
|
||||
self.active_connections += 1
|
||||
except Full:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"初始化连接池时创建连接失败: {e}")
|
||||
|
||||
def _create_connection(self) -> Optional[Connection]:
|
||||
"""创建新的数据库连接"""
|
||||
try:
|
||||
conn = pymysql.connect(
|
||||
host=self.config.host,
|
||||
port=self.config.port,
|
||||
user=self.config.user,
|
||||
password=self.config.password,
|
||||
charset=self.config.charset,
|
||||
connect_timeout=self.config.connect_timeout,
|
||||
read_timeout=self.config.read_timeout,
|
||||
write_timeout=self.config.write_timeout,
|
||||
autocommit=True
|
||||
)
|
||||
return conn
|
||||
except Exception as e:
|
||||
logger.error(f"创建数据库连接失败: {e}")
|
||||
return None
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self):
|
||||
"""获取连接的上下文管理器"""
|
||||
conn = None
|
||||
try:
|
||||
# 尝试从池中获取连接
|
||||
try:
|
||||
conn = self.pool.get_nowait()
|
||||
except Empty:
|
||||
# 池中没有连接,尝试创建新连接
|
||||
with self.lock:
|
||||
if self.active_connections < self.max_connections:
|
||||
conn = self._create_connection()
|
||||
if conn:
|
||||
self.active_connections += 1
|
||||
else:
|
||||
raise Exception("无法创建新的数据库连接")
|
||||
else:
|
||||
# 等待可用连接
|
||||
logger.info("等待可用连接...")
|
||||
conn = self.pool.get(timeout=30)
|
||||
|
||||
# 检查连接是否仍然有效
|
||||
if conn and not self._is_connection_alive(conn):
|
||||
logger.warning("连接已失效,重新创建")
|
||||
try:
|
||||
conn.close()
|
||||
except:
|
||||
pass
|
||||
conn = self._create_connection()
|
||||
if not conn:
|
||||
raise Exception("重新创建连接失败")
|
||||
|
||||
yield conn
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取数据库连接时出错: {e}")
|
||||
if conn:
|
||||
try:
|
||||
conn.close()
|
||||
except:
|
||||
pass
|
||||
with self.lock:
|
||||
self.active_connections -= 1
|
||||
raise
|
||||
else:
|
||||
# 归还连接到池中
|
||||
if conn:
|
||||
try:
|
||||
self.pool.put_nowait(conn)
|
||||
except Full:
|
||||
# 池已满,关闭连接
|
||||
try:
|
||||
conn.close()
|
||||
except:
|
||||
pass
|
||||
with self.lock:
|
||||
self.active_connections -= 1
|
||||
|
||||
def _is_connection_alive(self, conn: Connection) -> bool:
|
||||
"""检查连接是否仍然有效"""
|
||||
try:
|
||||
conn.ping(reconnect=False)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
def close_all(self) -> None:
|
||||
"""关闭所有连接"""
|
||||
logger.info("正在关闭连接池中的所有连接...")
|
||||
while not self.pool.empty():
|
||||
try:
|
||||
conn = self.pool.get_nowait()
|
||||
conn.close()
|
||||
except (Empty, Exception):
|
||||
break
|
||||
|
||||
self.active_connections = 0
|
||||
logger.info("连接池已关闭")
|
||||
return result
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
@@ -357,12 +199,76 @@ class DataProcessor:
|
||||
|
||||
|
||||
class MariaDBClient:
|
||||
"""优化后的MariaDB数据库客户端"""
|
||||
"""简化版 MariaDB 客户端(内置轻量连接池以复用连接)"""
|
||||
|
||||
def __init__(self, config: DatabaseConfig, max_connections: int = 10):
|
||||
self.config = config
|
||||
self.connection_pool = ConnectionPool(config, max_connections)
|
||||
def __init__(self, max_connections: int = 10):
|
||||
self.data_processor = DataProcessor()
|
||||
self._max_connections = max_connections
|
||||
self._pool: Queue = Queue(maxsize=max_connections)
|
||||
self._active = 0
|
||||
self._lock = threading.Lock()
|
||||
# 预创建少量连接,降低首次延迟
|
||||
initial = min(3, max_connections)
|
||||
for _ in range(initial):
|
||||
conn = self._create_connection()
|
||||
if conn:
|
||||
try:
|
||||
self._pool.put_nowait(conn)
|
||||
self._active += 1
|
||||
except Full:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _create_connection(self):
|
||||
try:
|
||||
return pymysql.connect(
|
||||
host=DB_HOST,
|
||||
port=DB_PORT,
|
||||
user=DB_USER,
|
||||
password=DB_PASSWORD,
|
||||
charset=DB_CHARSET,
|
||||
connect_timeout=DB_CONNECT_TIMEOUT,
|
||||
read_timeout=DB_READ_TIMEOUT,
|
||||
write_timeout=DB_WRITE_TIMEOUT,
|
||||
autocommit=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"创建数据库连接失败: {e}")
|
||||
return None
|
||||
|
||||
def _acquire_connection(self):
|
||||
# 先尝试不阻塞获取
|
||||
try:
|
||||
return self._pool.get_nowait()
|
||||
except Empty:
|
||||
# 池为空,若可创建新连接则创建,否则阻塞等待
|
||||
with self._lock:
|
||||
if self._active < self._max_connections:
|
||||
conn = self._create_connection()
|
||||
if conn:
|
||||
self._active += 1
|
||||
return conn
|
||||
# 达到上限,阻塞等待可用连接
|
||||
try:
|
||||
return self._pool.get(timeout=30)
|
||||
except Empty:
|
||||
raise RuntimeError("获取数据库连接超时")
|
||||
|
||||
def _release_connection(self, conn):
|
||||
if not conn:
|
||||
return
|
||||
try:
|
||||
self._pool.put_nowait(conn)
|
||||
except Full:
|
||||
# 池已满,关闭多余连接
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
with self._lock:
|
||||
self._active = max(0, self._active - 1)
|
||||
|
||||
def __enter__(self) -> 'MariaDBClient':
|
||||
return self
|
||||
@@ -371,30 +277,52 @@ class MariaDBClient:
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
"""关闭客户端"""
|
||||
self.connection_pool.close_all()
|
||||
"""关闭连接池中的连接"""
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
conn = self._pool.get_nowait()
|
||||
except Empty:
|
||||
break
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
with self._lock:
|
||||
self._active = 0
|
||||
|
||||
def execute_query(self, sql: str, params: Optional[Tuple] = None) -> Tuple[Optional[pd.DataFrame], List[str]]:
|
||||
"""执行SQL查询"""
|
||||
"""执行SQL查询(复用连接池连接)"""
|
||||
conn = None
|
||||
try:
|
||||
with self.connection_pool.get_connection() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(sql, params)
|
||||
results = cursor.fetchall()
|
||||
|
||||
# 获取列名
|
||||
column_names = [desc[0] for desc in cursor.description] if cursor.description else []
|
||||
|
||||
if results:
|
||||
df = pd.DataFrame(results, columns=column_names)
|
||||
return df, column_names
|
||||
else:
|
||||
return pd.DataFrame(), column_names
|
||||
|
||||
conn = self._acquire_connection()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(sql, params)
|
||||
results = cursor.fetchall()
|
||||
column_names = [desc[0] for desc in cursor.description] if cursor.description else []
|
||||
if results:
|
||||
df = pd.DataFrame(results, columns=column_names)
|
||||
return df, column_names
|
||||
else:
|
||||
return pd.DataFrame(), column_names
|
||||
except Exception as e:
|
||||
logger.error(f"执行查询时出错: {e}")
|
||||
logger.error(f"SQL: {sql}")
|
||||
return None, []
|
||||
finally:
|
||||
if conn:
|
||||
# 若连接异常,尝试关闭并减少活跃计数;否则归还
|
||||
try:
|
||||
conn.ping(reconnect=False)
|
||||
self._release_connection(conn)
|
||||
except Exception:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
with self._lock:
|
||||
self._active = max(0, self._active - 1)
|
||||
|
||||
def query_sessions(self, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||
"""查询指定日期范围内的会话数据"""
|
||||
@@ -483,125 +411,79 @@ class MariaDBClient:
|
||||
logger.error(f"导出到Excel时出错: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class SessionProcessor:
|
||||
"""会话处理器,负责并发处理"""
|
||||
|
||||
def __init__(self, db_client: MariaDBClient, max_workers: int = None):
|
||||
self.db_client = db_client
|
||||
self.max_workers = max_workers if max_workers is not None else os.cpu_count()
|
||||
self.temp_save_lock = threading.Lock() # 添加锁用于保护临时保存操作
|
||||
|
||||
logger.info(f"初始化会话处理器: max_workers={self.max_workers}")
|
||||
|
||||
def process_sessions(self, sessions_df: pd.DataFrame) -> List[List[Dict[str, Any]]]:
|
||||
"""处理所有会话数据"""
|
||||
if sessions_df.empty:
|
||||
logger.warning("没有会话数据需要处理")
|
||||
return []
|
||||
|
||||
total_sessions = len(sessions_df)
|
||||
logger.info(f"开始处理 {total_sessions} 个会话...")
|
||||
|
||||
all_conversations = []
|
||||
|
||||
# 直接并发处理每个会话
|
||||
def process_single_session(session_row):
|
||||
try:
|
||||
session_id = session_row['SESSION_ID']
|
||||
messages_df = self.db_client.query_messages_by_session_id(session_id)
|
||||
|
||||
if messages_df is not None and not messages_df.empty:
|
||||
conversation = self.db_client.data_processor.messages_df_to_list(messages_df, session_row)
|
||||
if conversation:
|
||||
return conversation
|
||||
except Exception as e:
|
||||
logger.error(f"处理会话 {session_row.get('SESSION_ID', 'unknown')} 时出错: {e}")
|
||||
return None
|
||||
|
||||
# 使用线程池处理所有会话
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
# 提交所有会话处理任务
|
||||
future_to_session = {
|
||||
executor.submit(process_single_session, row): i
|
||||
for i, row in sessions_df.iterrows()
|
||||
}
|
||||
|
||||
# 收集结果
|
||||
with tqdm(total=total_sessions, desc="处理会话进度") as pbar:
|
||||
for future in concurrent.futures.as_completed(future_to_session):
|
||||
try:
|
||||
conversation = future.result()
|
||||
if conversation:
|
||||
all_conversations.append(conversation)
|
||||
|
||||
# 每处理100个对话临时保存一次
|
||||
if len(all_conversations) % 100 == 0:
|
||||
with self.temp_save_lock:
|
||||
logger.info(f"临时保存 {len(all_conversations)} 个对话")
|
||||
temp_output_file = self.db_client.export_to_excel(
|
||||
all_conversations,
|
||||
f"客服对话记录_临时保存",
|
||||
output_dir="/data/QueryRewrite/data/excel"
|
||||
)
|
||||
if temp_output_file:
|
||||
logger.info(f"临时保存完成: {temp_output_file}")
|
||||
|
||||
except Exception as e:
|
||||
session_idx = future_to_session[future]
|
||||
logger.error(f"处理会话索引 {session_idx} 时出错: {e}")
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
logger.info(f"处理完成,共获得 {len(all_conversations)} 个有效对话")
|
||||
return all_conversations
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""主函数"""
|
||||
"""主函数(精简版)"""
|
||||
try:
|
||||
# 加载配置
|
||||
config = DatabaseConfig.from_config_file()
|
||||
logger.info(f"使用数据库配置: {config.host}:{config.port}")
|
||||
logger.info(f"使用数据库配置: {DB_HOST}:{DB_PORT}")
|
||||
|
||||
# 创建数据库客户端
|
||||
with MariaDBClient(config, max_connections=12) as db_client:
|
||||
# 创建数据库客户端(简化)
|
||||
with MariaDBClient() as db_client:
|
||||
# 查询会话数据
|
||||
start_date = '2025-08-01 00:00:00'
|
||||
end_date = '2025-08-01 23:00:00'
|
||||
|
||||
logger.info(f"查询时间范围: {start_date} 到 {end_date}")
|
||||
# 创建会话处理器
|
||||
processor = SessionProcessor(db_client)
|
||||
# is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
|
||||
# if is_debug:
|
||||
# messages_df = db_client.query_messages_by_session_id("86c919e0-09f1-11f0-84ae-2daf59566989")
|
||||
# print(db_client.data_processor.messages_df_to_list(messages_df))
|
||||
# return []
|
||||
|
||||
sessions_df = db_client.query_sessions(start_date, end_date)
|
||||
|
||||
if sessions_df is None or sessions_df.empty:
|
||||
logger.warning("没有找到符合条件的会话数据")
|
||||
return
|
||||
|
||||
# 处理会话数据
|
||||
all_conversations = processor.process_sessions(sessions_df)
|
||||
|
||||
# 直接并发处理每个会话(替代 SessionProcessor)
|
||||
total_sessions = len(sessions_df)
|
||||
all_conversations: List[List[Dict[str, Any]]] = []
|
||||
temp_save_lock = threading.Lock()
|
||||
|
||||
def process_single_session(session_row):
|
||||
try:
|
||||
session_id = session_row['SESSION_ID']
|
||||
messages_df = db_client.query_messages_by_session_id(session_id)
|
||||
if messages_df is not None and not messages_df.empty:
|
||||
conversation = db_client.data_processor.messages_df_to_list(messages_df, session_row)
|
||||
if conversation:
|
||||
return conversation
|
||||
except Exception as e:
|
||||
logger.error(f"处理会话 {session_row.get('SESSION_ID', 'unknown')} 时出错: {e}")
|
||||
return None
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
|
||||
future_to_session = {executor.submit(process_single_session, row): i for i, row in sessions_df.iterrows()}
|
||||
with tqdm(total=total_sessions, desc="处理会话进度") as pbar:
|
||||
for future in concurrent.futures.as_completed(future_to_session):
|
||||
try:
|
||||
conversation = future.result()
|
||||
if conversation:
|
||||
all_conversations.append(conversation)
|
||||
if len(all_conversations) % 100 == 0:
|
||||
with temp_save_lock:
|
||||
logger.info(f"临时保存 {len(all_conversations)} 个对话")
|
||||
temp_output_file = db_client.export_to_excel(
|
||||
all_conversations,
|
||||
f"客服对话记录_临时保存",
|
||||
output_dir="/data/QueryRewrite/data/excel"
|
||||
)
|
||||
if temp_output_file:
|
||||
logger.info(f"临时保存完成: {temp_output_file}")
|
||||
except Exception as e:
|
||||
session_idx = future_to_session[future]
|
||||
logger.error(f"处理会话索引 {session_idx} 时出错: {e}")
|
||||
pbar.update(1)
|
||||
|
||||
# 导出结果
|
||||
if all_conversations:
|
||||
output_file = db_client.export_to_excel(
|
||||
all_conversations,
|
||||
all_conversations,
|
||||
"客服对话记录",
|
||||
output_dir="/data/QueryRewrite/data/excel"
|
||||
)
|
||||
|
||||
if output_file:
|
||||
logger.info(f"处理完成!共导出 {len(all_conversations)} 个对话到文件: {output_file}")
|
||||
else:
|
||||
logger.error("导出文件失败")
|
||||
else:
|
||||
logger.warning("没有有效的对话数据可导出")
|
||||
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("用户中断程序")
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user