调整提示词、简化代码

This commit is contained in:
2025-09-25 16:23:13 +08:00
parent 640e02f89e
commit 2b13fdab99
5 changed files with 235 additions and 476 deletions
+1 -3
View File
@@ -11,9 +11,7 @@ data/excel/*.xlsx
!data/excel/Excel版 清单定额库/ !data/excel/Excel版 清单定额库/
!data/excel/Excel版 清单定额库/** !data/excel/Excel版 清单定额库/**
data/logs/* data/logs/*
rag2_0/dify/Test.py
data/query_logs/*
data/conversations/*
data/test* data/test*
data/temp* data/temp*
data/db/answer_logs.db data/db/answer_logs.db
data/db/qingdan_ding_e_ku.db
+155 -273
View File
@@ -5,23 +5,18 @@ from __future__ import annotations
import json import json
import os import os
import re import re
import configparser
import logging import logging
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass from dataclasses import dataclass
from contextlib import contextmanager
import threading import threading
import time
from queue import Queue, Empty, Full
import pandas as pd import pandas as pd
import pymysql import pymysql
from pymysql.connections import Connection
from pymysql.cursors import Cursor
from tqdm import tqdm from tqdm import tqdm
import concurrent.futures import concurrent.futures
import sys import sys
from queue import Queue, Empty, Full
os.makedirs('./data/logs', exist_ok=True) os.makedirs('./data/logs', exist_ok=True)
# 配置日志 # 配置日志
@@ -35,6 +30,18 @@ logging.basicConfig(
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# =====================
# 硬编码数据库配置(简化)
# =====================
DB_HOST = '192.168.0.123'
DB_PORT = 3307
DB_USER = 'fuzhimei'
DB_PASSWORD = 'fuzhimei@135'
DB_CHARSET = 'utf8mb4'
DB_CONNECT_TIMEOUT = 10
DB_READ_TIMEOUT = 300
DB_WRITE_TIMEOUT = 300
def parse_session_tags(input_string): def parse_session_tags(input_string):
""" """
解析sessionTag格式的字符串,支持任意数量的sessionTag 解析sessionTag格式的字符串,支持任意数量的sessionTag
@@ -76,171 +83,6 @@ def parse_session_tags(input_string):
return result return result
@dataclass
class DatabaseConfig:
"""数据库配置类"""
host: str = '192.168.0.123'
port: int = 3307
user: str = 'fuzhimei'
password: str = 'fuzhimei@135'
charset: str = 'utf8mb4'
connect_timeout: int = 10
read_timeout: int = 300
write_timeout: int = 300
@classmethod
def from_config_file(cls, config_file: str = 'config.ini') -> 'DatabaseConfig':
"""从配置文件加载配置"""
if not os.path.exists(config_file):
logger.warning(f"配置文件 {config_file} 不存在,使用默认配置")
return cls()
config = configparser.ConfigParser()
config.read(config_file, encoding='utf-8')
if 'database' not in config:
logger.warning("配置文件中没有 [database] 部分,使用默认配置")
return cls()
db_config = config['database']
return cls(
host=db_config.get('host', cls.host),
port=int(db_config.get('port', cls.port)),
user=db_config.get('user', cls.user),
password=db_config.get('password', cls.password),
charset=db_config.get('charset', cls.charset),
connect_timeout=int(db_config.get('connect_timeout', cls.connect_timeout)),
read_timeout=int(db_config.get('read_timeout', cls.read_timeout)),
write_timeout=int(db_config.get('write_timeout', cls.write_timeout))
)
class ConnectionPool:
"""数据库连接池"""
def __init__(self, config: DatabaseConfig, max_connections: int = 10):
self.config = config
self.max_connections = max_connections
self.pool = Queue(maxsize=max_connections)
self.active_connections = 0
self.lock = threading.Lock()
# 预创建一些连接
self._initialize_pool()
def _initialize_pool(self) -> None:
"""初始化连接池,预创建一些连接"""
initial_connections = min(3, self.max_connections)
for _ in range(initial_connections):
try:
conn = self._create_connection()
if conn:
self.pool.put_nowait(conn)
self.active_connections += 1
except Full:
break
except Exception as e:
logger.error(f"初始化连接池时创建连接失败: {e}")
def _create_connection(self) -> Optional[Connection]:
"""创建新的数据库连接"""
try:
conn = pymysql.connect(
host=self.config.host,
port=self.config.port,
user=self.config.user,
password=self.config.password,
charset=self.config.charset,
connect_timeout=self.config.connect_timeout,
read_timeout=self.config.read_timeout,
write_timeout=self.config.write_timeout,
autocommit=True
)
return conn
except Exception as e:
logger.error(f"创建数据库连接失败: {e}")
return None
@contextmanager
def get_connection(self):
"""获取连接的上下文管理器"""
conn = None
try:
# 尝试从池中获取连接
try:
conn = self.pool.get_nowait()
except Empty:
# 池中没有连接,尝试创建新连接
with self.lock:
if self.active_connections < self.max_connections:
conn = self._create_connection()
if conn:
self.active_connections += 1
else:
raise Exception("无法创建新的数据库连接")
else:
# 等待可用连接
logger.info("等待可用连接...")
conn = self.pool.get(timeout=30)
# 检查连接是否仍然有效
if conn and not self._is_connection_alive(conn):
logger.warning("连接已失效,重新创建")
try:
conn.close()
except:
pass
conn = self._create_connection()
if not conn:
raise Exception("重新创建连接失败")
yield conn
except Exception as e:
logger.error(f"获取数据库连接时出错: {e}")
if conn:
try:
conn.close()
except:
pass
with self.lock:
self.active_connections -= 1
raise
else:
# 归还连接到池中
if conn:
try:
self.pool.put_nowait(conn)
except Full:
# 池已满,关闭连接
try:
conn.close()
except:
pass
with self.lock:
self.active_connections -= 1
def _is_connection_alive(self, conn: Connection) -> bool:
"""检查连接是否仍然有效"""
try:
conn.ping(reconnect=False)
return True
except:
return False
def close_all(self) -> None:
"""关闭所有连接"""
logger.info("正在关闭连接池中的所有连接...")
while not self.pool.empty():
try:
conn = self.pool.get_nowait()
conn.close()
except (Empty, Exception):
break
self.active_connections = 0
logger.info("连接池已关闭")
class DataProcessor: class DataProcessor:
"""数据处理器""" """数据处理器"""
@@ -357,12 +199,76 @@ class DataProcessor:
class MariaDBClient: class MariaDBClient:
"""优化后的MariaDB数据库客户端""" """简化版 MariaDB 客户端(内置轻量连接池以复用连接)"""
def __init__(self, config: DatabaseConfig, max_connections: int = 10): def __init__(self, max_connections: int = 10):
self.config = config
self.connection_pool = ConnectionPool(config, max_connections)
self.data_processor = DataProcessor() self.data_processor = DataProcessor()
self._max_connections = max_connections
self._pool: Queue = Queue(maxsize=max_connections)
self._active = 0
self._lock = threading.Lock()
# 预创建少量连接,降低首次延迟
initial = min(3, max_connections)
for _ in range(initial):
conn = self._create_connection()
if conn:
try:
self._pool.put_nowait(conn)
self._active += 1
except Full:
try:
conn.close()
except Exception:
pass
def _create_connection(self):
try:
return pymysql.connect(
host=DB_HOST,
port=DB_PORT,
user=DB_USER,
password=DB_PASSWORD,
charset=DB_CHARSET,
connect_timeout=DB_CONNECT_TIMEOUT,
read_timeout=DB_READ_TIMEOUT,
write_timeout=DB_WRITE_TIMEOUT,
autocommit=True
)
except Exception as e:
logger.error(f"创建数据库连接失败: {e}")
return None
def _acquire_connection(self):
# 先尝试不阻塞获取
try:
return self._pool.get_nowait()
except Empty:
# 池为空,若可创建新连接则创建,否则阻塞等待
with self._lock:
if self._active < self._max_connections:
conn = self._create_connection()
if conn:
self._active += 1
return conn
# 达到上限,阻塞等待可用连接
try:
return self._pool.get(timeout=30)
except Empty:
raise RuntimeError("获取数据库连接超时")
def _release_connection(self, conn):
if not conn:
return
try:
self._pool.put_nowait(conn)
except Full:
# 池已满,关闭多余连接
try:
conn.close()
except Exception:
pass
with self._lock:
self._active = max(0, self._active - 1)
def __enter__(self) -> 'MariaDBClient': def __enter__(self) -> 'MariaDBClient':
return self return self
@@ -371,30 +277,52 @@ class MariaDBClient:
self.close() self.close()
def close(self) -> None: def close(self) -> None:
"""关闭客户端""" """关闭连接池中的连接"""
self.connection_pool.close_all() try:
while True:
try:
conn = self._pool.get_nowait()
except Empty:
break
try:
conn.close()
except Exception:
pass
finally:
with self._lock:
self._active = 0
def execute_query(self, sql: str, params: Optional[Tuple] = None) -> Tuple[Optional[pd.DataFrame], List[str]]: def execute_query(self, sql: str, params: Optional[Tuple] = None) -> Tuple[Optional[pd.DataFrame], List[str]]:
"""执行SQL查询""" """执行SQL查询(复用连接池连接)"""
conn = None
try: try:
with self.connection_pool.get_connection() as conn: conn = self._acquire_connection()
with conn.cursor() as cursor: with conn.cursor() as cursor:
cursor.execute(sql, params) cursor.execute(sql, params)
results = cursor.fetchall() results = cursor.fetchall()
# 获取列名
column_names = [desc[0] for desc in cursor.description] if cursor.description else [] column_names = [desc[0] for desc in cursor.description] if cursor.description else []
if results: if results:
df = pd.DataFrame(results, columns=column_names) df = pd.DataFrame(results, columns=column_names)
return df, column_names return df, column_names
else: else:
return pd.DataFrame(), column_names return pd.DataFrame(), column_names
except Exception as e: except Exception as e:
logger.error(f"执行查询时出错: {e}") logger.error(f"执行查询时出错: {e}")
logger.error(f"SQL: {sql}") logger.error(f"SQL: {sql}")
return None, [] return None, []
finally:
if conn:
# 若连接异常,尝试关闭并减少活跃计数;否则归还
try:
conn.ping(reconnect=False)
self._release_connection(conn)
except Exception:
try:
conn.close()
except Exception:
pass
with self._lock:
self._active = max(0, self._active - 1)
def query_sessions(self, start_date: str, end_date: str) -> Optional[pd.DataFrame]: def query_sessions(self, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""查询指定日期范围内的会话数据""" """查询指定日期范围内的会话数据"""
@@ -483,101 +411,18 @@ class MariaDBClient:
logger.error(f"导出到Excel时出错: {e}") logger.error(f"导出到Excel时出错: {e}")
return None return None
class SessionProcessor:
"""会话处理器,负责并发处理"""
def __init__(self, db_client: MariaDBClient, max_workers: int = None):
self.db_client = db_client
self.max_workers = max_workers if max_workers is not None else os.cpu_count()
self.temp_save_lock = threading.Lock() # 添加锁用于保护临时保存操作
logger.info(f"初始化会话处理器: max_workers={self.max_workers}")
def process_sessions(self, sessions_df: pd.DataFrame) -> List[List[Dict[str, Any]]]:
"""处理所有会话数据"""
if sessions_df.empty:
logger.warning("没有会话数据需要处理")
return []
total_sessions = len(sessions_df)
logger.info(f"开始处理 {total_sessions} 个会话...")
all_conversations = []
# 直接并发处理每个会话
def process_single_session(session_row):
try:
session_id = session_row['SESSION_ID']
messages_df = self.db_client.query_messages_by_session_id(session_id)
if messages_df is not None and not messages_df.empty:
conversation = self.db_client.data_processor.messages_df_to_list(messages_df, session_row)
if conversation:
return conversation
except Exception as e:
logger.error(f"处理会话 {session_row.get('SESSION_ID', 'unknown')} 时出错: {e}")
return None
# 使用线程池处理所有会话
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# 提交所有会话处理任务
future_to_session = {
executor.submit(process_single_session, row): i
for i, row in sessions_df.iterrows()
}
# 收集结果
with tqdm(total=total_sessions, desc="处理会话进度") as pbar:
for future in concurrent.futures.as_completed(future_to_session):
try:
conversation = future.result()
if conversation:
all_conversations.append(conversation)
# 每处理100个对话临时保存一次
if len(all_conversations) % 100 == 0:
with self.temp_save_lock:
logger.info(f"临时保存 {len(all_conversations)} 个对话")
temp_output_file = self.db_client.export_to_excel(
all_conversations,
f"客服对话记录_临时保存",
output_dir="/data/QueryRewrite/data/excel"
)
if temp_output_file:
logger.info(f"临时保存完成: {temp_output_file}")
except Exception as e:
session_idx = future_to_session[future]
logger.error(f"处理会话索引 {session_idx} 时出错: {e}")
pbar.update(1)
logger.info(f"处理完成,共获得 {len(all_conversations)} 个有效对话")
return all_conversations
def main() -> None: def main() -> None:
"""主函数""" """主函数(精简版)"""
try: try:
# 加载配置 logger.info(f"使用数据库配置: {DB_HOST}:{DB_PORT}")
config = DatabaseConfig.from_config_file()
logger.info(f"使用数据库配置: {config.host}:{config.port}")
# 创建数据库客户端 # 创建数据库客户端(简化)
with MariaDBClient(config, max_connections=12) as db_client: with MariaDBClient() as db_client:
# 查询会话数据 # 查询会话数据
start_date = '2025-08-01 00:00:00' start_date = '2025-08-01 00:00:00'
end_date = '2025-08-01 23:00:00' end_date = '2025-08-01 23:00:00'
logger.info(f"查询时间范围: {start_date}{end_date}") logger.info(f"查询时间范围: {start_date}{end_date}")
# 创建会话处理器
processor = SessionProcessor(db_client)
# is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
# if is_debug:
# messages_df = db_client.query_messages_by_session_id("86c919e0-09f1-11f0-84ae-2daf59566989")
# print(db_client.data_processor.messages_df_to_list(messages_df))
# return []
sessions_df = db_client.query_sessions(start_date, end_date) sessions_df = db_client.query_sessions(start_date, end_date)
@@ -585,8 +430,46 @@ def main() -> None:
logger.warning("没有找到符合条件的会话数据") logger.warning("没有找到符合条件的会话数据")
return return
# 处理会话数据 # 直接并发处理每个会话(替代 SessionProcessor
all_conversations = processor.process_sessions(sessions_df) total_sessions = len(sessions_df)
all_conversations: List[List[Dict[str, Any]]] = []
temp_save_lock = threading.Lock()
def process_single_session(session_row):
try:
session_id = session_row['SESSION_ID']
messages_df = db_client.query_messages_by_session_id(session_id)
if messages_df is not None and not messages_df.empty:
conversation = db_client.data_processor.messages_df_to_list(messages_df, session_row)
if conversation:
return conversation
except Exception as e:
logger.error(f"处理会话 {session_row.get('SESSION_ID', 'unknown')} 时出错: {e}")
return None
with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
future_to_session = {executor.submit(process_single_session, row): i for i, row in sessions_df.iterrows()}
with tqdm(total=total_sessions, desc="处理会话进度") as pbar:
for future in concurrent.futures.as_completed(future_to_session):
try:
conversation = future.result()
if conversation:
all_conversations.append(conversation)
if len(all_conversations) % 100 == 0:
with temp_save_lock:
logger.info(f"临时保存 {len(all_conversations)} 个对话")
temp_output_file = db_client.export_to_excel(
all_conversations,
f"客服对话记录_临时保存",
output_dir="/data/QueryRewrite/data/excel"
)
if temp_output_file:
logger.info(f"临时保存完成: {temp_output_file}")
except Exception as e:
session_idx = future_to_session[future]
logger.error(f"处理会话索引 {session_idx} 时出错: {e}")
pbar.update(1)
# 导出结果 # 导出结果
if all_conversations: if all_conversations:
output_file = db_client.export_to_excel( output_file = db_client.export_to_excel(
@@ -594,7 +477,6 @@ def main() -> None:
"客服对话记录", "客服对话记录",
output_dir="/data/QueryRewrite/data/excel" output_dir="/data/QueryRewrite/data/excel"
) )
if output_file: if output_file:
logger.info(f"处理完成!共导出 {len(all_conversations)} 个对话到文件: {output_file}") logger.info(f"处理完成!共导出 {len(all_conversations)} 个对话到文件: {output_file}")
else: else:
+6 -10
View File
@@ -17,21 +17,18 @@ from typing import List, Tuple, Dict, Any, Optional
import re import re
import jieba import jieba
import time import time
import threading
from .PromptTemplates import (classification_prompt, query_rewrite_prompt_pro, from .PromptTemplates import (classification_prompt, query_rewrite_prompt_pro,
extract_nouns_prompt, classification_info, extract_nouns_prompt, classification_info,
slot_filling_prompt, step_back_prompt, slot_filling_prompt, step_back_prompt)
hyde_prompt)
from .DataModels import ( from .DataModels import (
Classification, QueryRewrite, Term, TermList, Classification, QueryRewrite, Term, TermList,
SoftwareFunctionSlots, SoftwareTroubleShootingSlots, ProfessionalConsultingSlots, SoftwareFunctionSlots, SoftwareTroubleShootingSlots, ProfessionalConsultingSlots,
DataProblemSlots, FileExtensionConsultingSlots, SoftwareLockSlots, DataProblemSlots, FileExtensionConsultingSlots, SoftwareLockSlots,
InstallationDownloadSlots, ProblemDiagnosisSlots, OtherSlots, IntentAndSlotResult, InstallationDownloadSlots, ProblemDiagnosisSlots, OtherSlots,
StepBackPrompt, HypotheticalDocument StepBackPrompt
) )
from .ProfessionalNounVector import ProfessionalNounRetriever, AsyncProfessionalNounRetriever
from rag2_0.tool.ModelTool import OpenAiLLM from rag2_0.tool.ModelTool import OpenAiLLM
class AsyncIntentRecognizer: class AsyncIntentRecognizer:
@@ -344,8 +341,7 @@ class AsyncIntentRecognizer:
""" """
start_time = time.time() # 记录开始时间 start_time = time.time() # 记录开始时间
prompt=f""" prompt=f"""当前提问内容:
当前提问内容:
<query>{query}</query> <query>{query}</query>
对话上下文: 对话上下文:
<chat_history> <chat_history>
@@ -358,8 +354,7 @@ class AsyncIntentRecognizer:
{{ {{
"dinge_info_list":{{"dinge_code_list":["xxxx","xxxx"], "dinge_name_list":["xxxx","xxxx"]}}, "dinge_info_list":{{"dinge_code_list":["xxxx","xxxx"], "dinge_name_list":["xxxx","xxxx"]}},
"qingdan_info":{{"qingdan_code_list":["xxxx","xxxx"], "qingdan_name_list":["xxxx","xxxx"]}} "qingdan_info":{{"qingdan_code_list":["xxxx","xxxx"], "qingdan_name_list":["xxxx","xxxx"]}}
}} }}"""
"""
try: try:
# response = await self._llm.ainvoke(prompt, response_format={"type": "json_object"}, extra_body={"enable_thinking": False}) # response = await self._llm.ainvoke(prompt, response_format={"type": "json_object"}, extra_body={"enable_thinking": False})
@@ -489,6 +484,7 @@ class AsyncIntentRecognizer:
# 特殊处理 锁相关咨询 # 特殊处理 锁相关咨询
if classification.vertical_classification == "安装下载注册" and classification.sub_classification == "软件锁类": if classification.vertical_classification == "安装下载注册" and classification.sub_classification == "软件锁类":
process_lock_start_time = time.time() process_lock_start_time = time.time()
# 特殊处理提问只有锁号的问题,手动将问题改写为特定格式
rewrite.rewrite = self._process_lock_related_query(rewrite.rewrite) rewrite.rewrite = self._process_lock_related_query(rewrite.rewrite)
process_lock_end_time = time.time() process_lock_end_time = time.time()
process_lock_time = process_lock_end_time - process_lock_start_time process_lock_time = process_lock_end_time - process_lock_start_time
+7 -106
View File
@@ -56,13 +56,11 @@ classification_info="""【垂直领域分类】:
4. 问题排查类:软件安装下载失败、报错,系统兼容性问题等 4. 问题排查类:软件安装下载失败、报错,系统兼容性问题等
【固定话术类包括以下类】: 【固定话术类包括以下类】:
1. 规费咨询 1. 规费咨询
**以下两种情况才属于该类** **以下两种情况才属于该类**
1、当询问规费(如社会保障费和住房公积金)费率是/填多少 1、当询问规费(如社会保障费和住房公积金)费率是/填多少
2、去哪里获取规费费率 2、去哪里获取规费费率
**其余涉及规费的属于其他垂直领域分类** **其余涉及规费的属于其他垂直领域分类**
2. 调差下载更新 2. 调差下载更新
**以下两种情况才属于该类** **以下两种情况才属于该类**
1、询问如何下载导入调差文件、调差插件 1、询问如何下载导入调差文件、调差插件
@@ -73,12 +71,9 @@ classification_info="""【垂直领域分类】:
【其他】: 【其他】:
1. 其他 1. 其他
分类优先级: 分类优先级:固定话术类 > 软件问题 、 业务问题 、 安装下载注册 > 其他"""
固定话术类 > 软件问题 、 业务问题 、 安装下载注册 > 其他
"""
classification_prompt=""" classification_prompt="""用户正在使用电力造价软件或想询问电力造价领域相关知识,你需要根据用户的输入内容集合历史对话(如果存在),将其归类为以下垂直领域之一:
用户正在使用电力造价软件或想询问电力造价领域相关知识,你需要根据用户的输入内容集合历史对话(如果存在),将其归类为以下垂直领域之一:
{classification_info} {classification_info}
## 【历史对话记录】 ## 【历史对话记录】
@@ -96,88 +91,9 @@ classification_prompt="""
{{ {{
"vertical_classification":"软件问题", "vertical_classification":"软件问题",
"sub_classification":"软件功能" "sub_classification":"软件功能"
}} }}"""
""" query_rewrite_prompt_pro="""# 问答优化工程师
query_rewrite_prompt = """
# 电力造价专业问答优化工程师
你是一名电力造价专业问答优化工程师,负责通过专业关键词集合替换原始问题中的非专业表述以提升知识库检索准确率。
## 核心任务
将用户的原始问题结合专业术语库进行规范化重构,提高知识库检索的准确性和专业性。
## 处理流程
### 第一阶段:输入解析
1. 解析基础信息
- 原始问题(需保留核心语义){query}
- 关键词集合:{keywords}
### 第二阶段:匹配分析
**匹配规则:**
1. 检查原始问题中是否包含关键词集合中的`name`字段或`synonymous`字段中的任何词汇
2. 统计匹配的术语数量
3. 判断执行路径:
- 匹配术语 ≥ 1个 → 执行重构流程
- 匹配术语 = 0个 → 直接输出原始问题
### 第三阶段:问题重构
**重构原则(按优先级排序):**
1. **语义保真**:严格保持原问题的核心意图和诉求
2. **术语规范**
- 将匹配到的同义词替换为对应的标准术语(name字段)
- 对在关键词中的标准术语使用【】进行标记
- 保留在原问题中未在关键词库中的专业术语、限定词和修饰词
3. **结构优化**
- 保持原问题的语态特征5W2H
- 保持主谓宾结构清晰
- 保留时间、版本等限定条件
**术语处理规则:**
- 优先级1:保留原问题中的专业术语、限定词和修饰词(即使不在关键词库中)
- 优先级2:将同义词替换为标准术语并用【】标记
- 优先级3:对原问题中已存在的标准术语添加【】标记
# 输出规范
{output_format}
# 示范案例库
▶ 案例1(有效匹配)
输入:
原始问题:怎么把旧版西藏定额工程转到Z1新版
关键词:【'老版本定额升级', '批量设置定额', '西藏造价软件Z1'
输出:
{{"rewrite":"【西藏造价软件Z1】如何执行【老版本定额升级】操作?"}}
▶ 案例2(无效匹配)
输入:
原始问题:程序界面文字显示过小如何处理?
关键词:【'定额升级', '工程批量导入'
输出:
{{"rewrite":"程序界面文字显示过小如何处理?"}}
▶ 案例3(部分匹配,但保留修饰限定词)
输入:
原始问题:"配网软件D3能导出清单的计算公式吗?
关键词:【'配网工程计价通D3软件', '计算式'
输出(保留限定修饰词"清单")
{{"rewrite":"【配网工程计价通D3软件】能导出清单的【计算式】吗?"}}
## 质量检查清单
执行前请确认:
- [ ] 是否保持了原问题的核心诉求?
- [ ] 是否正确执行了同义词替换?
- [ ] 是否保留了原问题中的专业术语和限定条件?
- [ ] 是否正确使用了【】标记?
- [ ] 重构后的问题是否自然流畅?
"""
query_rewrite_prompt_pro="""
# 问答优化工程师
**角色**:基于历史对话和术语库重构问题,提升知识库检索准确率。 **角色**:基于历史对话和术语库重构问题,提升知识库检索准确率。
**最高准则** **最高准则**
1、保持问题核心意图,允许指代消除 1、保持问题核心意图,允许指代消除
@@ -221,23 +137,11 @@ query_rewrite_prompt_pro="""
## 输出规范 ## 输出规范
{output_format} {output_format}
## 示例模仿
示例1
输入:
<history>
'user': '811623110668是哪款软件的锁?
'assistant': 可通过查询软件锁的许可证信息,通过许可证名称可以判断对应软件
</history>
<query> ”锂离子电池储能安装“ </query>
输出:
{{"rewrite": "许可证名称为‘锂离子电池储能安装’对应什么软件?"}}
## 质量自检 ## 质量自检
- [] **主题是否合理继承?** - [] **主题是否合理继承?**
- [] 核心诉求是否保留? - [] 核心诉求是否保留?
- [] 语句是否自然流畅? - [] 语句是否自然流畅?
- [] 避免补充无关信息 - [] 避免补充无关信息"""
"""
slot_filling_prompt = """ slot_filling_prompt = """
你是一个专业的电力造价领域问题槽位填充助手。你需要从用户问题中提取关键信息,并填充到对应的数据结构中。 你是一个专业的电力造价领域问题槽位填充助手。你需要从用户问题中提取关键信息,并填充到对应的数据结构中。
@@ -282,8 +186,7 @@ slot_filling_prompt = """
""" """
# 意图优化环节提示词模板 # 意图优化环节提示词模板
step_back_prompt = """ step_back_prompt = """# 后退提示生成器
# 后退提示生成器
你是一个专业的电力造价领域问题抽象专家。你的任务是根据用户的具体问题,提出一个更抽象、更高层次的问题,帮助系统更好地理解用户的意图。 你是一个专业的电力造价领域问题抽象专家。你的任务是根据用户的具体问题,提出一个更抽象、更高层次的问题,帮助系统更好地理解用户的意图。
@@ -320,9 +223,7 @@ step_back_prompt = """
"original_query": "某个设备更换后,如何在系统中更新对应的定额?", "original_query": "某个设备更换后,如何在系统中更新对应的定额?",
"can_use_back_prompt": true, "can_use_back_prompt": true,
"step_back_query": ["如何更新设备对应的定额?", "如何更新定额?"] "step_back_query": ["如何更新设备对应的定额?", "如何更新定额?"]
}} }}"""
"""
follow_up_questions_prompt = """ follow_up_questions_prompt = """
# 后续问题生成器 # 后续问题生成器
+15 -33
View File
@@ -193,7 +193,9 @@ class OpenAiLLM:
messages=[{'role': 'user', 'content': user_prompt}], messages=[{'role': 'user', 'content': user_prompt}],
**kwargs **kwargs
) )
return completion.choices[0].message message = completion.choices[0].message
message.usage = completion.usage
return message
except Exception as e: except Exception as e:
raise RuntimeError(f"OpenAiLLM:invoke:error:{str(e)}") from e raise RuntimeError(f"OpenAiLLM:invoke:error:{str(e)}") from e
@@ -225,36 +227,16 @@ class OpenAiLLM:
if __name__ == "__main__": if __name__ == "__main__":
# 测试重排模型 # 测试重排模型
reranker = SiliconFlowReRankerModel() base_url = os.getenv("OPENAI_API_BASE")
model_name = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
# 测试用例1:简单问题 # 初始化LLM
query = "如何通过【电力经济评价软件】的【打开】功能加载工程文件?" llm_params = {
documents = [] "temperature": 0.4, # 降低随机性,使结果更确定
results = reranker.rerank(query, documents) "top_p": 0.7,
print(f"测试用例1 - 查询:{query}") "model": model_name,
for idx, item in enumerate(results): "base_url": base_url
print(f"{idx+1}. 文档: {item['document']}, 分数: {item['score']}") }
print("-" * 50)
# 异步测试示例
async def test_async():
# 测试异步嵌入
api_key = APIKeyManager.get_api_key()
embeddings = XinferenceEmbeddings(api_key=api_key)
query_embedding = await embeddings.embed_query_async("测试查询")
print(f"异步嵌入向量维度: {len(query_embedding)}")
# 测试异步重排序
results = await SiliconFlowReRankerModel.rerank_async(query, documents)
print(f"异步重排序结果数量: {len(results)}")
# 测试异步LLM调用
llm = OpenAiLLM()
response = await llm.ainvoke("你好,请简单介绍一下自己")
print(f"异步LLM响应: {response.content}")
# 如果需要运行异步测试,取消下面的注释
# import asyncio
# asyncio.run(test_async())
_llm = OpenAiLLM(**llm_params)
promt="""你好,请简单介绍一下自己"""
print(_llm.invoke(promt))