From c02fe2624a3407262dc6098e9ccd023b9e5e9ef0 Mon Sep 17 00:00:00 2001 From: ouyangyouzhang Date: Mon, 4 Aug 2025 08:43:35 +0800 Subject: [PATCH] =?UTF-8?q?=E5=90=8C=E6=AD=A5=E5=90=88=E7=90=86=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E5=AD=97=E6=AE=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rag2_0/demo/heli_db_to_excel.py | 87 ++++++++++++++++++-- rag2_0/intent_recognition/PromptTemplates.py | 2 +- 2 files changed, 81 insertions(+), 8 deletions(-) diff --git a/rag2_0/demo/heli_db_to_excel.py b/rag2_0/demo/heli_db_to_excel.py index 6d59907..aac2b91 100755 --- a/rag2_0/demo/heli_db_to_excel.py +++ b/rag2_0/demo/heli_db_to_excel.py @@ -35,6 +35,46 @@ logging.basicConfig( ) logger = logging.getLogger(__name__) +def parse_session_tags(input_string): + """ + 解析sessionTag格式的字符串,支持任意数量的sessionTag + 支持格式:sessionTagFirst, sessionTagSecond, sessionTagThird, sessionTagFourth 等 + """ + # 去除外层的方括号和引号 + cleaned_string = input_string.strip('[]"') + + # 使用正则表达式匹配所有的sessionTag + # 匹配模式:sessionTag + 任意后缀 + = + 花括号内容 + pattern = r'sessionTag(\w+)=\{([^}]+)\}' + matches = re.findall(pattern, cleaned_string) + + result = {} + + for tag_suffix, content in matches: + # 解析每个tag的内容 + tag_data = {} + + # 提取键值对,支持中文和各种字符 + kv_pattern = r'(\w+)=([^,}]+?)(?=,\s*\w+=|$|,\s*$)' + kv_matches = re.findall(kv_pattern, content) + + for key, value in kv_matches: + # 清理值的空白字符 + cleaned_value = value.strip() + + # 尝试转换数据类型 + if cleaned_value.isdigit(): + tag_data[key] = int(cleaned_value) + elif cleaned_value.lower() in ['true', 'false']: + tag_data[key] = cleaned_value.lower() == 'true' + else: + tag_data[key] = cleaned_value + + # 构造完整的sessionTag键名 + session_key = f'sessionTag{tag_suffix}' + result[session_key] = tag_data + + return result @dataclass class DatabaseConfig: @@ -228,7 +268,23 @@ class DataProcessor: return clean_text.strip() @staticmethod - def messages_df_to_list(messages_df: pd.DataFrame) -> List[Dict[str, Any]]: + def get_session_tag_dict(json_data: str) -> dict: + """解析JSON数据获取会话标签字典""" + try: + json_data_dict = json.loads(json_data) + session_tag_list_str = json_data_dict.get('sessionMultiTagList', None) + if not session_tag_list_str: + return {} + + result = parse_session_tags(session_tag_list_str) + return result + + except (json.JSONDecodeError, Exception) as e: + logger.error(f"解析会话标签时出错: {e}") + return {} + + @staticmethod + def messages_df_to_list(messages_df: pd.DataFrame, session_row) -> List[Dict[str, Any]]: """将消息DataFrame转换为字典列表,使用高效的向量化操作""" if messages_df.empty: return [] @@ -249,6 +305,9 @@ class DataProcessor: axis=1 ) + json_data = session_row['JSON'] + session_tag_dict = DataProcessor.get_session_tag_dict(json_data) + # 处理内容 def process_content(row): content = row["CONTENT"] @@ -270,7 +329,9 @@ class DataProcessor: # 如果上一条消息和当前消息的发送者、创建时间、消息内容相同,则跳过 if result and result[-1]['会话id'] == record['SESSION_ID'] and result[-1]['消息发送者'] == record['message_sender'] and result[-1]['创建时间'] == record['CREATE_TIME'] and result[-1]['消息内容'] == record['processed_content']: continue - result.append({ + + # 创建消息字典 + message_dict = { "账号id": record["ACCOUNT"], "会话id": record["SESSION_ID"], "消息内容": record["processed_content"], @@ -278,7 +339,19 @@ class DataProcessor: "发送者昵称": record["sender_nickname"], "创建时间": record["CREATE_TIME"], "SEQUENCE_ID": record["SEQUENCE_ID"], - }) + } + + # 添加标签信息(如果有) + if session_tag_dict: + if 'sessionTagFirst' in session_tag_dict: + first_tag = session_tag_dict['sessionTagFirst'] + message_dict["一级标签"] = first_tag.get('tagName', '') + + if 'sessionTagSecond' in session_tag_dict: + second_tag = session_tag_dict['sessionTagSecond'] + message_dict["二级标签"] = second_tag.get('tagName', '') + + result.append(message_dict) return result @@ -327,7 +400,7 @@ class MariaDBClient: """查询指定日期范围内的会话数据""" sql = """ SELECT ACCOUNT, BEGIN_TIME, END_TIME, CUST_SEND_MESSAGE_COUNT, - AGENT_SEND_MESSAGE_COUNT, STATUS, CHANNEL_NAME, SESSION_ID, SESSION_TAG_NAME + AGENT_SEND_MESSAGE_COUNT, STATUS, CHANNEL_NAME, SESSION_ID, SESSION_TAG_NAME, JSON FROM crm_hlyj.crm_hlyj_dsri WHERE BEGIN_TIME >= %s AND BEGIN_TIME < %s @@ -439,7 +512,7 @@ class SessionProcessor: messages_df = self.db_client.query_messages_by_session_id(session_id) if messages_df is not None and not messages_df.empty: - conversation = self.db_client.data_processor.messages_df_to_list(messages_df) + conversation = self.db_client.data_processor.messages_df_to_list(messages_df, session_row) if conversation: return conversation except Exception as e: @@ -494,8 +567,8 @@ def main() -> None: # 创建数据库客户端 with MariaDBClient(config, max_connections=12) as db_client: # 查询会话数据 - start_date = '2025-06-12 00:00:00' - end_date = '2025-07-01 00:00:00' + start_date = '2025-08-01 00:00:00' + end_date = '2025-08-01 23:00:00' logger.info(f"查询时间范围: {start_date} 到 {end_date}") # 创建会话处理器 diff --git a/rag2_0/intent_recognition/PromptTemplates.py b/rag2_0/intent_recognition/PromptTemplates.py index 2feb93b..6ceb363 100755 --- a/rag2_0/intent_recognition/PromptTemplates.py +++ b/rag2_0/intent_recognition/PromptTemplates.py @@ -39,7 +39,7 @@ classification_info="""【垂直领域分类】: 2. 故障排查:软件运行异常、软件报错、软件显示错误等 【业务问题包括以下两类】: -1. 专业咨询:涉及电力造价规范、工程计价规则问题、行业标准解读等 +1. 专业咨询:涉及电力造价规范、工程计价规则问题、行业标准解读、定额套用、建筑问题等 2. 数据问题:涉及电力造价费用、造价指标的计算或构成等 【安装下载注册包括以下三类】: