diff --git a/rag2_0/demo/dialogue_to_workorder.py b/rag2_0/demo/dialogue_to_workorder.py index 7c3d7e1..63eb5bd 100755 --- a/rag2_0/demo/dialogue_to_workorder.py +++ b/rag2_0/demo/dialogue_to_workorder.py @@ -231,7 +231,7 @@ class DialogueToWorkorder: output_format = self.user_question_and_solution_parser.get_format_instructions() llm_prompt = prompt.format(output_format=output_format, dialogue_str=dialogue_str) - response = self.llm.invoke(user_prompt=llm_prompt) + response = self.llm.invoke(user_prompt=llm_prompt, need_retry=False) try: if response.content.count('user_question') == 1: @@ -261,7 +261,7 @@ class DialogueToWorkorder: except Exception as e: output_format = self.user_question_and_solution_list_parser.get_format_instructions() llm_prompt = prompt.format(output_format=output_format, dialogue_str=dialogue_str) - response = self.llm.invoke(user_prompt=llm_prompt) + response = self.llm.invoke(user_prompt=llm_prompt, need_retry=False) user_question_and_solution_temp = self.user_question_and_solution_list_parser.parse(response.content) return user_question_and_solution_temp.user_question_list @@ -293,7 +293,7 @@ class DialogueToWorkorder: {dialogue_str} """ - response = self.llm.invoke(user_prompt=prompt) + response = self.llm.invoke(user_prompt=prompt, need_retry=False) product_name_and_module_name = self.product_name_and_module_name_parser.parse(response.content) return product_name_and_module_name.product_name, product_name_and_module_name.module_name @@ -322,7 +322,7 @@ class DialogueToWorkorder: {dialogue_str} """ - response = self.llm.invoke(user_prompt=prompt) + response = self.llm.invoke(user_prompt=prompt, need_retry=False) product_line = self.product_line_parser.parse(response.content) return product_line.product_line @@ -358,7 +358,7 @@ class DialogueToWorkorder: {dialogue_str} """ - response = self.llm.invoke(user_prompt=prompt) + response = self.llm.invoke(user_prompt=prompt, need_retry=False) question_type = self.question_type_parser.parse(response.content) return question_type.question_type @@ -394,7 +394,7 @@ class DialogueToWorkorder: """ - response = self.llm.invoke(user_prompt=prompt) + response = self.llm.invoke(user_prompt=prompt, need_retry=False) is_complaint = self.is_complaint_parser.parse(response.content) return (is_complaint.is_dissatisfaction, @@ -479,7 +479,19 @@ class DialogueToWorkorder: # 按会话ID分组 conversation_dict = self.group_conversations_by_id(df) - + # 限制处理的会话数量为前2000个 + if len(conversation_dict) > 2000: + print(f"会话总数为 {len(conversation_dict)},限制处理前2000个会话") + # 获取所有会话ID + conversation_ids = list(conversation_dict.keys()) + # 只保留前2000个会话 + limited_conversation_dict = { + conversation_id: conversation_dict[conversation_id] + for conversation_id in conversation_ids[:2000] + } + conversation_dict = limited_conversation_dict + else: + print(f"会话总数为 {len(conversation_dict)},处理全部会话") # 使用线程池处理每个会话 workorder_dict_list = [] with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: @@ -593,7 +605,7 @@ def main(): args = parse_arguments() # 设置默认文件路径 - conversation_excel_path = args.conversation_file or os.path.join('data', 'excel', '会话内容详情20250528110230.xlsx') + conversation_excel_path = args.conversation_file or os.path.join('data', 'excel', '2025年1月到6月12号所有对话记录.xlsx') product_detail_excel_path = args.product_detail_file or os.path.join('data', 'excel', '产品详情_工单.xlsx') # 创建处理实例 diff --git a/rag2_0/demo/heli_db_to_excel.py b/rag2_0/demo/heli_db_to_excel.py new file mode 100644 index 0000000..9018c09 --- /dev/null +++ b/rag2_0/demo/heli_db_to_excel.py @@ -0,0 +1,537 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from __future__ import annotations +import json +import os +import re +import configparser +import logging +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple, Union +from dataclasses import dataclass +from contextlib import contextmanager +import threading +import time +from queue import Queue, Empty, Full + +import pandas as pd +import pymysql +from pymysql.connections import Connection +from pymysql.cursors import Cursor +from tqdm import tqdm +import concurrent.futures +import sys + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler('./data/log/mariadb_client.log'), + logging.StreamHandler() + ] +) +logger = logging.getLogger(__name__) +os.makedirs('./data/log', exist_ok=True) + +@dataclass +class DatabaseConfig: + """数据库配置类""" + host: str = '192.168.0.123' + port: int = 3307 + user: str = 'fuzhimei' + password: str = 'fuzhimei@135' + charset: str = 'utf8mb4' + connect_timeout: int = 10 + read_timeout: int = 300 + write_timeout: int = 300 + + @classmethod + def from_config_file(cls, config_file: str = 'config.ini') -> 'DatabaseConfig': + """从配置文件加载配置""" + if not os.path.exists(config_file): + logger.warning(f"配置文件 {config_file} 不存在,使用默认配置") + return cls() + + config = configparser.ConfigParser() + config.read(config_file, encoding='utf-8') + + if 'database' not in config: + logger.warning("配置文件中没有 [database] 部分,使用默认配置") + return cls() + + db_config = config['database'] + return cls( + host=db_config.get('host', cls.host), + port=int(db_config.get('port', cls.port)), + user=db_config.get('user', cls.user), + password=db_config.get('password', cls.password), + charset=db_config.get('charset', cls.charset), + connect_timeout=int(db_config.get('connect_timeout', cls.connect_timeout)), + read_timeout=int(db_config.get('read_timeout', cls.read_timeout)), + write_timeout=int(db_config.get('write_timeout', cls.write_timeout)) + ) + + +class ConnectionPool: + """数据库连接池""" + + def __init__(self, config: DatabaseConfig, max_connections: int = 10): + self.config = config + self.max_connections = max_connections + self.pool = Queue(maxsize=max_connections) + self.active_connections = 0 + self.lock = threading.Lock() + + # 预创建一些连接 + self._initialize_pool() + + def _initialize_pool(self) -> None: + """初始化连接池,预创建一些连接""" + initial_connections = min(3, self.max_connections) + for _ in range(initial_connections): + try: + conn = self._create_connection() + if conn: + self.pool.put_nowait(conn) + self.active_connections += 1 + except Full: + break + except Exception as e: + logger.error(f"初始化连接池时创建连接失败: {e}") + + def _create_connection(self) -> Optional[Connection]: + """创建新的数据库连接""" + try: + conn = pymysql.connect( + host=self.config.host, + port=self.config.port, + user=self.config.user, + password=self.config.password, + charset=self.config.charset, + connect_timeout=self.config.connect_timeout, + read_timeout=self.config.read_timeout, + write_timeout=self.config.write_timeout, + autocommit=True + ) + return conn + except Exception as e: + logger.error(f"创建数据库连接失败: {e}") + return None + + @contextmanager + def get_connection(self): + """获取连接的上下文管理器""" + conn = None + try: + # 尝试从池中获取连接 + try: + conn = self.pool.get_nowait() + except Empty: + # 池中没有连接,尝试创建新连接 + with self.lock: + if self.active_connections < self.max_connections: + conn = self._create_connection() + if conn: + self.active_connections += 1 + else: + raise Exception("无法创建新的数据库连接") + else: + # 等待可用连接 + logger.info("等待可用连接...") + conn = self.pool.get(timeout=30) + + # 检查连接是否仍然有效 + if conn and not self._is_connection_alive(conn): + logger.warning("连接已失效,重新创建") + try: + conn.close() + except: + pass + conn = self._create_connection() + if not conn: + raise Exception("重新创建连接失败") + + yield conn + + except Exception as e: + logger.error(f"获取数据库连接时出错: {e}") + if conn: + try: + conn.close() + except: + pass + with self.lock: + self.active_connections -= 1 + raise + else: + # 归还连接到池中 + if conn: + try: + self.pool.put_nowait(conn) + except Full: + # 池已满,关闭连接 + try: + conn.close() + except: + pass + with self.lock: + self.active_connections -= 1 + + def _is_connection_alive(self, conn: Connection) -> bool: + """检查连接是否仍然有效""" + try: + conn.ping(reconnect=False) + return True + except: + return False + + def close_all(self) -> None: + """关闭所有连接""" + logger.info("正在关闭连接池中的所有连接...") + while not self.pool.empty(): + try: + conn = self.pool.get_nowait() + conn.close() + except (Empty, Exception): + break + + self.active_connections = 0 + logger.info("连接池已关闭") + + +class DataProcessor: + """数据处理器""" + + @staticmethod + def clean_html_tags(text: str) -> str: + """清除文本中的HTML标签""" + if not isinstance(text, str): + return str(text) if text is not None else "" + + # 使用正则表达式移除HTML标签 + clean_text = re.sub(r'<[^>]+>', '', text) + # 处理HTML实体 + html_entities = { + ' ': ' ', + '<': '<', + '>': '>', + '&': '&', + '"': '"', + ''': "'" + } + for entity, char in html_entities.items(): + clean_text = clean_text.replace(entity, char) + + return clean_text.strip() + + @staticmethod + def messages_df_to_list(messages_df: pd.DataFrame) -> List[Dict[str, Any]]: + """将消息DataFrame转换为字典列表,使用高效的向量化操作""" + if messages_df.empty: + return [] + + # 过滤掉系统消息 + mask = (messages_df["MODE"] != "system") & (messages_df["SYSTEM_MODE_MESSAGE_TYPE"].isna()) + filtered_df = messages_df[mask].copy() + + if filtered_df.empty: + return [] + + # 向量化操作 + filtered_df['message_sender'] = filtered_df["MODE"].map({'reply': '坐席', 'receive': '访客'}).fillna('未知') + + # 处理发送者昵称 + filtered_df['sender_nickname'] = filtered_df.apply( + lambda row: row["AGENT_NAME"] if row["message_sender"] == "坐席" else row["CUS_NICK_NAME"], + axis=1 + ) + + # 处理内容 + def process_content(row): + content = row["CONTENT"] + if row["MSG_TYPE"] == "attachment": + return f"附件:{DataProcessor.clean_html_tags(content)}" + elif row["MSG_TYPE"] == "image": + return f"图片:{DataProcessor.clean_html_tags(content)}" + else: + return content + + filtered_df['processed_content'] = filtered_df.apply(process_content, axis=1) + + # 过滤掉空昵称 + filtered_df = filtered_df[filtered_df['sender_nickname'].notna() & (filtered_df['sender_nickname'] != '')] + + # 转换为字典列表 + result = [] + for record in filtered_df.to_dict('records'): + result.append({ + "账号id": record["ACCOUNT"], + "会话id": record["SESSION_ID"], + "消息内容": record["processed_content"], + "消息发送者": record["message_sender"], + "发送者昵称": record["sender_nickname"], + "创建时间": record["CREATE_TIME"], + }) + + return result + + +class MariaDBClient: + """优化后的MariaDB数据库客户端""" + + def __init__(self, config: DatabaseConfig, max_connections: int = 10): + self.config = config + self.connection_pool = ConnectionPool(config, max_connections) + self.data_processor = DataProcessor() + + def __enter__(self) -> 'MariaDBClient': + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.close() + + def close(self) -> None: + """关闭客户端""" + self.connection_pool.close_all() + + def execute_query(self, sql: str, params: Optional[Tuple] = None) -> Tuple[Optional[pd.DataFrame], List[str]]: + """执行SQL查询""" + try: + with self.connection_pool.get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute(sql, params) + results = cursor.fetchall() + + # 获取列名 + column_names = [desc[0] for desc in cursor.description] if cursor.description else [] + + if results: + df = pd.DataFrame(results, columns=column_names) + return df, column_names + else: + return pd.DataFrame(), column_names + + except Exception as e: + logger.error(f"执行查询时出错: {e}") + logger.error(f"SQL: {sql}") + return None, [] + + def query_sessions(self, start_date: str, end_date: str) -> Optional[pd.DataFrame]: + """查询指定日期范围内的会话数据""" + sql = """ + SELECT ACCOUNT, BEGIN_TIME, END_TIME, CUST_SEND_MESSAGE_COUNT, + AGENT_SEND_MESSAGE_COUNT, STATUS, CHANNEL_NAME, SESSION_ID, SESSION_TAG_NAME + FROM crm_hlyj.crm_hlyj_dsri + WHERE BEGIN_TIME >= %s + AND BEGIN_TIME < %s + AND STATUS = 'assign' + ORDER BY BEGIN_TIME DESC + """ + + df, _ = self.execute_query(sql, (start_date, end_date)) + return df + + def query_messages_by_session_id(self, session_id: str) -> Optional[pd.DataFrame]: + """根据会话ID查询消息详情""" + sql = """ + SELECT CREATE_TIME, CUS_NICK_NAME, MODE, MSG_TYPE, AGENT_NAME, CONTENT, + SESSION_ID, ACCOUNT, SYSTEM_MODE_MESSAGE_TYPE + FROM crm_hlyj.crm_hlyj_dmri + WHERE SESSION_ID = %s + ORDER BY CREATE_TIME + """ + + df, _ = self.execute_query(sql, (session_id,)) + return df + + def export_to_excel(self, data: List[Dict[str, Any]], filename: str, output_dir: str = "output") -> Optional[str]: + """导出数据到Excel文件""" + if not data: + logger.warning(f"没有数据可导出到 {filename}") + return None + + try: + # 创建输出目录 + os.makedirs(output_dir, exist_ok=True) + + # 生成文件路径 + # timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + file_path = os.path.join(output_dir, f"{filename}.xlsx") + + # 准备数据:不同对话之间添加空行 + all_rows = [] + current_session_id = None + + for conversation in data: + if not conversation: # 跳过空对话 + continue + + # 如果是新的会话,添加空行(除了第一个会话) + if current_session_id and current_session_id != conversation[0]["会话id"]: + empty_row = {key: "" for key in conversation[0].keys()} + all_rows.append(empty_row) + + # 更新当前会话ID + current_session_id = conversation[0]["会话id"] + + # 添加当前会话的所有消息 + all_rows.extend(conversation) + + # 创建DataFrame并导出 + if all_rows: + df = pd.DataFrame(all_rows) + with pd.ExcelWriter(file_path, engine='openpyxl') as writer: + df.to_excel(writer, sheet_name='对话记录', index=False) + + logger.info(f"数据已导出到 {file_path}") + return file_path + else: + logger.warning("没有有效数据可导出") + return None + + except Exception as e: + logger.error(f"导出到Excel时出错: {e}") + return None + + +def process_session_batch(db_client: MariaDBClient, session_batch: pd.DataFrame) -> List[List[Dict[str, Any]]]: + """批量处理会话数据""" + conversations = [] + + for _, session_row in session_batch.iterrows(): + try: + session_id = session_row['SESSION_ID'] + messages_df = db_client.query_messages_by_session_id(session_id) + + if messages_df is not None and not messages_df.empty: + conversation = db_client.data_processor.messages_df_to_list(messages_df) + if conversation: + conversations.append(conversation) + + except Exception as e: + logger.error(f"处理会话 {session_row.get('SESSION_ID', 'unknown')} 时出错: {e}") + continue + + return conversations + + +class SessionProcessor: + """会话处理器,负责批量和并发处理""" + + def __init__(self, db_client: MariaDBClient, max_workers: int = None, batch_size: int = 50): + self.db_client = db_client + self.max_workers = max_workers if max_workers is not None else os.cpu_count() + self.batch_size = batch_size + self.temp_save_lock = threading.Lock() # 添加锁用于保护临时保存操作 + + logger.info(f"初始化会话处理器: max_workers={self.max_workers}, batch_size={self.batch_size}") + + def process_sessions(self, sessions_df: pd.DataFrame) -> List[List[Dict[str, Any]]]: + """处理所有会话数据""" + if sessions_df.empty: + logger.warning("没有会话数据需要处理") + return [] + + total_sessions = len(sessions_df) + logger.info(f"开始处理 {total_sessions} 个会话...") + + # 分批处理 + all_conversations = [] + batch_count = (total_sessions + self.batch_size - 1) // self.batch_size + # 使用线程池处理批次 + with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor: + # 提交所有批次任务 + future_to_batch = {} + + for i in range(0, total_sessions, self.batch_size): + batch = sessions_df.iloc[i:i + self.batch_size] + future = executor.submit(process_session_batch, self.db_client, batch) + future_to_batch[future] = i // self.batch_size + 1 + + # 收集结果 + with tqdm(total=batch_count, desc="处理批次进度") as pbar: + for future in concurrent.futures.as_completed(future_to_batch): + try: + batch_conversations = future.result() + all_conversations.extend(batch_conversations) + + # 使用锁保护临时列表的操作 + with self.temp_save_lock: + # 每处理100个对话临时保存一次 + logger.info(f"临时保存 {len(all_conversations)} 个对话") + temp_output_file = self.db_client.export_to_excel( + all_conversations, + f"客服对话记录_临时保存", + output_dir="/data/QueryRewrite/data/excel" + ) + if temp_output_file: + logger.info(f"临时保存完成: {temp_output_file}") + + batch_num = future_to_batch[future] + logger.debug(f"批次 {batch_num} 完成,获得 {len(batch_conversations)} 个对话") + + except Exception as e: + batch_num = future_to_batch[future] + logger.error(f"处理批次 {batch_num} 时出错: {e}") + + pbar.update(1) + + logger.info(f"处理完成,共获得 {len(all_conversations)} 个有效对话") + return all_conversations + + +def main() -> None: + """主函数""" + try: + # 加载配置 + config = DatabaseConfig.from_config_file() + logger.info(f"使用数据库配置: {config.host}:{config.port}") + + # 创建数据库客户端 + with MariaDBClient(config, max_connections=12) as db_client: + # 查询会话数据 + start_date = '2025-01-01 00:00:00' + end_date = '2025-06-12 00:00:00' + + logger.info(f"查询时间范围: {start_date} 到 {end_date}") + # 创建会话处理器 + processor = SessionProcessor(db_client, batch_size=100) + is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None + if is_debug: + messages_df = db_client.query_messages_by_session_id("86c919e0-09f1-11f0-84ae-2daf59566989") + print(db_client.data_processor.messages_df_to_list(messages_df)) + return [] + + sessions_df = db_client.query_sessions(start_date, end_date) + + if sessions_df is None or sessions_df.empty: + logger.warning("没有找到符合条件的会话数据") + return + + # 处理会话数据 + all_conversations = processor.process_sessions(sessions_df) + # 导出结果 + if all_conversations: + output_file = db_client.export_to_excel( + all_conversations, + "客服对话记录", + output_dir="/data/QueryRewrite/data/excel" + ) + + if output_file: + logger.info(f"处理完成!共导出 {len(all_conversations)} 个对话到文件: {output_file}") + else: + logger.error("导出文件失败") + else: + logger.warning("没有有效的对话数据可导出") + + except KeyboardInterrupt: + logger.info("用户中断程序") + except Exception as e: + logger.error(f"程序执行出错: {e}", exc_info=True) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/rag2_0/demo/intent_recognition_example.py b/rag2_0/demo/intent_recognition_example.py index beebf20..3e896b8 100644 --- a/rag2_0/demo/intent_recognition_example.py +++ b/rag2_0/demo/intent_recognition_example.py @@ -175,7 +175,7 @@ def save_results_to_excel(results, output_file, is_final=False): logging.info(f"已保存{len(valid_results)}条结果至: {temp_output_file}") # 示例查询 -examples_query = """那西藏软件呢""" +examples_query = """那储能软件如何操作""" conversation_context="" chat_history=[ { @@ -214,8 +214,8 @@ def main(): # 读取提问数据 current_dir = os.path.dirname(os.path.abspath(__file__)) - data_file = os.path.join(current_dir, "..", "..", "data", "excel", "历史提问数据(like)_提问明确.xlsx") - output_file = os.path.join(current_dir, "..", "..", "data", "excel", "测试提问数据_槽位填充结果.xlsx") + data_file = os.path.join(current_dir, "..", "..", "data", "excel", "200条点踩数据测试.xlsx") + output_file = os.path.join(current_dir, "..", "..", "data", "excel", "200条点踩数据测试_槽位填充结果.xlsx") # 检测是否为调试模式,调试模式下使用examples_query,否则从Excel读取 is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None @@ -226,7 +226,7 @@ def main(): examples = load_questions_from_excel(data_file) if not is_debug: - max_workers = 40 # 减少并发数以避免API限制 + max_workers = 20 # 减少并发数以避免API限制 logging.info(f"共有 {len(examples)} 个问题需要处理,使用 {max_workers} 个并发线程") # 创建一个与输入顺序相同的结果列表 @@ -260,9 +260,10 @@ def main(): logging.info(f"所有处理完成,最终结果已保存至: {output_file}") else: for idx, query in enumerate(examples): - if query.strip() == "": - continue - process_query(recognizer, query, conversation_context, chat_history, previous_slots) + if query.strip() == "": + continue + process_query(recognizer, query, conversation_context, chat_history, previous_slots) + # print(json.dumps(process_query(recognizer, query), ensure_ascii=False, indent=2)) def setup_logging(): # 配置日志输出到控制台 diff --git a/rag2_0/dify/intent_recognition_api.py b/rag2_0/dify/intent_recognition_api.py index 42da11e..344362a 100644 --- a/rag2_0/dify/intent_recognition_api.py +++ b/rag2_0/dify/intent_recognition_api.py @@ -6,10 +6,23 @@ import json import time import threading import datetime +import logging # 加载环境变量 load_dotenv() +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler() + ] +) +logging.getLogger('httpx').setLevel(logging.WARNING) +logging.getLogger('openai').setLevel(logging.WARNING) + +logger = logging.getLogger(__name__) + app = Flask(__name__) # 创建线程锁,用于保护共享资源 @@ -50,8 +63,8 @@ def intent_recognize(): end_time = time.time() current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S %z") - print(f"[{current_time}] [{os.getpid()}] [INFO] 意图识别耗时: {end_time - start_time:.2f}秒") - + logger.info(f"[{os.getpid()}] 意图识别耗时: {end_time - start_time:.2f}秒") + # 提取分类信息 classification = result["classification"] diff --git a/rag2_0/intent_recognition/DataModels.py b/rag2_0/intent_recognition/DataModels.py index f345f5e..7a25fcb 100644 --- a/rag2_0/intent_recognition/DataModels.py +++ b/rag2_0/intent_recognition/DataModels.py @@ -150,12 +150,14 @@ class SoftwareFunctionSlots(SlotBase): software_name: str = Field(default="", description="软件名称") function_name: str = Field(default="", description="具体功能名称") operation: str = Field(default="", description="用户操作意图(如何使用功能、功能入口、功能使用场景)") - project_type: Optional[str] = Field(default="单工程", description="工程类型(单工程、多工程、批次工程)") + project_type: Optional[str] = Field(default="单工程", description="工程类型(单工程、多工程、批次工程), 未明确提及则默认下是(单工程)") software_version: Optional[str] = Field(default="", description="软件版本") operation_steps: Optional[str] = Field(default="", description="操作步骤描述") def check_required_slots(self) -> Tuple[bool, Dict[str, str]]: """检查必填槽位是否都存在""" + if self.project_type is None or len(self.project_type) == 0: + self.project_type="单工程" missing_slots = {} if not self.software_name: missing_slots["software_name"] = f"{SoftwareFunctionSlots.model_fields['software_name'].description},可选值:{', '.join([name.value for name in SoftwareName if name not in [SoftwareName.UNKNOWN, SoftwareName.ALIASES]])}" diff --git a/rag2_0/intent_recognition/IntentRecognition.py b/rag2_0/intent_recognition/IntentRecognition.py index d1e0cc0..5920a4e 100644 --- a/rag2_0/intent_recognition/IntentRecognition.py +++ b/rag2_0/intent_recognition/IntentRecognition.py @@ -14,6 +14,8 @@ import json from typing import List, Tuple, Dict, Any, Optional import re import jieba +import time + from .PromptTemplates import (classification_prompt, query_rewrite_prompt, extract_nouns_prompt, classification_info, slot_filling_prompt) @@ -95,7 +97,9 @@ class IntentRecognizer: except Exception as e: raise RuntimeError(f"加载后缀关键词失败: {e}") from e - def _classify_intent(self, query: str) -> Classification: + def _classify_intent(self, query: str, conversation_context: str = "", + chat_history: List[Dict[str, str]] = None, + previous_slots: Dict[str, Any] = None) -> Classification: """ 对用户输入进行意图分类 @@ -109,7 +113,9 @@ class IntentRecognizer: classification_parser = PydanticOutputParser(pydantic_object=Classification) formatted_prompt = classification_prompt.format(user_input=query, classification_info=classification_info, - output_format=classification_parser.get_format_instructions()) + output_format=classification_parser.get_format_instructions(), + conversation_context=conversation_context, + chat_history=json.dumps(chat_history, ensure_ascii=False)) # 调用LLM response = self._llm.invoke(formatted_prompt, False) @@ -208,7 +214,7 @@ class IntentRecognizer: term_texts = ["名称:" + term.name + "|" + "同义词:" + ";".join(term.synonymous) for term in matched_terms] # 使用重排序模型 - xinference_reranker = SiliconFlowReRankerModel() + xinference_reranker = XinferenceReRankerModel() rerank_results = xinference_reranker.rerank(query_key, term_texts, top_k=top_k) # 将matched_terms转换为列表以便按索引访问 @@ -220,7 +226,7 @@ class IntentRecognizer: return reranked_terms except Exception as e: - raise RuntimeError(f"SiliconFlowReRankerModel重排失败:{e}") from e + raise RuntimeError(f"_rerank_matched_terms重排失败:{e}") from e def _match_keywords(self, query: str, use_jieba: bool = False) -> Tuple[TermList, List[str]]: """ @@ -233,18 +239,23 @@ class IntentRecognizer: Returns: 匹配到的关键词列表 """ + start_time = time.time() query_keys=[] # 步骤1: 使用LLM提取查询中的关键词 try: + llm_start_time = time.time() extracted_terms = self._extract_keywords_with_llm(query, use_jieba) for term in extracted_terms: query_keys.append(term.name) + llm_end_time = time.time() + llm_time = llm_end_time - llm_start_time except Exception as e: raise RuntimeError(f"LLM关键词提取失败: {e}") from e matched_terms = [] # 存储匹配到的Term对象 # 步骤2: 使用向量检索找到相似的专业名词 try: + vector_start_time = time.time() # 对matched_terms中的每个关键字进行向量检索 for current_key in query_keys: vector_results = self._noun_retriever.query(current_key, top_k=5, use_intersection=False) @@ -262,12 +273,20 @@ class IntentRecognizer: if len(current_key_terms) > 0: reranked_terms = self._rerank_matched_terms(current_key, current_key_terms) matched_terms.extend(reranked_terms) + vector_end_time = time.time() + vector_time = vector_end_time - vector_start_time except Exception as e: raise RuntimeError(f"向量检索关键词时出错: {e}") from e # 提取所有Term对象的名称并排序 # 将set类型的matched_terms转换为TermList类型 term_list = TermList(terms=list(matched_terms)) + end_time = time.time() + total_time = end_time - start_time + + # 输出整合的时间日志 + logging.info(f"关键词匹配耗时统计 - 总耗时: {total_time:.2f}秒, 问题关键词提取: {llm_time:.2f}秒, 向量检索+重排序: {vector_time:.2f}秒") + return term_list, query_keys def _rewrite_query(self, query: str, keywords: TermList, query_keys:List[str], chat_history: List[Dict[str, str]] = None, context: str = "") -> QueryRewrite: @@ -282,6 +301,8 @@ class IntentRecognizer: Returns: 改写结果 """ + + rewrite_start_time = time.time() # 准备问题改写提示 # terms_dict = [term.model_dump(exclude={"description"}) for term in keywords.terms] terms_dict = [term.model_dump() for term in keywords.terms] @@ -295,7 +316,7 @@ class IntentRecognizer: keywords=keywords_str, chat_history=chat_history, context=context) - + # 调用LLM response = self._llm.invoke(formatted_prompt, False) @@ -303,6 +324,9 @@ class IntentRecognizer: try: # 尝试直接解析JSON响应 parsed_output = query_rewrite_parser.parse(response.content) + rewrite_end_time = time.time() + rewrite_time = rewrite_end_time - rewrite_start_time + logging.info(f"问题改写耗时统计 - 总耗时: {rewrite_time:.2f}秒") return parsed_output except Exception as e: raise RuntimeError(f"解析问题改写结果时出错: {e}") from e @@ -360,7 +384,10 @@ class IntentRecognizer: # suffix_terms.append(suffix_term) # return Classification(vertical_classification="安装下载", sub_classification="查询"), TermList(terms=suffix_terms), QueryRewrite(rewrite=query), matched_suffixes - + if chat_history is None: + chat_history = [] + if previous_slots is None: + previous_slots = {} # 步骤1: 匹配关键词 keywords_terms, query_keys = self._match_keywords(query, use_jieba) @@ -397,7 +424,9 @@ class IntentRecognizer: # } - def _fill_slots(self, query: str, classification: Classification) -> Dict[str, Any]: + def _fill_slots(self, query: str, classification: Classification, conversation_context: str = "", + chat_history: List[Dict[str, str]] = None, + previous_slots: Dict[str, Any] = None,) -> Dict[str, Any]: """ 根据分类结果对问题进行槽位填充 @@ -415,7 +444,7 @@ class IntentRecognizer: raise RuntimeError("未找到匹配的槽位模型") # 使用LLM进行槽位填充 - filled_slots = self._fill_slots_with_llm(query, classification, slot_model) + filled_slots = self._fill_slots_with_llm(query, classification, slot_model, conversation_context, chat_history, previous_slots) # 检查必填槽位是否都已填充 is_complete, missing_slots = filled_slots.check_required_slots() @@ -467,7 +496,12 @@ class IntentRecognizer: return None - def _fill_slots_with_llm(self, query: str, classification: Classification, slot_model_class: type) -> Any: + def _fill_slots_with_llm(self, query: str, + classification: Classification, + slot_model_class: type, + conversation_context: str = "", + chat_history: List[Dict[str, str]] = None, + previous_slots: Dict[str, Any] = None) -> Any: """ 使用LLM进行槽位填充 @@ -486,7 +520,10 @@ class IntentRecognizer: query=query, vertical_classification=classification.vertical_classification, sub_classification=classification.sub_classification, - output_format=slot_parser.get_format_instructions() + output_format=slot_parser.get_format_instructions(), + conversation_context=conversation_context, + chat_history=json.dumps(chat_history,ensure_ascii=False), + previous_slots=json.dumps(previous_slots,ensure_ascii=False), ) # 调用LLM @@ -537,9 +574,14 @@ class IntentRecognizer: output_format=parser.get_format_instructions(), classification_info=classification_info ) + # 调用LLM + llm_start_time = time.time() response = self._llm.invoke(formatted_prompt + output_example, False) + llm_end_time = time.time() + llm_time = llm_end_time - llm_start_time + try: # 解析LLM响应为JSON result_json = parser.parse(response.content) @@ -552,8 +594,19 @@ class IntentRecognizer: if expected_slot_model is None: # 添加容错处理,应对LLM返回错误分类信息,一级分类跟二级分类错乱 # 重新分类 - classification = self._classify_intent(user_input) - fill_slots = self._fill_slots(user_input, classification) + classify_start_time = time.time() + classification = self._classify_intent(user_input, conversation_context, chat_history, previous_slots) + classify_end_time = time.time() + classify_time = classify_end_time - classify_start_time + # logging.info(f"重新分类耗时: {classify_time:.2f}秒") + + fill_start_time = time.time() + fill_slots = self._fill_slots(user_input, classification, conversation_context, chat_history, previous_slots) + fill_end_time = time.time() + fill_time = fill_end_time - fill_start_time + all_time=fill_end_time-llm_start_time + logging.info(f"总耗时:{all_time:.2f}秒,首次槽位+分类:{llm_time:.2f}秒, 重新分类耗时: {classify_time:.2f}秒, 重新槽位填充耗时: {fill_time:.2f}秒") + result = { "classification": classification.model_dump(), "slot_filling": fill_slots @@ -562,13 +615,21 @@ class IntentRecognizer: return result elif expected_slot_model.__name__ != type(slot_filling).__name__: # 添加容错处理,应对LLM槽位与分类不匹配。重新填充槽位 + fill_start_time = time.time() slot_filling = self._fill_slots(user_input, classification) + fill_end_time = time.time() + fill_time = fill_end_time - fill_start_time + all_time=fill_end_time-llm_start_time + logging.info(f"总耗时:{all_time:.2f}秒,首次槽位+分类:{llm_time:.2f}秒, 重新槽位填充耗时: {fill_time:.2f}秒") + result = { "classification": classification.model_dump(), "slot_filling": slot_filling } logging.warning(f"重新填充槽点") return result + + logging.info(f"意图识别+槽位LLM调用耗时: {llm_time:.2f}秒") # 构建最终结果 result = { diff --git a/rag2_0/intent_recognition/Multi_PromptTemplates.py b/rag2_0/intent_recognition/Multi_PromptTemplates.py index a29534a..5d99b0f 100644 --- a/rag2_0/intent_recognition/Multi_PromptTemplates.py +++ b/rag2_0/intent_recognition/Multi_PromptTemplates.py @@ -126,7 +126,7 @@ query_rewrite_prompt_pro_old=""" query_rewrite_prompt_pro=""" # 电力造价问答优化工程师(精简版) -**角色**:基于历史对话和专业术语库重构问题,提升知识库检索准确率。 +**角色**:基于历史对话和术语库重构问题,提升知识库检索准确率。 ## 核心原则 1. 语义保真 → 保持问题核心意图 @@ -135,8 +135,14 @@ query_rewrite_prompt_pro=""" ## 处理流程 ### 一、输入解析 - - 原始问题(需保留核心语义):{query} - - 关键词集合:{keywords} + - 原始问题(需保留核心语义): + + {query} + + - 术语库集合: + + {keywords} + - 历史对话记录: {chat_history} @@ -159,14 +165,14 @@ graph TD ### 三、重构优先级 1. **背景补充** - - 历史对话中确定的背景信息需要保留(例:"这软件"→"【配网工程D3】") + - 历史对话中确定的背景信息需要保留(例:"这软件"→"【配网工程计价通D3软件】") 2. **术语处理** - - 同义词转标准词 → 批量设置定额 + - 同义词转标准词 → 将提问中的同义词(synonymous)替换为标准词(name) - 存在即标记 → 【计算式】 3. **结构优化** - - 保持原问题的5W2H特征 + - 保持原问题的5W2H特征,确保问题意图不发生改变。 - 明确指代关系("该功能"→"【批量导入】功能") ## 输出规范 @@ -184,7 +190,7 @@ graph TD - [] 背景信息是否合理补充? - [] 术语标记是否完整【】? - [] 语句是否自然流畅? -- [] 避免过度补充无关信息 +- [] 避免补充无关信息 """ @@ -349,7 +355,7 @@ def generate_slot_mapping_doc() -> str: doc.append(f"- {sub_class} -> {slot_model}") doc.append("\n## 【注意事项】") - doc.append("1. 分类与槽位模型必须严格对应") + doc.append("1. 分类与槽位模型必须严格对应。严格遵守,不得违背") doc.append("2. 每个分类只能使用其对应的槽位模型") doc.append("3. 不允许混用不同分类的槽位模型") diff --git a/rag2_0/intent_recognition/PromptTemplates.py b/rag2_0/intent_recognition/PromptTemplates.py index f8776e6..b628f4b 100644 --- a/rag2_0/intent_recognition/PromptTemplates.py +++ b/rag2_0/intent_recognition/PromptTemplates.py @@ -58,6 +58,12 @@ classification_prompt=""" 用户正在使用电力造价软件或想询问电力造价领域相关知识,你需要根据用户的输入内容,将其归类为以下垂直领域之一: {classification_info} + ## 【会话背景信息】 + {conversation_context} + + ## 【历史对话记录】 + {chat_history} + 【用户输入】: {user_input} @@ -154,6 +160,15 @@ slot_filling_prompt = """ 【用户问题】 {query} +## 【会话背景信息】 +{conversation_context} + +## 【历史对话记录】 +{chat_history} + +## 【历史槽位信息】 +{previous_slots} + 【问题分类】 垂直领域分类: {vertical_classification} 子分类: {sub_classification} diff --git a/rag2_0/tool/APIKeyManager.py b/rag2_0/tool/APIKeyManager.py index e2fc25e..3685150 100644 --- a/rag2_0/tool/APIKeyManager.py +++ b/rag2_0/tool/APIKeyManager.py @@ -23,16 +23,6 @@ API_KEY_LIST=[ "sk-kzhxlqvqcxlnbdgnpalqnzumkmspepkttkgbophnkqanainw", "sk-bzttugqtlskrvguvhckwamdssvgmgnrqpsialpdbskfsyyak", "sk-tovmogiablsoeabwgqyvevpcfichyjpuzqdymmvksspdrtqt", -"sk-wqdpapdkisovziexgcyxvumpwzbjnhqbxvcqcspzctjhyhjk", -"sk-bbntrnifrtdzhhgrtlrhvwbnaysuszviemshdakxonnnymnb", -"sk-vmpnwjxersrwybmfhfxgsvbmhsmpjldxseiyxovnysrlbuzi", -"sk-nscsxwfqigkfpfqfzebkmaickxjzbhtfwywdppmmobrrbfnw", -"sk-irbxuakhntsrusrympiubkkjbkabbfbdgpstqnxbztzdtxdq", -"sk-hcfojzczbgwgcuhzxkicxqrhadurtakwbawiesyxyvksmcoz", -"sk-wiyosqgyutjypgzibveiwkgqwfkfsnonrmvjfbvrbkoicciv", -"sk-ocglenyvxkkvzupzumoypnyndjpjqhivyqpedusunboglspz", -"sk-dtbawdwajkhdctrukundbkqwswzfzihqbebfuvqnfnounbuc", -"sk-zqiyiqtbwqgyeenkvppymfbkspriolwbnxnjakugzxyvcuql", "sk-wtnjpejveiobtvzsmnuaefqkocsafbfyrtqkkyqardndtxcs", "sk-gqdvtrwvzxewnagwsfakrvajtzwgcknatpflkesyqhzjrlal", "sk-plivglrkxahodgtgjlaqdjusdoerxspjbcbizaybicarfyuk", @@ -96,6 +86,26 @@ API_KEY_LIST=[ "sk-jrdzerhmvrtvzawkksowbgkggkubwfquplmrxbdhespqgtis", "sk-jjbpnkbeupsxyclcivbhizcfpfjrppddunbqynyjkqhtmpwu", "sk-oqehupcveovkjqqtxypqyifidcdissuyehwrkdwgruoyjkpq", +"sk-jnnmltwtqwuoyagoogzzeraczmyfxhoairiddgayksqdfnbr", +"sk-eghuepxnbcollzrjwbzqvbnhiiwagkejaclyhvaodeqgwrog", +"sk-poszkbjdmamimconjustnrxxqusuzlryxkrzkpronlenrmen", +"sk-zolvcegarsrwqhwgvwzgtqupodsdmckjiocyvoyldbkusbzc", +"sk-ywfafulcniaqdgdcsnbtqquaqeuiqlkcnknkaflwxyuemcow", +"sk-hhedmocgtfpywbbpwamgfkygrahiqsuurntlbqqbmjwfipmm", +"sk-gzdqfoyvulrqscdpjlwlufdecrsyjpmwpkknuhnjsvtyftox", +"sk-bkcufidsebujopqqwexwxwpmevrpelmvxzdymncvllcyojce", +"sk-olabhscekudzkyudypkcjvehwqunagubwdmtppugrjmcptwv", +"sk-zpdqyocliebhqpkuwvebpgcnfjdkvavdltimllmgkthwnwph", +"sk-gvhchlfelocjniuydusyhhwacnomxnvucjonzkhtqoplnbcr", +"sk-lzneagvdxhisodndnxnpkntghpkimjmjsebiqdzaoqzuhbla", +"sk-xotcfdkigykevngedupitbcatjqppxmcibjtcebyoglykuxz", +"sk-ufydqsdqnwsegaqwtappzwdyzqnoblyunfvslomnnmykedgk", +"sk-jwasykftbkyjzdqlwcxuicrwzxsbhttilxfefbrozrznpwlv", +"sk-xngteojwkxmftyaabjdwwgyoadspsowmcpcqobteutdcfmnr", +"sk-akzkgniebruqrtuqskvlibkpcxjuazhcatysptkfyqivldfn", +"sk-vpqkxtmcgkggllexchzysuewyfaoexzasoumxngdplzgwksw", +"sk-fvcsqdbqmdlwxzjyofrilusqcypbfyczogaqwqrjrwvojmer", +"sk-htjprscvfgskjtjzpxxxjhyymshagogykpawxekrrfbgftyx", ] class APIKeyManager: diff --git a/rag2_0/tool/ModelTool.py b/rag2_0/tool/ModelTool.py index f5f4bc2..8ee20b7 100644 --- a/rag2_0/tool/ModelTool.py +++ b/rag2_0/tool/ModelTool.py @@ -100,10 +100,10 @@ class XinferenceReRankerModel: Returns: List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引 """ - url = "http://10.1.16.39:9995/v1/rerank" + url = "http://172.20.0.145:9995/v1/rerank" - params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": os.getenv("RERANKER_MODEL_NAME")} + params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": "bge-reranker-v2-m3"} headers = { "Authorization": "Bearer ", # 这里需要替换为实际的token "Content-Type": "application/json" @@ -140,8 +140,7 @@ class OpenAiLLM: def invoke(self, user_prompt="你是谁?", need_retry=True): # 初始化 OpenAI 客户端 - api_key = APIKeyManager.get_api_key() - client = OpenAI(api_key=api_key, base_url=self._url) + max_retries = 3 retry_count = 0 @@ -149,6 +148,8 @@ class OpenAiLLM: if need_retry: while retry_count < max_retries: try: + api_key = APIKeyManager.get_api_key() + client = OpenAI(api_key=api_key, base_url=self._url) # 创建 Completion 请求. 超时120s completion = client.chat.completions.create( model=self._model, @@ -162,11 +163,13 @@ class OpenAiLLM: retry_count += 1 if retry_count == max_retries: logging.error(f"LLM 重试{max_retries}次后仍然失败: {e}") - return "" + raise e else: time.sleep(5*retry_count) # 重试前等待1秒 else: # 创建 Completion 请求. 超时120s + api_key = APIKeyManager.get_api_key() + client = OpenAI(api_key=api_key, base_url=self._url) completion = client.chat.completions.create( model=self._model, messages=[{'role': 'user', 'content': user_prompt}], @@ -180,53 +183,15 @@ if __name__ == "__main__": reranker = SiliconFlowReRankerModel() # 测试用例1:简单问题 - query = "他想做什么" - documents = ["她想去公园跑步", "她想换一个新手机", "明天她想出去旅游"] + query = "如何通过【电力经济评价软件】的【打开】功能加载工程文件?" + documents = ["\n# (电力建设计价通软件) (概预算工程)工程备份管理\n## 操作步骤\n**方法一:** \n\n1、查找工程:输入工程文件名称的关键字,点击“查找”按钮,可以快速定位需查找的工程;\n\n![2](https://172.20.0.145/files/ea1d6edc-090c-4c35-be13-e506f0eeb176/image-preview)\n\n2、根据时间点找备份工程:选中对应工程文件,在右侧选中“备份时间”的备份记录,点击“还原工程”或者“另存为工程”;\n\n **还原工程:** 将工程还原保存在原路径下;\n\n **另存为工程:** 另存为一个新工程,可选择保存路径,保存后,可点击文件——打开,浏览到另存的新工程打开。\n\n注:不确定备份是否是需要时,优先建议另存为工程。\n\n![3](https://172.20.0.145/files/63fb4e9d-06ce-44e6-adbc-45ebc92c3e6f/image-preview)\n\n **方法二:** \n\n1、点击桌面软件快捷图标,右键属性—打开文件位置,直接定位软件安装根目录。 \n\n![打开文件所在位置](https://172.20.0.145/files/28540c85-2b85-4717-9b87-31eba886c5f0/image-preview)\n\n2、在软件安装根目录,点击“数据备份”文件夹,进入到文件夹内,根据修改日期找到对应工程,右键复制粘贴至桌面。\n\n![打开数据备份文件夹](https://172.20.0.145/files/76a4263c-ea75-4083-9ed8-7f2e28412e51/image-preview)\n\n![按照时间排序,找到.bak文件复制到桌面](https://172.20.0.145/files/eaa59154-6fbc-494e-817c-e5f7610e6294/image-preview)\n\n3、定位桌面复制粘贴出来的数据工程,右键\"重命名\",将bak修改成相应的文件后缀(概预算工程及施工图预算工程后缀为zwzj,招标工程及投标工程后缀为zwqd),然后点击“确定”,再通过软件的“文件”——“打开”按钮去浏览工程打开。\n", + "\n# (配网计价通D3)插件管理/全国版和专版切换\n## 使用场景\n1.打开软件提示“当前工程文件为全国版文件,请使用全国版软件打开!”,该如何打开这个工程呢?\n\n![1](https://172.20.0.145/files/3751d1c2-da12-4076-bd9f-3e1c9eced1ab/image-preview)\n\n2.打开软件提示“当前工程文件为辽宁版文件,请确认是否要在全国版软件中打开?”,这是什么意思?点击“确定”又可以打开工程?\n\n![2](https://172.20.0.145/files/68c18647-c4fa-4496-974a-f5b07352063e/image-preview)\n## 知识原理\n\n## 费用去向\n\n", + "\n(电力建设计价通软件) 云造价--停用\n# 工程文件管理\n\n## 【主页】中点击“云端工程管理”,进入博微服务大厅;\n![6](https://172.20.0.145/files/69d1d763-bb6b-459f-b7ba-264c788225e9/image-preview)![7](https://172.20.0.145/files/eb93d8ea-d3c8-45f3-86ab-f23c0d244d96/image-preview)\n## 工程文件管理界面中显示云端备份的工程列表,可支持\n![8](https://172.20.0.145/files/91200171-182b-4de7-a895-f5118bfe9d98/image-preview)\n## 高级设置:可对历史版本数量进行设置,默认数量为10,可设置(5-15);\n![9](https://172.20.0.145/files/31b63307-3313-4402-9652-9b46e9549aec/image-preview)\n## 历史版本:勾选单个工程,点击“历史版本”可查看该工程保存的不同时间节点的历史工程;\n![10](https://172.20.0.145/files/1ae83da5-9e86-4a44-9af1-013199a933c6/image-preview)\n## 在线查阅:可查看工程数据,仅为只读模式不支持任何编辑;\n![11](https://172.20.0.145/files/61ef2bc8-0d73-47bf-b66f-aebcc39d7a46/image-preview)\n## 下载;选择需要的工程点击“下载”,可下载软件版本工程;\n![12](https://172.20.0.145/files/3391d7b2-c289-485a-a459-754ce6c8e294/image-preview)\n", + "\n(配网D3软件)打开工程\n\n# (配网D3软件)打开工程\n\n## 功能入口\n各界面点击“文件”按钮——“打开”按钮 \n\n![打开工程-功能入口](https://172.20.0.145/files/7c6dbff1-932c-4cf5-ab16-b17a6309af0e/image-preview)\n## 操作步骤\n**打开工程:** \n\n点击“打开”按钮,浏览到工程存放位置,选中工程文件,点击“打开”即可。"] results = reranker.rerank(query, documents) print(f"测试用例1 - 查询:{query}") for idx, item in enumerate(results): print(f"{idx+1}. 文档: {item['document']}, 分数: {item['score']}") print("-" * 50) - # 测试用例2:技术问题 - query = "Python如何处理JSON数据" - documents = [ - "Python中可以使用json模块来处理JSON数据,例如json.loads()将JSON字符串转换为字典", - "Java提供了多种处理JSON的库,比如Jackson和Gson", - "在Python中,可以使用pandas库来分析CSV数据", - "JavaScript可以使用JSON.parse()方法解析JSON字符串" - ] - results = reranker.rerank(query, documents) - print(f"测试用例2 - 查询:{query}") - for idx, item in enumerate(results): - print(f"{idx+1}. 文档: {item['document']}, 分数: {item['score']}") - print("-" * 50) - - # 测试用例3:医疗问题 - query = "高血压的症状有哪些" - documents = [ - "高血压的常见症状包括头痛、头晕、耳鸣和视力模糊", - "糖尿病的症状包括多饮、多尿和体重减轻", - "心脏病的症状通常包括胸痛、呼吸急促和疲劳", - "高血压患者应该定期监测血压,保持健康的生活方式" - ] - results = reranker.rerank(query, documents) - print(f"测试用例3 - 查询:{query}") - for idx, item in enumerate(results): - print(f"{idx+1}. 文档: {item['document']}, 分数: {item['score']}") - print("-" * 50) - - # 测试用例4:长文本查询和文档 - query = "人工智能在医疗领域的应用及其伦理问题" - documents = [ - "人工智能在医疗诊断中的应用已经显示出良好的效果,例如通过分析医学影像来检测疾病。然而,这也引发了关于医生角色和责任的伦理问题。", - "在教育领域,人工智能可以提供个性化学习体验,适应不同学生的学习进度和风格。", - "医疗伦理问题主要包括患者隐私保护、知情同意和医疗资源分配等方面。", - "人工智能技术在金融领域的应用主要集中在风险评估、欺诈检测和算法交易等方面。" - ] - results = reranker.rerank(query, documents) - print(f"测试用例4 - 查询:{query}") - for idx, item in enumerate(results): - print(f"{idx+1}. 文档: {item['document']}, 分数: {item['score']}") - print("-" * 50)