重构会话处理逻辑,移除批量处理,改为直接并发处理每个会话,优化临时保存机制。同时,更新DifyExporter类,新增获取节点信息和词条列表的方法,调整导出列以包含备注信息

This commit is contained in:
2025-07-14 18:51:03 +08:00
parent af1e1a9d9b
commit 5b5a2f2b16
2 changed files with 99 additions and 78 deletions
+48 -61
View File
@@ -397,37 +397,15 @@ class MariaDBClient:
return None
def process_session_batch(db_client: MariaDBClient, session_batch: pd.DataFrame) -> List[List[Dict[str, Any]]]:
"""批量处理会话数据"""
conversations = []
for _, session_row in session_batch.iterrows():
try:
session_id = session_row['SESSION_ID']
messages_df = db_client.query_messages_by_session_id(session_id)
if messages_df is not None and not messages_df.empty:
conversation = db_client.data_processor.messages_df_to_list(messages_df)
if conversation:
conversations.append(conversation)
except Exception as e:
logger.error(f"处理会话 {session_row.get('SESSION_ID', 'unknown')} 时出错: {e}")
continue
return conversations
class SessionProcessor:
"""会话处理器,负责批量和并发处理"""
"""会话处理器,负责并发处理"""
def __init__(self, db_client: MariaDBClient, max_workers: int = None, batch_size: int = 50):
def __init__(self, db_client: MariaDBClient, max_workers: int = None):
self.db_client = db_client
self.max_workers = max_workers if max_workers is not None else os.cpu_count()
self.batch_size = batch_size
self.temp_save_lock = threading.Lock() # 添加锁用于保护临时保存操作
logger.info(f"初始化会话处理器: max_workers={self.max_workers}, batch_size={self.batch_size}")
logger.info(f"初始化会话处理器: max_workers={self.max_workers}")
def process_sessions(self, sessions_df: pd.DataFrame) -> List[List[Dict[str, Any]]]:
"""处理所有会话数据"""
@@ -438,44 +416,53 @@ class SessionProcessor:
total_sessions = len(sessions_df)
logger.info(f"开始处理 {total_sessions} 个会话...")
# 分批处理
all_conversations = []
batch_count = (total_sessions + self.batch_size - 1) // self.batch_size
# 使用线程池处理批次
# 直接并发处理每个会话
def process_single_session(session_row):
try:
session_id = session_row['SESSION_ID']
messages_df = self.db_client.query_messages_by_session_id(session_id)
if messages_df is not None and not messages_df.empty:
conversation = self.db_client.data_processor.messages_df_to_list(messages_df)
if conversation:
return conversation
except Exception as e:
logger.error(f"处理会话 {session_row.get('SESSION_ID', 'unknown')} 时出错: {e}")
return None
# 使用线程池处理所有会话
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# 提交所有批次任务
future_to_batch = {}
for i in range(0, total_sessions, self.batch_size):
batch = sessions_df.iloc[i:i + self.batch_size]
future = executor.submit(process_session_batch, self.db_client, batch)
future_to_batch[future] = i // self.batch_size + 1
# 提交所有会话处理任务
future_to_session = {
executor.submit(process_single_session, row): i
for i, row in sessions_df.iterrows()
}
# 收集结果
with tqdm(total=batch_count, desc="处理批次进度") as pbar:
for future in concurrent.futures.as_completed(future_to_batch):
with tqdm(total=total_sessions, desc="处理会话进度") as pbar:
for future in concurrent.futures.as_completed(future_to_session):
try:
batch_conversations = future.result()
all_conversations.extend(batch_conversations)
# 使用锁保护临时列表的操作
with self.temp_save_lock:
conversation = future.result()
if conversation:
all_conversations.append(conversation)
# 每处理100个对话临时保存一次
logger.info(f"临时保存 {len(all_conversations)} 个对话")
temp_output_file = self.db_client.export_to_excel(
all_conversations,
f"客服对话记录_临时保存",
output_dir="/data/QueryRewrite/data/excel"
)
if temp_output_file:
logger.info(f"临时保存完成: {temp_output_file}")
batch_num = future_to_batch[future]
logger.debug(f"批次 {batch_num} 完成,获得 {len(batch_conversations)} 个对话")
if len(all_conversations) % 100 == 0:
with self.temp_save_lock:
logger.info(f"临时保存 {len(all_conversations)} 个对话")
temp_output_file = self.db_client.export_to_excel(
all_conversations,
f"客服对话记录_临时保存",
output_dir="/data/QueryRewrite/data/excel"
)
if temp_output_file:
logger.info(f"临时保存完成: {temp_output_file}")
except Exception as e:
batch_num = future_to_batch[future]
logger.error(f"处理批次 {batch_num} 时出错: {e}")
session_idx = future_to_session[future]
logger.error(f"处理会话索引 {session_idx} 时出错: {e}")
pbar.update(1)
@@ -498,12 +485,12 @@ def main() -> None:
logger.info(f"查询时间范围: {start_date}{end_date}")
# 创建会话处理器
processor = SessionProcessor(db_client, batch_size=100)
is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
if is_debug:
messages_df = db_client.query_messages_by_session_id("86c919e0-09f1-11f0-84ae-2daf59566989")
print(db_client.data_processor.messages_df_to_list(messages_df))
return []
processor = SessionProcessor(db_client)
# is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
# if is_debug:
# messages_df = db_client.query_messages_by_session_id("86c919e0-09f1-11f0-84ae-2daf59566989")
# print(db_client.data_processor.messages_df_to_list(messages_df))
# return []
sessions_df = db_client.query_sessions(start_date, end_date)
+51 -17
View File
@@ -120,10 +120,54 @@ class DifyExporter:
if vertical_classification == "固定话术类":
return "使用固定话术"
if sub_classification == "软件锁类":
return "固定引导至博微软件助手中操作"
return ""
def get_node_info_by_title(self, workflow_node_executions_info:list, title:str) -> dict:
"""
获取指定标题的节点信息
"""
if workflow_node_executions_info is None:
return None
for node_execution in workflow_node_executions_info:
if node_execution["title"] == title:
return node_execution
return None
def get_wiki_list(self, msg_debug_info) -> list:
"""
获取检索到的词条列表
"""
wiki_list = []
if msg_debug_info['workflow_node_executions_info'] is None:
return []
node_execution = self.get_node_info_by_title(msg_debug_info['workflow_node_executions_info'], "提取处理后的知识")
if node_execution is not None:
if node_execution["outputs"] is None:
return []
source_kno = json.loads(node_execution["outputs"])["source_kno"]
knowledge_list_metadata = json.loads(node_execution["outputs"])["knowledge_list_metadata"]
for knowledge in knowledge_list_metadata:
document_name = knowledge['metadata']['document_name']
wiki_list.append(document_name.split("/")[-1])
return wiki_list
lock_node_execution = self.get_node_info_by_title(msg_debug_info['workflow_node_executions_info'], "软件锁知识")
if lock_node_execution is not None:
if lock_node_execution["outputs"] is None:
return []
source_kno = json.loads(lock_node_execution["outputs"])['json'][0]['retrieve_result']
for knowledge in source_kno:
document_name = knowledge['metadata']['document_name']
wiki_list.append(document_name.split("/")[-1])
wiki_list.append("锁信息查询")
wiki_list.append("软件锁注册、激活、查锁、试用锁延期")
return wiki_list
return []
def extract_message_info(self, message):
"""
从消息中提取信息
@@ -152,18 +196,7 @@ class DifyExporter:
if not msg_debug_info:
return None
wiki_list = []
if msg_debug_info['workflow_node_executions_info'] is not None:
for node_execution in msg_debug_info['workflow_node_executions_info']:
if node_execution["title"] == "提取处理后的知识":
if node_execution["outputs"] is None:
break
source_kno = json.loads(node_execution["outputs"])["source_kno"]
knowledge_list_metadata = json.loads(node_execution["outputs"])["knowledge_list_metadata"]
for knowledge in knowledge_list_metadata:
document_name = knowledge['metadata']['document_name']
wiki_list.append(document_name.split("/")[-1])
wiki_list = self.get_wiki_list(msg_debug_info)
# 获取备注
remark = self.get_remark(msg_debug_info)
@@ -239,7 +272,7 @@ class DifyExporter:
# 设置列的顺序
columns_order = [
"msg_id", "提问", "回答", "提问人", "提问时间",
"评价", "问题分类", "检索到的词条"
"评价", "问题分类", "检索到的词条", "备注"
]
# 确保所有列都存在,如果不存在则添加空列
@@ -275,7 +308,8 @@ class DifyExporter:
"提问时间": 15,
"评价": 10,
"问题分类": 20,
"检索到的词条": 40
"检索到的词条": 40,
"备注": 40
}
# 应用列宽设置
@@ -347,7 +381,7 @@ if __name__ == "__main__":
help='查询日志文件路径')
parser.add_argument('--start_date', '-s', type=str, default="2025-07-14 00",
help='开始日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 14表示2025年7月8日14时(UTC+8时区)')
parser.add_argument('--end_date', '-e', type=str, default="2025-07-14 15",
parser.add_argument('--end_date', '-e', type=str, default=None,
help='结束日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 18表示2025年7月8日18时(UTC+8时区)')
args = parser.parse_args()