重构会话处理逻辑,移除批量处理,改为直接并发处理每个会话,优化临时保存机制。同时,更新DifyExporter类,新增获取节点信息和词条列表的方法,调整导出列以包含备注信息

This commit is contained in:
2025-07-14 18:51:03 +08:00
parent af1e1a9d9b
commit 5b5a2f2b16
2 changed files with 99 additions and 78 deletions
+48 -61
View File
@@ -397,37 +397,15 @@ class MariaDBClient:
return None
def process_session_batch(db_client: MariaDBClient, session_batch: pd.DataFrame) -> List[List[Dict[str, Any]]]:
"""批量处理会话数据"""
conversations = []
for _, session_row in session_batch.iterrows():
try:
session_id = session_row['SESSION_ID']
messages_df = db_client.query_messages_by_session_id(session_id)
if messages_df is not None and not messages_df.empty:
conversation = db_client.data_processor.messages_df_to_list(messages_df)
if conversation:
conversations.append(conversation)
except Exception as e:
logger.error(f"处理会话 {session_row.get('SESSION_ID', 'unknown')} 时出错: {e}")
continue
return conversations
class SessionProcessor:
"""会话处理器,负责批量和并发处理"""
"""会话处理器,负责并发处理"""
def __init__(self, db_client: MariaDBClient, max_workers: int = None, batch_size: int = 50):
def __init__(self, db_client: MariaDBClient, max_workers: int = None):
self.db_client = db_client
self.max_workers = max_workers if max_workers is not None else os.cpu_count()
self.batch_size = batch_size
self.temp_save_lock = threading.Lock() # 添加锁用于保护临时保存操作
logger.info(f"初始化会话处理器: max_workers={self.max_workers}, batch_size={self.batch_size}")
logger.info(f"初始化会话处理器: max_workers={self.max_workers}")
def process_sessions(self, sessions_df: pd.DataFrame) -> List[List[Dict[str, Any]]]:
"""处理所有会话数据"""
@@ -438,44 +416,53 @@ class SessionProcessor:
total_sessions = len(sessions_df)
logger.info(f"开始处理 {total_sessions} 个会话...")
# 分批处理
all_conversations = []
batch_count = (total_sessions + self.batch_size - 1) // self.batch_size
# 使用线程池处理批次
# 直接并发处理每个会话
def process_single_session(session_row):
try:
session_id = session_row['SESSION_ID']
messages_df = self.db_client.query_messages_by_session_id(session_id)
if messages_df is not None and not messages_df.empty:
conversation = self.db_client.data_processor.messages_df_to_list(messages_df)
if conversation:
return conversation
except Exception as e:
logger.error(f"处理会话 {session_row.get('SESSION_ID', 'unknown')} 时出错: {e}")
return None
# 使用线程池处理所有会话
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# 提交所有批次任务
future_to_batch = {}
for i in range(0, total_sessions, self.batch_size):
batch = sessions_df.iloc[i:i + self.batch_size]
future = executor.submit(process_session_batch, self.db_client, batch)
future_to_batch[future] = i // self.batch_size + 1
# 提交所有会话处理任务
future_to_session = {
executor.submit(process_single_session, row): i
for i, row in sessions_df.iterrows()
}
# 收集结果
with tqdm(total=batch_count, desc="处理批次进度") as pbar:
for future in concurrent.futures.as_completed(future_to_batch):
with tqdm(total=total_sessions, desc="处理会话进度") as pbar:
for future in concurrent.futures.as_completed(future_to_session):
try:
batch_conversations = future.result()
all_conversations.extend(batch_conversations)
# 使用锁保护临时列表的操作
with self.temp_save_lock:
conversation = future.result()
if conversation:
all_conversations.append(conversation)
# 每处理100个对话临时保存一次
logger.info(f"临时保存 {len(all_conversations)} 个对话")
temp_output_file = self.db_client.export_to_excel(
all_conversations,
f"客服对话记录_临时保存",
output_dir="/data/QueryRewrite/data/excel"
)
if temp_output_file:
logger.info(f"临时保存完成: {temp_output_file}")
batch_num = future_to_batch[future]
logger.debug(f"批次 {batch_num} 完成,获得 {len(batch_conversations)} 个对话")
if len(all_conversations) % 100 == 0:
with self.temp_save_lock:
logger.info(f"临时保存 {len(all_conversations)} 个对话")
temp_output_file = self.db_client.export_to_excel(
all_conversations,
f"客服对话记录_临时保存",
output_dir="/data/QueryRewrite/data/excel"
)
if temp_output_file:
logger.info(f"临时保存完成: {temp_output_file}")
except Exception as e:
batch_num = future_to_batch[future]
logger.error(f"处理批次 {batch_num} 时出错: {e}")
session_idx = future_to_session[future]
logger.error(f"处理会话索引 {session_idx} 时出错: {e}")
pbar.update(1)
@@ -498,12 +485,12 @@ def main() -> None:
logger.info(f"查询时间范围: {start_date}{end_date}")
# 创建会话处理器
processor = SessionProcessor(db_client, batch_size=100)
is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
if is_debug:
messages_df = db_client.query_messages_by_session_id("86c919e0-09f1-11f0-84ae-2daf59566989")
print(db_client.data_processor.messages_df_to_list(messages_df))
return []
processor = SessionProcessor(db_client)
# is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
# if is_debug:
# messages_df = db_client.query_messages_by_session_id("86c919e0-09f1-11f0-84ae-2daf59566989")
# print(db_client.data_processor.messages_df_to_list(messages_df))
# return []
sessions_df = db_client.query_sessions(start_date, end_date)