重构会话处理逻辑,移除批量处理,改为直接并发处理每个会话,优化临时保存机制。同时,更新DifyExporter类,新增获取节点信息和词条列表的方法,调整导出列以包含备注信息
This commit is contained in:
@@ -397,37 +397,15 @@ class MariaDBClient:
|
||||
return None
|
||||
|
||||
|
||||
def process_session_batch(db_client: MariaDBClient, session_batch: pd.DataFrame) -> List[List[Dict[str, Any]]]:
|
||||
"""批量处理会话数据"""
|
||||
conversations = []
|
||||
|
||||
for _, session_row in session_batch.iterrows():
|
||||
try:
|
||||
session_id = session_row['SESSION_ID']
|
||||
messages_df = db_client.query_messages_by_session_id(session_id)
|
||||
|
||||
if messages_df is not None and not messages_df.empty:
|
||||
conversation = db_client.data_processor.messages_df_to_list(messages_df)
|
||||
if conversation:
|
||||
conversations.append(conversation)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理会话 {session_row.get('SESSION_ID', 'unknown')} 时出错: {e}")
|
||||
continue
|
||||
|
||||
return conversations
|
||||
|
||||
|
||||
class SessionProcessor:
|
||||
"""会话处理器,负责批量和并发处理"""
|
||||
"""会话处理器,负责并发处理"""
|
||||
|
||||
def __init__(self, db_client: MariaDBClient, max_workers: int = None, batch_size: int = 50):
|
||||
def __init__(self, db_client: MariaDBClient, max_workers: int = None):
|
||||
self.db_client = db_client
|
||||
self.max_workers = max_workers if max_workers is not None else os.cpu_count()
|
||||
self.batch_size = batch_size
|
||||
self.temp_save_lock = threading.Lock() # 添加锁用于保护临时保存操作
|
||||
|
||||
logger.info(f"初始化会话处理器: max_workers={self.max_workers}, batch_size={self.batch_size}")
|
||||
logger.info(f"初始化会话处理器: max_workers={self.max_workers}")
|
||||
|
||||
def process_sessions(self, sessions_df: pd.DataFrame) -> List[List[Dict[str, Any]]]:
|
||||
"""处理所有会话数据"""
|
||||
@@ -438,44 +416,53 @@ class SessionProcessor:
|
||||
total_sessions = len(sessions_df)
|
||||
logger.info(f"开始处理 {total_sessions} 个会话...")
|
||||
|
||||
# 分批处理
|
||||
all_conversations = []
|
||||
batch_count = (total_sessions + self.batch_size - 1) // self.batch_size
|
||||
# 使用线程池处理批次
|
||||
|
||||
# 直接并发处理每个会话
|
||||
def process_single_session(session_row):
|
||||
try:
|
||||
session_id = session_row['SESSION_ID']
|
||||
messages_df = self.db_client.query_messages_by_session_id(session_id)
|
||||
|
||||
if messages_df is not None and not messages_df.empty:
|
||||
conversation = self.db_client.data_processor.messages_df_to_list(messages_df)
|
||||
if conversation:
|
||||
return conversation
|
||||
except Exception as e:
|
||||
logger.error(f"处理会话 {session_row.get('SESSION_ID', 'unknown')} 时出错: {e}")
|
||||
return None
|
||||
|
||||
# 使用线程池处理所有会话
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
# 提交所有批次任务
|
||||
future_to_batch = {}
|
||||
|
||||
for i in range(0, total_sessions, self.batch_size):
|
||||
batch = sessions_df.iloc[i:i + self.batch_size]
|
||||
future = executor.submit(process_session_batch, self.db_client, batch)
|
||||
future_to_batch[future] = i // self.batch_size + 1
|
||||
# 提交所有会话处理任务
|
||||
future_to_session = {
|
||||
executor.submit(process_single_session, row): i
|
||||
for i, row in sessions_df.iterrows()
|
||||
}
|
||||
|
||||
# 收集结果
|
||||
with tqdm(total=batch_count, desc="处理批次进度") as pbar:
|
||||
for future in concurrent.futures.as_completed(future_to_batch):
|
||||
with tqdm(total=total_sessions, desc="处理会话进度") as pbar:
|
||||
for future in concurrent.futures.as_completed(future_to_session):
|
||||
try:
|
||||
batch_conversations = future.result()
|
||||
all_conversations.extend(batch_conversations)
|
||||
|
||||
# 使用锁保护临时列表的操作
|
||||
with self.temp_save_lock:
|
||||
conversation = future.result()
|
||||
if conversation:
|
||||
all_conversations.append(conversation)
|
||||
|
||||
# 每处理100个对话临时保存一次
|
||||
logger.info(f"临时保存 {len(all_conversations)} 个对话")
|
||||
temp_output_file = self.db_client.export_to_excel(
|
||||
all_conversations,
|
||||
f"客服对话记录_临时保存",
|
||||
output_dir="/data/QueryRewrite/data/excel"
|
||||
)
|
||||
if temp_output_file:
|
||||
logger.info(f"临时保存完成: {temp_output_file}")
|
||||
|
||||
batch_num = future_to_batch[future]
|
||||
logger.debug(f"批次 {batch_num} 完成,获得 {len(batch_conversations)} 个对话")
|
||||
if len(all_conversations) % 100 == 0:
|
||||
with self.temp_save_lock:
|
||||
logger.info(f"临时保存 {len(all_conversations)} 个对话")
|
||||
temp_output_file = self.db_client.export_to_excel(
|
||||
all_conversations,
|
||||
f"客服对话记录_临时保存",
|
||||
output_dir="/data/QueryRewrite/data/excel"
|
||||
)
|
||||
if temp_output_file:
|
||||
logger.info(f"临时保存完成: {temp_output_file}")
|
||||
|
||||
except Exception as e:
|
||||
batch_num = future_to_batch[future]
|
||||
logger.error(f"处理批次 {batch_num} 时出错: {e}")
|
||||
session_idx = future_to_session[future]
|
||||
logger.error(f"处理会话索引 {session_idx} 时出错: {e}")
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
@@ -498,12 +485,12 @@ def main() -> None:
|
||||
|
||||
logger.info(f"查询时间范围: {start_date} 到 {end_date}")
|
||||
# 创建会话处理器
|
||||
processor = SessionProcessor(db_client, batch_size=100)
|
||||
is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
|
||||
if is_debug:
|
||||
messages_df = db_client.query_messages_by_session_id("86c919e0-09f1-11f0-84ae-2daf59566989")
|
||||
print(db_client.data_processor.messages_df_to_list(messages_df))
|
||||
return []
|
||||
processor = SessionProcessor(db_client)
|
||||
# is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
|
||||
# if is_debug:
|
||||
# messages_df = db_client.query_messages_by_session_id("86c919e0-09f1-11f0-84ae-2daf59566989")
|
||||
# print(db_client.data_processor.messages_df_to_list(messages_df))
|
||||
# return []
|
||||
|
||||
sessions_df = db_client.query_sessions(start_date, end_date)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user