From 5b5a2f2b16aac1ced656336ed9ea5ef349a8c98e Mon Sep 17 00:00:00 2001 From: ouyangyouzhang Date: Mon, 14 Jul 2025 18:51:03 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84=E4=BC=9A=E8=AF=9D=E5=A4=84?= =?UTF-8?q?=E7=90=86=E9=80=BB=E8=BE=91=EF=BC=8C=E7=A7=BB=E9=99=A4=E6=89=B9?= =?UTF-8?q?=E9=87=8F=E5=A4=84=E7=90=86=EF=BC=8C=E6=94=B9=E4=B8=BA=E7=9B=B4?= =?UTF-8?q?=E6=8E=A5=E5=B9=B6=E5=8F=91=E5=A4=84=E7=90=86=E6=AF=8F=E4=B8=AA?= =?UTF-8?q?=E4=BC=9A=E8=AF=9D=EF=BC=8C=E4=BC=98=E5=8C=96=E4=B8=B4=E6=97=B6?= =?UTF-8?q?=E4=BF=9D=E5=AD=98=E6=9C=BA=E5=88=B6=E3=80=82=E5=90=8C=E6=97=B6?= =?UTF-8?q?=EF=BC=8C=E6=9B=B4=E6=96=B0DifyExporter=E7=B1=BB=EF=BC=8C?= =?UTF-8?q?=E6=96=B0=E5=A2=9E=E8=8E=B7=E5=8F=96=E8=8A=82=E7=82=B9=E4=BF=A1?= =?UTF-8?q?=E6=81=AF=E5=92=8C=E8=AF=8D=E6=9D=A1=E5=88=97=E8=A1=A8=E7=9A=84?= =?UTF-8?q?=E6=96=B9=E6=B3=95=EF=BC=8C=E8=B0=83=E6=95=B4=E5=AF=BC=E5=87=BA?= =?UTF-8?q?=E5=88=97=E4=BB=A5=E5=8C=85=E5=90=AB=E5=A4=87=E6=B3=A8=E4=BF=A1?= =?UTF-8?q?=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rag2_0/demo/heli_db_to_excel.py | 109 ++++++++++++++------------------ rag2_0/dify/export_new_dify.py | 68 +++++++++++++++----- 2 files changed, 99 insertions(+), 78 deletions(-) diff --git a/rag2_0/demo/heli_db_to_excel.py b/rag2_0/demo/heli_db_to_excel.py index 823fd90..7146933 100755 --- a/rag2_0/demo/heli_db_to_excel.py +++ b/rag2_0/demo/heli_db_to_excel.py @@ -397,37 +397,15 @@ class MariaDBClient: return None -def process_session_batch(db_client: MariaDBClient, session_batch: pd.DataFrame) -> List[List[Dict[str, Any]]]: - """批量处理会话数据""" - conversations = [] - - for _, session_row in session_batch.iterrows(): - try: - session_id = session_row['SESSION_ID'] - messages_df = db_client.query_messages_by_session_id(session_id) - - if messages_df is not None and not messages_df.empty: - conversation = db_client.data_processor.messages_df_to_list(messages_df) - if conversation: - conversations.append(conversation) - - except Exception as e: - logger.error(f"处理会话 {session_row.get('SESSION_ID', 'unknown')} 时出错: {e}") - continue - - return conversations - - class SessionProcessor: - """会话处理器,负责批量和并发处理""" + """会话处理器,负责并发处理""" - def __init__(self, db_client: MariaDBClient, max_workers: int = None, batch_size: int = 50): + def __init__(self, db_client: MariaDBClient, max_workers: int = None): self.db_client = db_client self.max_workers = max_workers if max_workers is not None else os.cpu_count() - self.batch_size = batch_size self.temp_save_lock = threading.Lock() # 添加锁用于保护临时保存操作 - logger.info(f"初始化会话处理器: max_workers={self.max_workers}, batch_size={self.batch_size}") + logger.info(f"初始化会话处理器: max_workers={self.max_workers}") def process_sessions(self, sessions_df: pd.DataFrame) -> List[List[Dict[str, Any]]]: """处理所有会话数据""" @@ -438,44 +416,53 @@ class SessionProcessor: total_sessions = len(sessions_df) logger.info(f"开始处理 {total_sessions} 个会话...") - # 分批处理 all_conversations = [] - batch_count = (total_sessions + self.batch_size - 1) // self.batch_size - # 使用线程池处理批次 + + # 直接并发处理每个会话 + def process_single_session(session_row): + try: + session_id = session_row['SESSION_ID'] + messages_df = self.db_client.query_messages_by_session_id(session_id) + + if messages_df is not None and not messages_df.empty: + conversation = self.db_client.data_processor.messages_df_to_list(messages_df) + if conversation: + return conversation + except Exception as e: + logger.error(f"处理会话 {session_row.get('SESSION_ID', 'unknown')} 时出错: {e}") + return None + + # 使用线程池处理所有会话 with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor: - # 提交所有批次任务 - future_to_batch = {} - - for i in range(0, total_sessions, self.batch_size): - batch = sessions_df.iloc[i:i + self.batch_size] - future = executor.submit(process_session_batch, self.db_client, batch) - future_to_batch[future] = i // self.batch_size + 1 + # 提交所有会话处理任务 + future_to_session = { + executor.submit(process_single_session, row): i + for i, row in sessions_df.iterrows() + } # 收集结果 - with tqdm(total=batch_count, desc="处理批次进度") as pbar: - for future in concurrent.futures.as_completed(future_to_batch): + with tqdm(total=total_sessions, desc="处理会话进度") as pbar: + for future in concurrent.futures.as_completed(future_to_session): try: - batch_conversations = future.result() - all_conversations.extend(batch_conversations) - - # 使用锁保护临时列表的操作 - with self.temp_save_lock: + conversation = future.result() + if conversation: + all_conversations.append(conversation) + # 每处理100个对话临时保存一次 - logger.info(f"临时保存 {len(all_conversations)} 个对话") - temp_output_file = self.db_client.export_to_excel( - all_conversations, - f"客服对话记录_临时保存", - output_dir="/data/QueryRewrite/data/excel" - ) - if temp_output_file: - logger.info(f"临时保存完成: {temp_output_file}") - - batch_num = future_to_batch[future] - logger.debug(f"批次 {batch_num} 完成,获得 {len(batch_conversations)} 个对话") + if len(all_conversations) % 100 == 0: + with self.temp_save_lock: + logger.info(f"临时保存 {len(all_conversations)} 个对话") + temp_output_file = self.db_client.export_to_excel( + all_conversations, + f"客服对话记录_临时保存", + output_dir="/data/QueryRewrite/data/excel" + ) + if temp_output_file: + logger.info(f"临时保存完成: {temp_output_file}") except Exception as e: - batch_num = future_to_batch[future] - logger.error(f"处理批次 {batch_num} 时出错: {e}") + session_idx = future_to_session[future] + logger.error(f"处理会话索引 {session_idx} 时出错: {e}") pbar.update(1) @@ -498,12 +485,12 @@ def main() -> None: logger.info(f"查询时间范围: {start_date} 到 {end_date}") # 创建会话处理器 - processor = SessionProcessor(db_client, batch_size=100) - is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None - if is_debug: - messages_df = db_client.query_messages_by_session_id("86c919e0-09f1-11f0-84ae-2daf59566989") - print(db_client.data_processor.messages_df_to_list(messages_df)) - return [] + processor = SessionProcessor(db_client) + # is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None + # if is_debug: + # messages_df = db_client.query_messages_by_session_id("86c919e0-09f1-11f0-84ae-2daf59566989") + # print(db_client.data_processor.messages_df_to_list(messages_df)) + # return [] sessions_df = db_client.query_sessions(start_date, end_date) diff --git a/rag2_0/dify/export_new_dify.py b/rag2_0/dify/export_new_dify.py index 04cc6b1..a821ccc 100644 --- a/rag2_0/dify/export_new_dify.py +++ b/rag2_0/dify/export_new_dify.py @@ -120,10 +120,54 @@ class DifyExporter: if vertical_classification == "固定话术类": return "使用固定话术" - if sub_classification == "软件锁类": - return "固定引导至博微软件助手中操作" return "" + def get_node_info_by_title(self, workflow_node_executions_info:list, title:str) -> dict: + """ + 获取指定标题的节点信息 + """ + if workflow_node_executions_info is None: + return None + for node_execution in workflow_node_executions_info: + if node_execution["title"] == title: + return node_execution + + return None + + def get_wiki_list(self, msg_debug_info) -> list: + """ + 获取检索到的词条列表 + """ + wiki_list = [] + if msg_debug_info['workflow_node_executions_info'] is None: + return [] + node_execution = self.get_node_info_by_title(msg_debug_info['workflow_node_executions_info'], "提取处理后的知识") + if node_execution is not None: + if node_execution["outputs"] is None: + return [] + source_kno = json.loads(node_execution["outputs"])["source_kno"] + knowledge_list_metadata = json.loads(node_execution["outputs"])["knowledge_list_metadata"] + for knowledge in knowledge_list_metadata: + document_name = knowledge['metadata']['document_name'] + wiki_list.append(document_name.split("/")[-1]) + return wiki_list + + lock_node_execution = self.get_node_info_by_title(msg_debug_info['workflow_node_executions_info'], "软件锁知识") + if lock_node_execution is not None: + if lock_node_execution["outputs"] is None: + return [] + source_kno = json.loads(lock_node_execution["outputs"])['json'][0]['retrieve_result'] + for knowledge in source_kno: + document_name = knowledge['metadata']['document_name'] + wiki_list.append(document_name.split("/")[-1]) + + wiki_list.append("锁信息查询") + wiki_list.append("软件锁注册、激活、查锁、试用锁延期") + return wiki_list + + return [] + + def extract_message_info(self, message): """ 从消息中提取信息 @@ -152,18 +196,7 @@ class DifyExporter: if not msg_debug_info: return None - wiki_list = [] - if msg_debug_info['workflow_node_executions_info'] is not None: - for node_execution in msg_debug_info['workflow_node_executions_info']: - if node_execution["title"] == "提取处理后的知识": - if node_execution["outputs"] is None: - break - source_kno = json.loads(node_execution["outputs"])["source_kno"] - knowledge_list_metadata = json.loads(node_execution["outputs"])["knowledge_list_metadata"] - for knowledge in knowledge_list_metadata: - document_name = knowledge['metadata']['document_name'] - wiki_list.append(document_name.split("/")[-1]) - + wiki_list = self.get_wiki_list(msg_debug_info) # 获取备注 remark = self.get_remark(msg_debug_info) @@ -239,7 +272,7 @@ class DifyExporter: # 设置列的顺序 columns_order = [ "msg_id", "提问", "回答", "提问人", "提问时间", - "评价", "问题分类", "检索到的词条" + "评价", "问题分类", "检索到的词条", "备注" ] # 确保所有列都存在,如果不存在则添加空列 @@ -275,7 +308,8 @@ class DifyExporter: "提问时间": 15, "评价": 10, "问题分类": 20, - "检索到的词条": 40 + "检索到的词条": 40, + "备注": 40 } # 应用列宽设置 @@ -347,7 +381,7 @@ if __name__ == "__main__": help='查询日志文件路径') parser.add_argument('--start_date', '-s', type=str, default="2025-07-14 00", help='开始日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 14表示2025年7月8日14时(UTC+8时区)') - parser.add_argument('--end_date', '-e', type=str, default="2025-07-14 15", + parser.add_argument('--end_date', '-e', type=str, default=None, help='结束日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 18表示2025年7月8日18时(UTC+8时区)') args = parser.parse_args()