重构会话处理逻辑,移除批量处理,改为直接并发处理每个会话,优化临时保存机制。同时,更新DifyExporter类,新增获取节点信息和词条列表的方法,调整导出列以包含备注信息
This commit is contained in:
@@ -397,37 +397,15 @@ class MariaDBClient:
|
||||
return None
|
||||
|
||||
|
||||
def process_session_batch(db_client: MariaDBClient, session_batch: pd.DataFrame) -> List[List[Dict[str, Any]]]:
|
||||
"""批量处理会话数据"""
|
||||
conversations = []
|
||||
|
||||
for _, session_row in session_batch.iterrows():
|
||||
try:
|
||||
session_id = session_row['SESSION_ID']
|
||||
messages_df = db_client.query_messages_by_session_id(session_id)
|
||||
|
||||
if messages_df is not None and not messages_df.empty:
|
||||
conversation = db_client.data_processor.messages_df_to_list(messages_df)
|
||||
if conversation:
|
||||
conversations.append(conversation)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理会话 {session_row.get('SESSION_ID', 'unknown')} 时出错: {e}")
|
||||
continue
|
||||
|
||||
return conversations
|
||||
|
||||
|
||||
class SessionProcessor:
|
||||
"""会话处理器,负责批量和并发处理"""
|
||||
"""会话处理器,负责并发处理"""
|
||||
|
||||
def __init__(self, db_client: MariaDBClient, max_workers: int = None, batch_size: int = 50):
|
||||
def __init__(self, db_client: MariaDBClient, max_workers: int = None):
|
||||
self.db_client = db_client
|
||||
self.max_workers = max_workers if max_workers is not None else os.cpu_count()
|
||||
self.batch_size = batch_size
|
||||
self.temp_save_lock = threading.Lock() # 添加锁用于保护临时保存操作
|
||||
|
||||
logger.info(f"初始化会话处理器: max_workers={self.max_workers}, batch_size={self.batch_size}")
|
||||
logger.info(f"初始化会话处理器: max_workers={self.max_workers}")
|
||||
|
||||
def process_sessions(self, sessions_df: pd.DataFrame) -> List[List[Dict[str, Any]]]:
|
||||
"""处理所有会话数据"""
|
||||
@@ -438,29 +416,41 @@ class SessionProcessor:
|
||||
total_sessions = len(sessions_df)
|
||||
logger.info(f"开始处理 {total_sessions} 个会话...")
|
||||
|
||||
# 分批处理
|
||||
all_conversations = []
|
||||
batch_count = (total_sessions + self.batch_size - 1) // self.batch_size
|
||||
# 使用线程池处理批次
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
# 提交所有批次任务
|
||||
future_to_batch = {}
|
||||
|
||||
for i in range(0, total_sessions, self.batch_size):
|
||||
batch = sessions_df.iloc[i:i + self.batch_size]
|
||||
future = executor.submit(process_session_batch, self.db_client, batch)
|
||||
future_to_batch[future] = i // self.batch_size + 1
|
||||
# 直接并发处理每个会话
|
||||
def process_single_session(session_row):
|
||||
try:
|
||||
session_id = session_row['SESSION_ID']
|
||||
messages_df = self.db_client.query_messages_by_session_id(session_id)
|
||||
|
||||
if messages_df is not None and not messages_df.empty:
|
||||
conversation = self.db_client.data_processor.messages_df_to_list(messages_df)
|
||||
if conversation:
|
||||
return conversation
|
||||
except Exception as e:
|
||||
logger.error(f"处理会话 {session_row.get('SESSION_ID', 'unknown')} 时出错: {e}")
|
||||
return None
|
||||
|
||||
# 使用线程池处理所有会话
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
# 提交所有会话处理任务
|
||||
future_to_session = {
|
||||
executor.submit(process_single_session, row): i
|
||||
for i, row in sessions_df.iterrows()
|
||||
}
|
||||
|
||||
# 收集结果
|
||||
with tqdm(total=batch_count, desc="处理批次进度") as pbar:
|
||||
for future in concurrent.futures.as_completed(future_to_batch):
|
||||
with tqdm(total=total_sessions, desc="处理会话进度") as pbar:
|
||||
for future in concurrent.futures.as_completed(future_to_session):
|
||||
try:
|
||||
batch_conversations = future.result()
|
||||
all_conversations.extend(batch_conversations)
|
||||
conversation = future.result()
|
||||
if conversation:
|
||||
all_conversations.append(conversation)
|
||||
|
||||
# 使用锁保护临时列表的操作
|
||||
with self.temp_save_lock:
|
||||
# 每处理100个对话临时保存一次
|
||||
if len(all_conversations) % 100 == 0:
|
||||
with self.temp_save_lock:
|
||||
logger.info(f"临时保存 {len(all_conversations)} 个对话")
|
||||
temp_output_file = self.db_client.export_to_excel(
|
||||
all_conversations,
|
||||
@@ -470,12 +460,9 @@ class SessionProcessor:
|
||||
if temp_output_file:
|
||||
logger.info(f"临时保存完成: {temp_output_file}")
|
||||
|
||||
batch_num = future_to_batch[future]
|
||||
logger.debug(f"批次 {batch_num} 完成,获得 {len(batch_conversations)} 个对话")
|
||||
|
||||
except Exception as e:
|
||||
batch_num = future_to_batch[future]
|
||||
logger.error(f"处理批次 {batch_num} 时出错: {e}")
|
||||
session_idx = future_to_session[future]
|
||||
logger.error(f"处理会话索引 {session_idx} 时出错: {e}")
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
@@ -498,12 +485,12 @@ def main() -> None:
|
||||
|
||||
logger.info(f"查询时间范围: {start_date} 到 {end_date}")
|
||||
# 创建会话处理器
|
||||
processor = SessionProcessor(db_client, batch_size=100)
|
||||
is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
|
||||
if is_debug:
|
||||
messages_df = db_client.query_messages_by_session_id("86c919e0-09f1-11f0-84ae-2daf59566989")
|
||||
print(db_client.data_processor.messages_df_to_list(messages_df))
|
||||
return []
|
||||
processor = SessionProcessor(db_client)
|
||||
# is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
|
||||
# if is_debug:
|
||||
# messages_df = db_client.query_messages_by_session_id("86c919e0-09f1-11f0-84ae-2daf59566989")
|
||||
# print(db_client.data_processor.messages_df_to_list(messages_df))
|
||||
# return []
|
||||
|
||||
sessions_df = db_client.query_sessions(start_date, end_date)
|
||||
|
||||
|
||||
@@ -120,10 +120,54 @@ class DifyExporter:
|
||||
if vertical_classification == "固定话术类":
|
||||
return "使用固定话术"
|
||||
|
||||
if sub_classification == "软件锁类":
|
||||
return "固定引导至博微软件助手中操作"
|
||||
return ""
|
||||
|
||||
def get_node_info_by_title(self, workflow_node_executions_info:list, title:str) -> dict:
|
||||
"""
|
||||
获取指定标题的节点信息
|
||||
"""
|
||||
if workflow_node_executions_info is None:
|
||||
return None
|
||||
for node_execution in workflow_node_executions_info:
|
||||
if node_execution["title"] == title:
|
||||
return node_execution
|
||||
|
||||
return None
|
||||
|
||||
def get_wiki_list(self, msg_debug_info) -> list:
|
||||
"""
|
||||
获取检索到的词条列表
|
||||
"""
|
||||
wiki_list = []
|
||||
if msg_debug_info['workflow_node_executions_info'] is None:
|
||||
return []
|
||||
node_execution = self.get_node_info_by_title(msg_debug_info['workflow_node_executions_info'], "提取处理后的知识")
|
||||
if node_execution is not None:
|
||||
if node_execution["outputs"] is None:
|
||||
return []
|
||||
source_kno = json.loads(node_execution["outputs"])["source_kno"]
|
||||
knowledge_list_metadata = json.loads(node_execution["outputs"])["knowledge_list_metadata"]
|
||||
for knowledge in knowledge_list_metadata:
|
||||
document_name = knowledge['metadata']['document_name']
|
||||
wiki_list.append(document_name.split("/")[-1])
|
||||
return wiki_list
|
||||
|
||||
lock_node_execution = self.get_node_info_by_title(msg_debug_info['workflow_node_executions_info'], "软件锁知识")
|
||||
if lock_node_execution is not None:
|
||||
if lock_node_execution["outputs"] is None:
|
||||
return []
|
||||
source_kno = json.loads(lock_node_execution["outputs"])['json'][0]['retrieve_result']
|
||||
for knowledge in source_kno:
|
||||
document_name = knowledge['metadata']['document_name']
|
||||
wiki_list.append(document_name.split("/")[-1])
|
||||
|
||||
wiki_list.append("锁信息查询")
|
||||
wiki_list.append("软件锁注册、激活、查锁、试用锁延期")
|
||||
return wiki_list
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def extract_message_info(self, message):
|
||||
"""
|
||||
从消息中提取信息
|
||||
@@ -152,18 +196,7 @@ class DifyExporter:
|
||||
if not msg_debug_info:
|
||||
return None
|
||||
|
||||
wiki_list = []
|
||||
if msg_debug_info['workflow_node_executions_info'] is not None:
|
||||
for node_execution in msg_debug_info['workflow_node_executions_info']:
|
||||
if node_execution["title"] == "提取处理后的知识":
|
||||
if node_execution["outputs"] is None:
|
||||
break
|
||||
source_kno = json.loads(node_execution["outputs"])["source_kno"]
|
||||
knowledge_list_metadata = json.loads(node_execution["outputs"])["knowledge_list_metadata"]
|
||||
for knowledge in knowledge_list_metadata:
|
||||
document_name = knowledge['metadata']['document_name']
|
||||
wiki_list.append(document_name.split("/")[-1])
|
||||
|
||||
wiki_list = self.get_wiki_list(msg_debug_info)
|
||||
# 获取备注
|
||||
remark = self.get_remark(msg_debug_info)
|
||||
|
||||
@@ -239,7 +272,7 @@ class DifyExporter:
|
||||
# 设置列的顺序
|
||||
columns_order = [
|
||||
"msg_id", "提问", "回答", "提问人", "提问时间",
|
||||
"评价", "问题分类", "检索到的词条"
|
||||
"评价", "问题分类", "检索到的词条", "备注"
|
||||
]
|
||||
|
||||
# 确保所有列都存在,如果不存在则添加空列
|
||||
@@ -275,7 +308,8 @@ class DifyExporter:
|
||||
"提问时间": 15,
|
||||
"评价": 10,
|
||||
"问题分类": 20,
|
||||
"检索到的词条": 40
|
||||
"检索到的词条": 40,
|
||||
"备注": 40
|
||||
}
|
||||
|
||||
# 应用列宽设置
|
||||
@@ -347,7 +381,7 @@ if __name__ == "__main__":
|
||||
help='查询日志文件路径')
|
||||
parser.add_argument('--start_date', '-s', type=str, default="2025-07-14 00",
|
||||
help='开始日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 14表示2025年7月8日14时(UTC+8时区)')
|
||||
parser.add_argument('--end_date', '-e', type=str, default="2025-07-14 15",
|
||||
parser.add_argument('--end_date', '-e', type=str, default=None,
|
||||
help='结束日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 18表示2025年7月8日18时(UTC+8时区)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
Reference in New Issue
Block a user