diff --git a/.gitignore b/.gitignore index 5439613..856df84 100644 --- a/.gitignore +++ b/.gitignore @@ -11,9 +11,7 @@ data/excel/*.xlsx !data/excel/Excel版 清单定额库/ !data/excel/Excel版 清单定额库/** data/logs/* -rag2_0/dify/Test.py -data/query_logs/* -data/conversations/* data/test* data/temp* data/db/answer_logs.db +data/db/qingdan_ding_e_ku.db diff --git a/rag2_0/demo/heli_db_to_excel.py b/rag2_0/demo/heli_db_to_excel.py index aac2b91..de43b52 100755 --- a/rag2_0/demo/heli_db_to_excel.py +++ b/rag2_0/demo/heli_db_to_excel.py @@ -5,23 +5,18 @@ from __future__ import annotations import json import os import re -import configparser import logging from datetime import datetime from typing import Any, Dict, List, Optional, Tuple, Union from dataclasses import dataclass -from contextlib import contextmanager import threading -import time -from queue import Queue, Empty, Full import pandas as pd import pymysql -from pymysql.connections import Connection -from pymysql.cursors import Cursor from tqdm import tqdm import concurrent.futures import sys +from queue import Queue, Empty, Full os.makedirs('./data/logs', exist_ok=True) # 配置日志 @@ -35,6 +30,18 @@ logging.basicConfig( ) logger = logging.getLogger(__name__) +# ===================== +# 硬编码数据库配置(简化) +# ===================== +DB_HOST = '192.168.0.123' +DB_PORT = 3307 +DB_USER = 'fuzhimei' +DB_PASSWORD = 'fuzhimei@135' +DB_CHARSET = 'utf8mb4' +DB_CONNECT_TIMEOUT = 10 +DB_READ_TIMEOUT = 300 +DB_WRITE_TIMEOUT = 300 + def parse_session_tags(input_string): """ 解析sessionTag格式的字符串,支持任意数量的sessionTag @@ -74,172 +81,7 @@ def parse_session_tags(input_string): session_key = f'sessionTag{tag_suffix}' result[session_key] = tag_data - return result - -@dataclass -class DatabaseConfig: - """数据库配置类""" - host: str = '192.168.0.123' - port: int = 3307 - user: str = 'fuzhimei' - password: str = 'fuzhimei@135' - charset: str = 'utf8mb4' - connect_timeout: int = 10 - read_timeout: int = 300 - write_timeout: int = 300 - - @classmethod - def from_config_file(cls, config_file: str = 'config.ini') -> 'DatabaseConfig': - """从配置文件加载配置""" - if not os.path.exists(config_file): - logger.warning(f"配置文件 {config_file} 不存在,使用默认配置") - return cls() - - config = configparser.ConfigParser() - config.read(config_file, encoding='utf-8') - - if 'database' not in config: - logger.warning("配置文件中没有 [database] 部分,使用默认配置") - return cls() - - db_config = config['database'] - return cls( - host=db_config.get('host', cls.host), - port=int(db_config.get('port', cls.port)), - user=db_config.get('user', cls.user), - password=db_config.get('password', cls.password), - charset=db_config.get('charset', cls.charset), - connect_timeout=int(db_config.get('connect_timeout', cls.connect_timeout)), - read_timeout=int(db_config.get('read_timeout', cls.read_timeout)), - write_timeout=int(db_config.get('write_timeout', cls.write_timeout)) - ) - - -class ConnectionPool: - """数据库连接池""" - - def __init__(self, config: DatabaseConfig, max_connections: int = 10): - self.config = config - self.max_connections = max_connections - self.pool = Queue(maxsize=max_connections) - self.active_connections = 0 - self.lock = threading.Lock() - - # 预创建一些连接 - self._initialize_pool() - - def _initialize_pool(self) -> None: - """初始化连接池,预创建一些连接""" - initial_connections = min(3, self.max_connections) - for _ in range(initial_connections): - try: - conn = self._create_connection() - if conn: - self.pool.put_nowait(conn) - self.active_connections += 1 - except Full: - break - except Exception as e: - logger.error(f"初始化连接池时创建连接失败: {e}") - - def _create_connection(self) -> Optional[Connection]: - """创建新的数据库连接""" - try: - conn = pymysql.connect( - host=self.config.host, - port=self.config.port, - user=self.config.user, - password=self.config.password, - charset=self.config.charset, - connect_timeout=self.config.connect_timeout, - read_timeout=self.config.read_timeout, - write_timeout=self.config.write_timeout, - autocommit=True - ) - return conn - except Exception as e: - logger.error(f"创建数据库连接失败: {e}") - return None - - @contextmanager - def get_connection(self): - """获取连接的上下文管理器""" - conn = None - try: - # 尝试从池中获取连接 - try: - conn = self.pool.get_nowait() - except Empty: - # 池中没有连接,尝试创建新连接 - with self.lock: - if self.active_connections < self.max_connections: - conn = self._create_connection() - if conn: - self.active_connections += 1 - else: - raise Exception("无法创建新的数据库连接") - else: - # 等待可用连接 - logger.info("等待可用连接...") - conn = self.pool.get(timeout=30) - - # 检查连接是否仍然有效 - if conn and not self._is_connection_alive(conn): - logger.warning("连接已失效,重新创建") - try: - conn.close() - except: - pass - conn = self._create_connection() - if not conn: - raise Exception("重新创建连接失败") - - yield conn - - except Exception as e: - logger.error(f"获取数据库连接时出错: {e}") - if conn: - try: - conn.close() - except: - pass - with self.lock: - self.active_connections -= 1 - raise - else: - # 归还连接到池中 - if conn: - try: - self.pool.put_nowait(conn) - except Full: - # 池已满,关闭连接 - try: - conn.close() - except: - pass - with self.lock: - self.active_connections -= 1 - - def _is_connection_alive(self, conn: Connection) -> bool: - """检查连接是否仍然有效""" - try: - conn.ping(reconnect=False) - return True - except: - return False - - def close_all(self) -> None: - """关闭所有连接""" - logger.info("正在关闭连接池中的所有连接...") - while not self.pool.empty(): - try: - conn = self.pool.get_nowait() - conn.close() - except (Empty, Exception): - break - - self.active_connections = 0 - logger.info("连接池已关闭") + return result class DataProcessor: @@ -357,12 +199,76 @@ class DataProcessor: class MariaDBClient: - """优化后的MariaDB数据库客户端""" + """简化版 MariaDB 客户端(内置轻量连接池以复用连接)""" - def __init__(self, config: DatabaseConfig, max_connections: int = 10): - self.config = config - self.connection_pool = ConnectionPool(config, max_connections) + def __init__(self, max_connections: int = 10): self.data_processor = DataProcessor() + self._max_connections = max_connections + self._pool: Queue = Queue(maxsize=max_connections) + self._active = 0 + self._lock = threading.Lock() + # 预创建少量连接,降低首次延迟 + initial = min(3, max_connections) + for _ in range(initial): + conn = self._create_connection() + if conn: + try: + self._pool.put_nowait(conn) + self._active += 1 + except Full: + try: + conn.close() + except Exception: + pass + + def _create_connection(self): + try: + return pymysql.connect( + host=DB_HOST, + port=DB_PORT, + user=DB_USER, + password=DB_PASSWORD, + charset=DB_CHARSET, + connect_timeout=DB_CONNECT_TIMEOUT, + read_timeout=DB_READ_TIMEOUT, + write_timeout=DB_WRITE_TIMEOUT, + autocommit=True + ) + except Exception as e: + logger.error(f"创建数据库连接失败: {e}") + return None + + def _acquire_connection(self): + # 先尝试不阻塞获取 + try: + return self._pool.get_nowait() + except Empty: + # 池为空,若可创建新连接则创建,否则阻塞等待 + with self._lock: + if self._active < self._max_connections: + conn = self._create_connection() + if conn: + self._active += 1 + return conn + # 达到上限,阻塞等待可用连接 + try: + return self._pool.get(timeout=30) + except Empty: + raise RuntimeError("获取数据库连接超时") + + def _release_connection(self, conn): + if not conn: + return + try: + self._pool.put_nowait(conn) + except Full: + # 池已满,关闭多余连接 + try: + conn.close() + except Exception: + pass + with self._lock: + self._active = max(0, self._active - 1) def __enter__(self) -> 'MariaDBClient': return self @@ -371,30 +277,52 @@ class MariaDBClient: self.close() def close(self) -> None: - """关闭客户端""" - self.connection_pool.close_all() + """关闭连接池中的连接""" + try: + while True: + try: + conn = self._pool.get_nowait() + except Empty: + break + try: + conn.close() + except Exception: + pass + finally: + with self._lock: + self._active = 0 def execute_query(self, sql: str, params: Optional[Tuple] = None) -> Tuple[Optional[pd.DataFrame], List[str]]: - """执行SQL查询""" + """执行SQL查询(复用连接池连接)""" + conn = None try: - with self.connection_pool.get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(sql, params) - results = cursor.fetchall() - - # 获取列名 - column_names = [desc[0] for desc in cursor.description] if cursor.description else [] - - if results: - df = pd.DataFrame(results, columns=column_names) - return df, column_names - else: - return pd.DataFrame(), column_names - + conn = self._acquire_connection() + with conn.cursor() as cursor: + cursor.execute(sql, params) + results = cursor.fetchall() + column_names = [desc[0] for desc in cursor.description] if cursor.description else [] + if results: + df = pd.DataFrame(results, columns=column_names) + return df, column_names + else: + return pd.DataFrame(), column_names except Exception as e: logger.error(f"执行查询时出错: {e}") logger.error(f"SQL: {sql}") return None, [] + finally: + if conn: + # 若连接异常,尝试关闭并减少活跃计数;否则归还 + try: + conn.ping(reconnect=False) + self._release_connection(conn) + except Exception: + try: + conn.close() + except Exception: + pass + with self._lock: + self._active = max(0, self._active - 1) def query_sessions(self, start_date: str, end_date: str) -> Optional[pd.DataFrame]: """查询指定日期范围内的会话数据""" @@ -483,125 +411,79 @@ class MariaDBClient: logger.error(f"导出到Excel时出错: {e}") return None - -class SessionProcessor: - """会话处理器,负责并发处理""" - - def __init__(self, db_client: MariaDBClient, max_workers: int = None): - self.db_client = db_client - self.max_workers = max_workers if max_workers is not None else os.cpu_count() - self.temp_save_lock = threading.Lock() # 添加锁用于保护临时保存操作 - - logger.info(f"初始化会话处理器: max_workers={self.max_workers}") - - def process_sessions(self, sessions_df: pd.DataFrame) -> List[List[Dict[str, Any]]]: - """处理所有会话数据""" - if sessions_df.empty: - logger.warning("没有会话数据需要处理") - return [] - - total_sessions = len(sessions_df) - logger.info(f"开始处理 {total_sessions} 个会话...") - - all_conversations = [] - - # 直接并发处理每个会话 - def process_single_session(session_row): - try: - session_id = session_row['SESSION_ID'] - messages_df = self.db_client.query_messages_by_session_id(session_id) - - if messages_df is not None and not messages_df.empty: - conversation = self.db_client.data_processor.messages_df_to_list(messages_df, session_row) - if conversation: - return conversation - except Exception as e: - logger.error(f"处理会话 {session_row.get('SESSION_ID', 'unknown')} 时出错: {e}") - return None - - # 使用线程池处理所有会话 - with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor: - # 提交所有会话处理任务 - future_to_session = { - executor.submit(process_single_session, row): i - for i, row in sessions_df.iterrows() - } - - # 收集结果 - with tqdm(total=total_sessions, desc="处理会话进度") as pbar: - for future in concurrent.futures.as_completed(future_to_session): - try: - conversation = future.result() - if conversation: - all_conversations.append(conversation) - - # 每处理100个对话临时保存一次 - if len(all_conversations) % 100 == 0: - with self.temp_save_lock: - logger.info(f"临时保存 {len(all_conversations)} 个对话") - temp_output_file = self.db_client.export_to_excel( - all_conversations, - f"客服对话记录_临时保存", - output_dir="/data/QueryRewrite/data/excel" - ) - if temp_output_file: - logger.info(f"临时保存完成: {temp_output_file}") - - except Exception as e: - session_idx = future_to_session[future] - logger.error(f"处理会话索引 {session_idx} 时出错: {e}") - - pbar.update(1) - - logger.info(f"处理完成,共获得 {len(all_conversations)} 个有效对话") - return all_conversations - - def main() -> None: - """主函数""" + """主函数(精简版)""" try: - # 加载配置 - config = DatabaseConfig.from_config_file() - logger.info(f"使用数据库配置: {config.host}:{config.port}") + logger.info(f"使用数据库配置: {DB_HOST}:{DB_PORT}") - # 创建数据库客户端 - with MariaDBClient(config, max_connections=12) as db_client: + # 创建数据库客户端(简化) + with MariaDBClient() as db_client: # 查询会话数据 start_date = '2025-08-01 00:00:00' end_date = '2025-08-01 23:00:00' logger.info(f"查询时间范围: {start_date} 到 {end_date}") - # 创建会话处理器 - processor = SessionProcessor(db_client) - # is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None - # if is_debug: - # messages_df = db_client.query_messages_by_session_id("86c919e0-09f1-11f0-84ae-2daf59566989") - # print(db_client.data_processor.messages_df_to_list(messages_df)) - # return [] sessions_df = db_client.query_sessions(start_date, end_date) if sessions_df is None or sessions_df.empty: logger.warning("没有找到符合条件的会话数据") return - - # 处理会话数据 - all_conversations = processor.process_sessions(sessions_df) + + # 直接并发处理每个会话(替代 SessionProcessor) + total_sessions = len(sessions_df) + all_conversations: List[List[Dict[str, Any]]] = [] + temp_save_lock = threading.Lock() + + def process_single_session(session_row): + try: + session_id = session_row['SESSION_ID'] + messages_df = db_client.query_messages_by_session_id(session_id) + if messages_df is not None and not messages_df.empty: + conversation = db_client.data_processor.messages_df_to_list(messages_df, session_row) + if conversation: + return conversation + except Exception as e: + logger.error(f"处理会话 {session_row.get('SESSION_ID', 'unknown')} 时出错: {e}") + return None + + with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: + future_to_session = {executor.submit(process_single_session, row): i for i, row in sessions_df.iterrows()} + with tqdm(total=total_sessions, desc="处理会话进度") as pbar: + for future in concurrent.futures.as_completed(future_to_session): + try: + conversation = future.result() + if conversation: + all_conversations.append(conversation) + if len(all_conversations) % 100 == 0: + with temp_save_lock: + logger.info(f"临时保存 {len(all_conversations)} 个对话") + temp_output_file = db_client.export_to_excel( + all_conversations, + f"客服对话记录_临时保存", + output_dir="/data/QueryRewrite/data/excel" + ) + if temp_output_file: + logger.info(f"临时保存完成: {temp_output_file}") + except Exception as e: + session_idx = future_to_session[future] + logger.error(f"处理会话索引 {session_idx} 时出错: {e}") + pbar.update(1) + # 导出结果 if all_conversations: output_file = db_client.export_to_excel( - all_conversations, + all_conversations, "客服对话记录", output_dir="/data/QueryRewrite/data/excel" ) - if output_file: logger.info(f"处理完成!共导出 {len(all_conversations)} 个对话到文件: {output_file}") else: logger.error("导出文件失败") else: logger.warning("没有有效的对话数据可导出") - + except KeyboardInterrupt: logger.info("用户中断程序") except Exception as e: diff --git a/rag2_0/intent_recognition/IntentRecognition.py b/rag2_0/intent_recognition/IntentRecognition.py index 14cbe9e..a1790d2 100755 --- a/rag2_0/intent_recognition/IntentRecognition.py +++ b/rag2_0/intent_recognition/IntentRecognition.py @@ -17,21 +17,18 @@ from typing import List, Tuple, Dict, Any, Optional import re import jieba import time -import threading from .PromptTemplates import (classification_prompt, query_rewrite_prompt_pro, extract_nouns_prompt, classification_info, - slot_filling_prompt, step_back_prompt, - hyde_prompt) + slot_filling_prompt, step_back_prompt) from .DataModels import ( Classification, QueryRewrite, Term, TermList, SoftwareFunctionSlots, SoftwareTroubleShootingSlots, ProfessionalConsultingSlots, DataProblemSlots, FileExtensionConsultingSlots, SoftwareLockSlots, - InstallationDownloadSlots, ProblemDiagnosisSlots, OtherSlots, IntentAndSlotResult, - StepBackPrompt, HypotheticalDocument + InstallationDownloadSlots, ProblemDiagnosisSlots, OtherSlots, + StepBackPrompt ) -from .ProfessionalNounVector import ProfessionalNounRetriever, AsyncProfessionalNounRetriever from rag2_0.tool.ModelTool import OpenAiLLM class AsyncIntentRecognizer: @@ -344,22 +341,20 @@ class AsyncIntentRecognizer: """ start_time = time.time() # 记录开始时间 - prompt=f""" - 当前提问内容: - {query} - 对话上下文: - - {json.dumps(chat_history, ensure_ascii=False)} - + prompt=f"""当前提问内容: +{query} +对话上下文: + +{json.dumps(chat_history, ensure_ascii=False)} + - 1、请从当前提问内容中提取电力造价行中定额编码、定额名称、清单编码、清单名称 - 2、请勿随机编造,如果没有提取到内容返回空的JSON - 3、返回结果为json格式,必须严格以纯JSON格式输出 - {{ - "dinge_info_list":{{"dinge_code_list":["xxxx","xxxx"], "dinge_name_list":["xxxx","xxxx"]}}, - "qingdan_info":{{"qingdan_code_list":["xxxx","xxxx"], "qingdan_name_list":["xxxx","xxxx"]}} - }} - """ +1、请从当前提问内容中提取电力造价行中定额编码、定额名称、清单编码、清单名称 +2、请勿随机编造,如果没有提取到内容返回空的JSON +3、返回结果为json格式,必须严格以纯JSON格式输出 +{{ + "dinge_info_list":{{"dinge_code_list":["xxxx","xxxx"], "dinge_name_list":["xxxx","xxxx"]}}, + "qingdan_info":{{"qingdan_code_list":["xxxx","xxxx"], "qingdan_name_list":["xxxx","xxxx"]}} +}}""" try: # response = await self._llm.ainvoke(prompt, response_format={"type": "json_object"}, extra_body={"enable_thinking": False}) @@ -489,6 +484,7 @@ class AsyncIntentRecognizer: # 特殊处理 锁相关咨询 if classification.vertical_classification == "安装下载注册" and classification.sub_classification == "软件锁类": process_lock_start_time = time.time() + # 特殊处理提问只有锁号的问题,手动将问题改写为特定格式 rewrite.rewrite = self._process_lock_related_query(rewrite.rewrite) process_lock_end_time = time.time() process_lock_time = process_lock_end_time - process_lock_start_time diff --git a/rag2_0/intent_recognition/PromptTemplates.py b/rag2_0/intent_recognition/PromptTemplates.py index 774f792..a8766e6 100755 --- a/rag2_0/intent_recognition/PromptTemplates.py +++ b/rag2_0/intent_recognition/PromptTemplates.py @@ -56,13 +56,11 @@ classification_info="""【垂直领域分类】: 4. 问题排查类:软件安装下载失败、报错,系统兼容性问题等 【固定话术类包括以下类】: - 1. 规费咨询 **以下两种情况才属于该类** 1、当询问规费(如社会保障费和住房公积金)费率是/填多少 2、去哪里获取规费费率 **其余涉及规费的属于其他垂直领域分类** - 2. 调差下载更新 **以下两种情况才属于该类** 1、询问如何下载导入调差文件、调差插件 @@ -73,111 +71,29 @@ classification_info="""【垂直领域分类】: 【其他】: 1. 其他 -分类优先级: - 固定话术类 > 软件问题 、 业务问题 、 安装下载注册 > 其他 -""" +分类优先级:固定话术类 > 软件问题 、 业务问题 、 安装下载注册 > 其他""" -classification_prompt=""" - 用户正在使用电力造价软件或想询问电力造价领域相关知识,你需要根据用户的输入内容集合历史对话(如果存在),将其归类为以下垂直领域之一: - {classification_info} +classification_prompt="""用户正在使用电力造价软件或想询问电力造价领域相关知识,你需要根据用户的输入内容集合历史对话(如果存在),将其归类为以下垂直领域之一: +{classification_info} - ## 【历史对话记录】 - {chat_history} +## 【历史对话记录】 +{chat_history} - 【用户输入】: - {user_input} +【用户输入】: +{user_input} - 【输出格式要求】: - {output_format} - - 【示例】 - 用户输入1: 技改T1怎样新建工程 - 输出1: - {{ - "vertical_classification":"软件问题", - "sub_classification":"软件功能" - }} - -""" - -query_rewrite_prompt = """ - -# 电力造价专业问答优化工程师 - -你是一名电力造价专业问答优化工程师,负责通过专业关键词集合替换原始问题中的非专业表述以提升知识库检索准确率。 - -## 核心任务 -将用户的原始问题结合专业术语库进行规范化重构,提高知识库检索的准确性和专业性。 - -## 处理流程 -### 第一阶段:输入解析 -1. 解析基础信息 - - 原始问题(需保留核心语义):{query} - - 关键词集合:{keywords} - -### 第二阶段:匹配分析 -**匹配规则:** -1. 检查原始问题中是否包含关键词集合中的`name`字段或`synonymous`字段中的任何词汇 -2. 统计匹配的术语数量 -3. 判断执行路径: - - 匹配术语 ≥ 1个 → 执行重构流程 - - 匹配术语 = 0个 → 直接输出原始问题 - -### 第三阶段:问题重构 -**重构原则(按优先级排序):** - -1. **语义保真**:严格保持原问题的核心意图和诉求 -2. **术语规范**: - - 将匹配到的同义词替换为对应的标准术语(name字段) - - 对在关键词中的标准术语使用【】进行标记 - - 保留在原问题中未在关键词库中的专业术语、限定词和修饰词 -3. **结构优化**: - - 保持原问题的语态特征5W2H - - 保持主谓宾结构清晰 - - 保留时间、版本等限定条件 - -**术语处理规则:** -- 优先级1:保留原问题中的专业术语、限定词和修饰词(即使不在关键词库中) -- 优先级2:将同义词替换为标准术语并用【】标记 -- 优先级3:对原问题中已存在的标准术语添加【】标记 - -# 输出规范 +【输出格式要求】: {output_format} -# 示范案例库 -▶ 案例1(有效匹配) -输入: -原始问题:怎么把旧版西藏定额工程转到Z1新版 -关键词:【'老版本定额升级', '批量设置定额', '西藏造价软件Z1'】 -输出: -{{"rewrite":"【西藏造价软件Z1】如何执行【老版本定额升级】操作?"}} +【示例】 +用户输入1: 技改T1怎样新建工程 +输出1: +{{ +"vertical_classification":"软件问题", +"sub_classification":"软件功能" +}}""" -▶ 案例2(无效匹配) -输入: -原始问题:程序界面文字显示过小如何处理? -关键词:【'定额升级', '工程批量导入'】 -输出: -{{"rewrite":"程序界面文字显示过小如何处理?"}} - -▶ 案例3(部分匹配,但保留修饰限定词) -输入: -原始问题:"配网软件D3能导出清单的计算公式吗? -关键词:【'配网工程计价通D3软件', '计算式'】 -输出(保留限定修饰词"清单"): -{{"rewrite":"【配网工程计价通D3软件】能导出清单的【计算式】吗?"}} - -## 质量检查清单 -执行前请确认: -- [ ] 是否保持了原问题的核心诉求? -- [ ] 是否正确执行了同义词替换? -- [ ] 是否保留了原问题中的专业术语和限定条件? -- [ ] 是否正确使用了【】标记? -- [ ] 重构后的问题是否自然流畅? - """ - - -query_rewrite_prompt_pro=""" -# 问答优化工程师 +query_rewrite_prompt_pro="""# 问答优化工程师 **角色**:基于历史对话和术语库重构问题,提升知识库检索准确率。 **最高准则**: 1、保持问题核心意图,允许指代消除 @@ -196,18 +112,18 @@ query_rewrite_prompt_pro=""" ## 处理流程 ### 一、输入解析 - - 原始问题(需保留核心语义): - {query} +- 原始问题(需保留核心语义): + {query} - - 术语库集合(用于同义词转标准词环节): - - {keywords} - +- 术语库集合(用于同义词转标准词环节): + +{keywords} + - - 历史对话记录: - - {chat_history} - +- 历史对话记录: + +{chat_history} + ### 一、重构流程 1、问题是否指代不明,指代不明时根据历史对话补充上下文 @@ -221,23 +137,11 @@ query_rewrite_prompt_pro=""" ## 输出规范 {output_format} -## 示例模仿 -示例1: -输入: - -'user': '811623110668是哪款软件的锁? -'assistant': 可通过查询软件锁的许可证信息,通过许可证名称可以判断对应软件 - - ”锂离子电池储能安装“ -输出: -{{"rewrite": "许可证名称为‘锂离子电池储能安装’对应什么软件?"}} - ## 质量自检 - [] **主题是否合理继承?** - [] 核心诉求是否保留? - [] 语句是否自然流畅? -- [] 避免补充无关信息 -""" +- [] 避免补充无关信息""" slot_filling_prompt = """ 你是一个专业的电力造价领域问题槽位填充助手。你需要从用户问题中提取关键信息,并填充到对应的数据结构中。 @@ -282,8 +186,7 @@ slot_filling_prompt = """ """ # 意图优化环节提示词模板 -step_back_prompt = """ -# 后退提示生成器 +step_back_prompt = """# 后退提示生成器 你是一个专业的电力造价领域问题抽象专家。你的任务是根据用户的具体问题,提出一个更抽象、更高层次的问题,帮助系统更好地理解用户的意图。 @@ -292,11 +195,11 @@ step_back_prompt = """ 2. 考虑历史对话和会话背景,理解用户当前问题的上下文 3. 生成更抽象、更高层次的问题,称为"后退问题",后退问题可以生成多个,依次后退到更抽象、更高层次的问题 4. 后退问题应该: - - 更加通用和抽象,不应包含原始问题的具体细节(包括场景限定、界面限定等其他限定词语) - - 涵盖原始问题的核心主题 - - 去除过于具体的限制条件(如时间、地点、特定版本、特定工程等) - - 保持在同一领域和主题范围内 - - 依次移除问题中的限定词或者修饰词 +- 更加通用和抽象,不应包含原始问题的具体细节(包括场景限定、界面限定等其他限定词语) +- 涵盖原始问题的核心主题 +- 去除过于具体的限制条件(如时间、地点、特定版本、特定工程等) +- 保持在同一领域和主题范围内 +- 依次移除问题中的限定词或者修饰词 ## 输入 用户原始问题: {query} @@ -320,9 +223,7 @@ step_back_prompt = """ "original_query": "某个设备更换后,如何在系统中更新对应的定额?", "can_use_back_prompt": true, "step_back_query": ["如何更新设备对应的定额?", "如何更新定额?"] -}} - -""" +}}""" follow_up_questions_prompt = """ # 后续问题生成器 diff --git a/rag2_0/tool/ModelTool.py b/rag2_0/tool/ModelTool.py index 1bb92da..2cecdbf 100755 --- a/rag2_0/tool/ModelTool.py +++ b/rag2_0/tool/ModelTool.py @@ -193,7 +193,9 @@ class OpenAiLLM: messages=[{'role': 'user', 'content': user_prompt}], **kwargs ) - return completion.choices[0].message + message = completion.choices[0].message + message.usage = completion.usage + return message except Exception as e: raise RuntimeError(f"OpenAiLLM:invoke:error:{str(e)}") from e @@ -225,36 +227,16 @@ class OpenAiLLM: if __name__ == "__main__": # 测试重排模型 - reranker = SiliconFlowReRankerModel() - - # 测试用例1:简单问题 - query = "如何通过【电力经济评价软件】的【打开】功能加载工程文件?" - documents = [] - results = reranker.rerank(query, documents) - print(f"测试用例1 - 查询:{query}") - for idx, item in enumerate(results): - print(f"{idx+1}. 文档: {item['document']}, 分数: {item['score']}") - print("-" * 50) - - # 异步测试示例 - async def test_async(): - # 测试异步嵌入 - api_key = APIKeyManager.get_api_key() - embeddings = XinferenceEmbeddings(api_key=api_key) - query_embedding = await embeddings.embed_query_async("测试查询") - print(f"异步嵌入向量维度: {len(query_embedding)}") - - # 测试异步重排序 - results = await SiliconFlowReRankerModel.rerank_async(query, documents) - print(f"异步重排序结果数量: {len(results)}") - - # 测试异步LLM调用 - llm = OpenAiLLM() - response = await llm.ainvoke("你好,请简单介绍一下自己") - print(f"异步LLM响应: {response.content}") - - # 如果需要运行异步测试,取消下面的注释 - # import asyncio - # asyncio.run(test_async()) - + base_url = os.getenv("OPENAI_API_BASE") + model_name = os.getenv("MODEL_NAME", "gpt-3.5-turbo") + # 初始化LLM + llm_params = { + "temperature": 0.4, # 降低随机性,使结果更确定 + "top_p": 0.7, + "model": model_name, + "base_url": base_url + } + _llm = OpenAiLLM(**llm_params) + promt="""你好,请简单介绍一下自己""" + print(_llm.invoke(promt)) \ No newline at end of file