重构会话处理逻辑,移除批量处理,改为直接并发处理每个会话,优化临时保存机制。同时,更新DifyExporter类,新增获取节点信息和词条列表的方法,调整导出列以包含备注信息
This commit is contained in:
@@ -397,37 +397,15 @@ class MariaDBClient:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def process_session_batch(db_client: MariaDBClient, session_batch: pd.DataFrame) -> List[List[Dict[str, Any]]]:
|
|
||||||
"""批量处理会话数据"""
|
|
||||||
conversations = []
|
|
||||||
|
|
||||||
for _, session_row in session_batch.iterrows():
|
|
||||||
try:
|
|
||||||
session_id = session_row['SESSION_ID']
|
|
||||||
messages_df = db_client.query_messages_by_session_id(session_id)
|
|
||||||
|
|
||||||
if messages_df is not None and not messages_df.empty:
|
|
||||||
conversation = db_client.data_processor.messages_df_to_list(messages_df)
|
|
||||||
if conversation:
|
|
||||||
conversations.append(conversation)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"处理会话 {session_row.get('SESSION_ID', 'unknown')} 时出错: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
return conversations
|
|
||||||
|
|
||||||
|
|
||||||
class SessionProcessor:
|
class SessionProcessor:
|
||||||
"""会话处理器,负责批量和并发处理"""
|
"""会话处理器,负责并发处理"""
|
||||||
|
|
||||||
def __init__(self, db_client: MariaDBClient, max_workers: int = None, batch_size: int = 50):
|
def __init__(self, db_client: MariaDBClient, max_workers: int = None):
|
||||||
self.db_client = db_client
|
self.db_client = db_client
|
||||||
self.max_workers = max_workers if max_workers is not None else os.cpu_count()
|
self.max_workers = max_workers if max_workers is not None else os.cpu_count()
|
||||||
self.batch_size = batch_size
|
|
||||||
self.temp_save_lock = threading.Lock() # 添加锁用于保护临时保存操作
|
self.temp_save_lock = threading.Lock() # 添加锁用于保护临时保存操作
|
||||||
|
|
||||||
logger.info(f"初始化会话处理器: max_workers={self.max_workers}, batch_size={self.batch_size}")
|
logger.info(f"初始化会话处理器: max_workers={self.max_workers}")
|
||||||
|
|
||||||
def process_sessions(self, sessions_df: pd.DataFrame) -> List[List[Dict[str, Any]]]:
|
def process_sessions(self, sessions_df: pd.DataFrame) -> List[List[Dict[str, Any]]]:
|
||||||
"""处理所有会话数据"""
|
"""处理所有会话数据"""
|
||||||
@@ -438,44 +416,53 @@ class SessionProcessor:
|
|||||||
total_sessions = len(sessions_df)
|
total_sessions = len(sessions_df)
|
||||||
logger.info(f"开始处理 {total_sessions} 个会话...")
|
logger.info(f"开始处理 {total_sessions} 个会话...")
|
||||||
|
|
||||||
# 分批处理
|
|
||||||
all_conversations = []
|
all_conversations = []
|
||||||
batch_count = (total_sessions + self.batch_size - 1) // self.batch_size
|
|
||||||
# 使用线程池处理批次
|
# 直接并发处理每个会话
|
||||||
|
def process_single_session(session_row):
|
||||||
|
try:
|
||||||
|
session_id = session_row['SESSION_ID']
|
||||||
|
messages_df = self.db_client.query_messages_by_session_id(session_id)
|
||||||
|
|
||||||
|
if messages_df is not None and not messages_df.empty:
|
||||||
|
conversation = self.db_client.data_processor.messages_df_to_list(messages_df)
|
||||||
|
if conversation:
|
||||||
|
return conversation
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理会话 {session_row.get('SESSION_ID', 'unknown')} 时出错: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 使用线程池处理所有会话
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||||
# 提交所有批次任务
|
# 提交所有会话处理任务
|
||||||
future_to_batch = {}
|
future_to_session = {
|
||||||
|
executor.submit(process_single_session, row): i
|
||||||
for i in range(0, total_sessions, self.batch_size):
|
for i, row in sessions_df.iterrows()
|
||||||
batch = sessions_df.iloc[i:i + self.batch_size]
|
}
|
||||||
future = executor.submit(process_session_batch, self.db_client, batch)
|
|
||||||
future_to_batch[future] = i // self.batch_size + 1
|
|
||||||
|
|
||||||
# 收集结果
|
# 收集结果
|
||||||
with tqdm(total=batch_count, desc="处理批次进度") as pbar:
|
with tqdm(total=total_sessions, desc="处理会话进度") as pbar:
|
||||||
for future in concurrent.futures.as_completed(future_to_batch):
|
for future in concurrent.futures.as_completed(future_to_session):
|
||||||
try:
|
try:
|
||||||
batch_conversations = future.result()
|
conversation = future.result()
|
||||||
all_conversations.extend(batch_conversations)
|
if conversation:
|
||||||
|
all_conversations.append(conversation)
|
||||||
# 使用锁保护临时列表的操作
|
|
||||||
with self.temp_save_lock:
|
|
||||||
# 每处理100个对话临时保存一次
|
# 每处理100个对话临时保存一次
|
||||||
logger.info(f"临时保存 {len(all_conversations)} 个对话")
|
if len(all_conversations) % 100 == 0:
|
||||||
temp_output_file = self.db_client.export_to_excel(
|
with self.temp_save_lock:
|
||||||
all_conversations,
|
logger.info(f"临时保存 {len(all_conversations)} 个对话")
|
||||||
f"客服对话记录_临时保存",
|
temp_output_file = self.db_client.export_to_excel(
|
||||||
output_dir="/data/QueryRewrite/data/excel"
|
all_conversations,
|
||||||
)
|
f"客服对话记录_临时保存",
|
||||||
if temp_output_file:
|
output_dir="/data/QueryRewrite/data/excel"
|
||||||
logger.info(f"临时保存完成: {temp_output_file}")
|
)
|
||||||
|
if temp_output_file:
|
||||||
batch_num = future_to_batch[future]
|
logger.info(f"临时保存完成: {temp_output_file}")
|
||||||
logger.debug(f"批次 {batch_num} 完成,获得 {len(batch_conversations)} 个对话")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
batch_num = future_to_batch[future]
|
session_idx = future_to_session[future]
|
||||||
logger.error(f"处理批次 {batch_num} 时出错: {e}")
|
logger.error(f"处理会话索引 {session_idx} 时出错: {e}")
|
||||||
|
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
@@ -498,12 +485,12 @@ def main() -> None:
|
|||||||
|
|
||||||
logger.info(f"查询时间范围: {start_date} 到 {end_date}")
|
logger.info(f"查询时间范围: {start_date} 到 {end_date}")
|
||||||
# 创建会话处理器
|
# 创建会话处理器
|
||||||
processor = SessionProcessor(db_client, batch_size=100)
|
processor = SessionProcessor(db_client)
|
||||||
is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
|
# is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
|
||||||
if is_debug:
|
# if is_debug:
|
||||||
messages_df = db_client.query_messages_by_session_id("86c919e0-09f1-11f0-84ae-2daf59566989")
|
# messages_df = db_client.query_messages_by_session_id("86c919e0-09f1-11f0-84ae-2daf59566989")
|
||||||
print(db_client.data_processor.messages_df_to_list(messages_df))
|
# print(db_client.data_processor.messages_df_to_list(messages_df))
|
||||||
return []
|
# return []
|
||||||
|
|
||||||
sessions_df = db_client.query_sessions(start_date, end_date)
|
sessions_df = db_client.query_sessions(start_date, end_date)
|
||||||
|
|
||||||
|
|||||||
@@ -120,10 +120,54 @@ class DifyExporter:
|
|||||||
if vertical_classification == "固定话术类":
|
if vertical_classification == "固定话术类":
|
||||||
return "使用固定话术"
|
return "使用固定话术"
|
||||||
|
|
||||||
if sub_classification == "软件锁类":
|
|
||||||
return "固定引导至博微软件助手中操作"
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
def get_node_info_by_title(self, workflow_node_executions_info:list, title:str) -> dict:
|
||||||
|
"""
|
||||||
|
获取指定标题的节点信息
|
||||||
|
"""
|
||||||
|
if workflow_node_executions_info is None:
|
||||||
|
return None
|
||||||
|
for node_execution in workflow_node_executions_info:
|
||||||
|
if node_execution["title"] == title:
|
||||||
|
return node_execution
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_wiki_list(self, msg_debug_info) -> list:
|
||||||
|
"""
|
||||||
|
获取检索到的词条列表
|
||||||
|
"""
|
||||||
|
wiki_list = []
|
||||||
|
if msg_debug_info['workflow_node_executions_info'] is None:
|
||||||
|
return []
|
||||||
|
node_execution = self.get_node_info_by_title(msg_debug_info['workflow_node_executions_info'], "提取处理后的知识")
|
||||||
|
if node_execution is not None:
|
||||||
|
if node_execution["outputs"] is None:
|
||||||
|
return []
|
||||||
|
source_kno = json.loads(node_execution["outputs"])["source_kno"]
|
||||||
|
knowledge_list_metadata = json.loads(node_execution["outputs"])["knowledge_list_metadata"]
|
||||||
|
for knowledge in knowledge_list_metadata:
|
||||||
|
document_name = knowledge['metadata']['document_name']
|
||||||
|
wiki_list.append(document_name.split("/")[-1])
|
||||||
|
return wiki_list
|
||||||
|
|
||||||
|
lock_node_execution = self.get_node_info_by_title(msg_debug_info['workflow_node_executions_info'], "软件锁知识")
|
||||||
|
if lock_node_execution is not None:
|
||||||
|
if lock_node_execution["outputs"] is None:
|
||||||
|
return []
|
||||||
|
source_kno = json.loads(lock_node_execution["outputs"])['json'][0]['retrieve_result']
|
||||||
|
for knowledge in source_kno:
|
||||||
|
document_name = knowledge['metadata']['document_name']
|
||||||
|
wiki_list.append(document_name.split("/")[-1])
|
||||||
|
|
||||||
|
wiki_list.append("锁信息查询")
|
||||||
|
wiki_list.append("软件锁注册、激活、查锁、试用锁延期")
|
||||||
|
return wiki_list
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
def extract_message_info(self, message):
|
def extract_message_info(self, message):
|
||||||
"""
|
"""
|
||||||
从消息中提取信息
|
从消息中提取信息
|
||||||
@@ -152,18 +196,7 @@ class DifyExporter:
|
|||||||
if not msg_debug_info:
|
if not msg_debug_info:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
wiki_list = []
|
wiki_list = self.get_wiki_list(msg_debug_info)
|
||||||
if msg_debug_info['workflow_node_executions_info'] is not None:
|
|
||||||
for node_execution in msg_debug_info['workflow_node_executions_info']:
|
|
||||||
if node_execution["title"] == "提取处理后的知识":
|
|
||||||
if node_execution["outputs"] is None:
|
|
||||||
break
|
|
||||||
source_kno = json.loads(node_execution["outputs"])["source_kno"]
|
|
||||||
knowledge_list_metadata = json.loads(node_execution["outputs"])["knowledge_list_metadata"]
|
|
||||||
for knowledge in knowledge_list_metadata:
|
|
||||||
document_name = knowledge['metadata']['document_name']
|
|
||||||
wiki_list.append(document_name.split("/")[-1])
|
|
||||||
|
|
||||||
# 获取备注
|
# 获取备注
|
||||||
remark = self.get_remark(msg_debug_info)
|
remark = self.get_remark(msg_debug_info)
|
||||||
|
|
||||||
@@ -239,7 +272,7 @@ class DifyExporter:
|
|||||||
# 设置列的顺序
|
# 设置列的顺序
|
||||||
columns_order = [
|
columns_order = [
|
||||||
"msg_id", "提问", "回答", "提问人", "提问时间",
|
"msg_id", "提问", "回答", "提问人", "提问时间",
|
||||||
"评价", "问题分类", "检索到的词条"
|
"评价", "问题分类", "检索到的词条", "备注"
|
||||||
]
|
]
|
||||||
|
|
||||||
# 确保所有列都存在,如果不存在则添加空列
|
# 确保所有列都存在,如果不存在则添加空列
|
||||||
@@ -275,7 +308,8 @@ class DifyExporter:
|
|||||||
"提问时间": 15,
|
"提问时间": 15,
|
||||||
"评价": 10,
|
"评价": 10,
|
||||||
"问题分类": 20,
|
"问题分类": 20,
|
||||||
"检索到的词条": 40
|
"检索到的词条": 40,
|
||||||
|
"备注": 40
|
||||||
}
|
}
|
||||||
|
|
||||||
# 应用列宽设置
|
# 应用列宽设置
|
||||||
@@ -347,7 +381,7 @@ if __name__ == "__main__":
|
|||||||
help='查询日志文件路径')
|
help='查询日志文件路径')
|
||||||
parser.add_argument('--start_date', '-s', type=str, default="2025-07-14 00",
|
parser.add_argument('--start_date', '-s', type=str, default="2025-07-14 00",
|
||||||
help='开始日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 14表示2025年7月8日14时(UTC+8时区)')
|
help='开始日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 14表示2025年7月8日14时(UTC+8时区)')
|
||||||
parser.add_argument('--end_date', '-e', type=str, default="2025-07-14 15",
|
parser.add_argument('--end_date', '-e', type=str, default=None,
|
||||||
help='结束日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 18表示2025年7月8日18时(UTC+8时区)')
|
help='结束日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 18表示2025年7月8日18时(UTC+8时区)')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
Reference in New Issue
Block a user