调整提示词、简化代码

This commit is contained in:
2025-09-25 16:23:13 +08:00
parent 640e02f89e
commit 2b13fdab99
5 changed files with 235 additions and 476 deletions
+1 -3
View File
@@ -11,9 +11,7 @@ data/excel/*.xlsx
!data/excel/Excel版 清单定额库/
!data/excel/Excel版 清单定额库/**
data/logs/*
rag2_0/dify/Test.py
data/query_logs/*
data/conversations/*
data/test*
data/temp*
data/db/answer_logs.db
data/db/qingdan_ding_e_ku.db
+164 -282
View File
@@ -5,23 +5,18 @@ from __future__ import annotations
import json
import os
import re
import configparser
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass
from contextlib import contextmanager
import threading
import time
from queue import Queue, Empty, Full
import pandas as pd
import pymysql
from pymysql.connections import Connection
from pymysql.cursors import Cursor
from tqdm import tqdm
import concurrent.futures
import sys
from queue import Queue, Empty, Full
os.makedirs('./data/logs', exist_ok=True)
# 配置日志
@@ -35,6 +30,18 @@ logging.basicConfig(
)
logger = logging.getLogger(__name__)
# =====================
# 硬编码数据库配置(简化)
# =====================
DB_HOST = '192.168.0.123'
DB_PORT = 3307
DB_USER = 'fuzhimei'
DB_PASSWORD = 'fuzhimei@135'
DB_CHARSET = 'utf8mb4'
DB_CONNECT_TIMEOUT = 10
DB_READ_TIMEOUT = 300
DB_WRITE_TIMEOUT = 300
def parse_session_tags(input_string):
"""
解析sessionTag格式的字符串,支持任意数量的sessionTag
@@ -76,171 +83,6 @@ def parse_session_tags(input_string):
return result
@dataclass
class DatabaseConfig:
"""数据库配置类"""
host: str = '192.168.0.123'
port: int = 3307
user: str = 'fuzhimei'
password: str = 'fuzhimei@135'
charset: str = 'utf8mb4'
connect_timeout: int = 10
read_timeout: int = 300
write_timeout: int = 300
@classmethod
def from_config_file(cls, config_file: str = 'config.ini') -> 'DatabaseConfig':
"""从配置文件加载配置"""
if not os.path.exists(config_file):
logger.warning(f"配置文件 {config_file} 不存在,使用默认配置")
return cls()
config = configparser.ConfigParser()
config.read(config_file, encoding='utf-8')
if 'database' not in config:
logger.warning("配置文件中没有 [database] 部分,使用默认配置")
return cls()
db_config = config['database']
return cls(
host=db_config.get('host', cls.host),
port=int(db_config.get('port', cls.port)),
user=db_config.get('user', cls.user),
password=db_config.get('password', cls.password),
charset=db_config.get('charset', cls.charset),
connect_timeout=int(db_config.get('connect_timeout', cls.connect_timeout)),
read_timeout=int(db_config.get('read_timeout', cls.read_timeout)),
write_timeout=int(db_config.get('write_timeout', cls.write_timeout))
)
class ConnectionPool:
"""数据库连接池"""
def __init__(self, config: DatabaseConfig, max_connections: int = 10):
self.config = config
self.max_connections = max_connections
self.pool = Queue(maxsize=max_connections)
self.active_connections = 0
self.lock = threading.Lock()
# 预创建一些连接
self._initialize_pool()
def _initialize_pool(self) -> None:
"""初始化连接池,预创建一些连接"""
initial_connections = min(3, self.max_connections)
for _ in range(initial_connections):
try:
conn = self._create_connection()
if conn:
self.pool.put_nowait(conn)
self.active_connections += 1
except Full:
break
except Exception as e:
logger.error(f"初始化连接池时创建连接失败: {e}")
def _create_connection(self) -> Optional[Connection]:
"""创建新的数据库连接"""
try:
conn = pymysql.connect(
host=self.config.host,
port=self.config.port,
user=self.config.user,
password=self.config.password,
charset=self.config.charset,
connect_timeout=self.config.connect_timeout,
read_timeout=self.config.read_timeout,
write_timeout=self.config.write_timeout,
autocommit=True
)
return conn
except Exception as e:
logger.error(f"创建数据库连接失败: {e}")
return None
@contextmanager
def get_connection(self):
"""获取连接的上下文管理器"""
conn = None
try:
# 尝试从池中获取连接
try:
conn = self.pool.get_nowait()
except Empty:
# 池中没有连接,尝试创建新连接
with self.lock:
if self.active_connections < self.max_connections:
conn = self._create_connection()
if conn:
self.active_connections += 1
else:
raise Exception("无法创建新的数据库连接")
else:
# 等待可用连接
logger.info("等待可用连接...")
conn = self.pool.get(timeout=30)
# 检查连接是否仍然有效
if conn and not self._is_connection_alive(conn):
logger.warning("连接已失效,重新创建")
try:
conn.close()
except:
pass
conn = self._create_connection()
if not conn:
raise Exception("重新创建连接失败")
yield conn
except Exception as e:
logger.error(f"获取数据库连接时出错: {e}")
if conn:
try:
conn.close()
except:
pass
with self.lock:
self.active_connections -= 1
raise
else:
# 归还连接到池中
if conn:
try:
self.pool.put_nowait(conn)
except Full:
# 池已满,关闭连接
try:
conn.close()
except:
pass
with self.lock:
self.active_connections -= 1
def _is_connection_alive(self, conn: Connection) -> bool:
"""检查连接是否仍然有效"""
try:
conn.ping(reconnect=False)
return True
except:
return False
def close_all(self) -> None:
"""关闭所有连接"""
logger.info("正在关闭连接池中的所有连接...")
while not self.pool.empty():
try:
conn = self.pool.get_nowait()
conn.close()
except (Empty, Exception):
break
self.active_connections = 0
logger.info("连接池已关闭")
class DataProcessor:
"""数据处理器"""
@@ -357,12 +199,76 @@ class DataProcessor:
class MariaDBClient:
"""优化后的MariaDB数据库客户端"""
"""简化版 MariaDB 客户端(内置轻量连接池以复用连接)"""
def __init__(self, config: DatabaseConfig, max_connections: int = 10):
self.config = config
self.connection_pool = ConnectionPool(config, max_connections)
def __init__(self, max_connections: int = 10):
self.data_processor = DataProcessor()
self._max_connections = max_connections
self._pool: Queue = Queue(maxsize=max_connections)
self._active = 0
self._lock = threading.Lock()
# 预创建少量连接,降低首次延迟
initial = min(3, max_connections)
for _ in range(initial):
conn = self._create_connection()
if conn:
try:
self._pool.put_nowait(conn)
self._active += 1
except Full:
try:
conn.close()
except Exception:
pass
def _create_connection(self):
try:
return pymysql.connect(
host=DB_HOST,
port=DB_PORT,
user=DB_USER,
password=DB_PASSWORD,
charset=DB_CHARSET,
connect_timeout=DB_CONNECT_TIMEOUT,
read_timeout=DB_READ_TIMEOUT,
write_timeout=DB_WRITE_TIMEOUT,
autocommit=True
)
except Exception as e:
logger.error(f"创建数据库连接失败: {e}")
return None
def _acquire_connection(self):
# 先尝试不阻塞获取
try:
return self._pool.get_nowait()
except Empty:
# 池为空,若可创建新连接则创建,否则阻塞等待
with self._lock:
if self._active < self._max_connections:
conn = self._create_connection()
if conn:
self._active += 1
return conn
# 达到上限,阻塞等待可用连接
try:
return self._pool.get(timeout=30)
except Empty:
raise RuntimeError("获取数据库连接超时")
def _release_connection(self, conn):
if not conn:
return
try:
self._pool.put_nowait(conn)
except Full:
# 池已满,关闭多余连接
try:
conn.close()
except Exception:
pass
with self._lock:
self._active = max(0, self._active - 1)
def __enter__(self) -> 'MariaDBClient':
return self
@@ -371,30 +277,52 @@ class MariaDBClient:
self.close()
def close(self) -> None:
"""关闭客户端"""
self.connection_pool.close_all()
"""关闭连接池中的连接"""
try:
while True:
try:
conn = self._pool.get_nowait()
except Empty:
break
try:
conn.close()
except Exception:
pass
finally:
with self._lock:
self._active = 0
def execute_query(self, sql: str, params: Optional[Tuple] = None) -> Tuple[Optional[pd.DataFrame], List[str]]:
"""执行SQL查询"""
"""执行SQL查询(复用连接池连接)"""
conn = None
try:
with self.connection_pool.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(sql, params)
results = cursor.fetchall()
# 获取列名
column_names = [desc[0] for desc in cursor.description] if cursor.description else []
if results:
df = pd.DataFrame(results, columns=column_names)
return df, column_names
else:
return pd.DataFrame(), column_names
conn = self._acquire_connection()
with conn.cursor() as cursor:
cursor.execute(sql, params)
results = cursor.fetchall()
column_names = [desc[0] for desc in cursor.description] if cursor.description else []
if results:
df = pd.DataFrame(results, columns=column_names)
return df, column_names
else:
return pd.DataFrame(), column_names
except Exception as e:
logger.error(f"执行查询时出错: {e}")
logger.error(f"SQL: {sql}")
return None, []
finally:
if conn:
# 若连接异常,尝试关闭并减少活跃计数;否则归还
try:
conn.ping(reconnect=False)
self._release_connection(conn)
except Exception:
try:
conn.close()
except Exception:
pass
with self._lock:
self._active = max(0, self._active - 1)
def query_sessions(self, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""查询指定日期范围内的会话数据"""
@@ -483,101 +411,18 @@ class MariaDBClient:
logger.error(f"导出到Excel时出错: {e}")
return None
class SessionProcessor:
"""会话处理器,负责并发处理"""
def __init__(self, db_client: MariaDBClient, max_workers: int = None):
self.db_client = db_client
self.max_workers = max_workers if max_workers is not None else os.cpu_count()
self.temp_save_lock = threading.Lock() # 添加锁用于保护临时保存操作
logger.info(f"初始化会话处理器: max_workers={self.max_workers}")
def process_sessions(self, sessions_df: pd.DataFrame) -> List[List[Dict[str, Any]]]:
"""处理所有会话数据"""
if sessions_df.empty:
logger.warning("没有会话数据需要处理")
return []
total_sessions = len(sessions_df)
logger.info(f"开始处理 {total_sessions} 个会话...")
all_conversations = []
# 直接并发处理每个会话
def process_single_session(session_row):
try:
session_id = session_row['SESSION_ID']
messages_df = self.db_client.query_messages_by_session_id(session_id)
if messages_df is not None and not messages_df.empty:
conversation = self.db_client.data_processor.messages_df_to_list(messages_df, session_row)
if conversation:
return conversation
except Exception as e:
logger.error(f"处理会话 {session_row.get('SESSION_ID', 'unknown')} 时出错: {e}")
return None
# 使用线程池处理所有会话
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# 提交所有会话处理任务
future_to_session = {
executor.submit(process_single_session, row): i
for i, row in sessions_df.iterrows()
}
# 收集结果
with tqdm(total=total_sessions, desc="处理会话进度") as pbar:
for future in concurrent.futures.as_completed(future_to_session):
try:
conversation = future.result()
if conversation:
all_conversations.append(conversation)
# 每处理100个对话临时保存一次
if len(all_conversations) % 100 == 0:
with self.temp_save_lock:
logger.info(f"临时保存 {len(all_conversations)} 个对话")
temp_output_file = self.db_client.export_to_excel(
all_conversations,
f"客服对话记录_临时保存",
output_dir="/data/QueryRewrite/data/excel"
)
if temp_output_file:
logger.info(f"临时保存完成: {temp_output_file}")
except Exception as e:
session_idx = future_to_session[future]
logger.error(f"处理会话索引 {session_idx} 时出错: {e}")
pbar.update(1)
logger.info(f"处理完成,共获得 {len(all_conversations)} 个有效对话")
return all_conversations
def main() -> None:
"""主函数"""
"""主函数(精简版)"""
try:
# 加载配置
config = DatabaseConfig.from_config_file()
logger.info(f"使用数据库配置: {config.host}:{config.port}")
logger.info(f"使用数据库配置: {DB_HOST}:{DB_PORT}")
# 创建数据库客户端
with MariaDBClient(config, max_connections=12) as db_client:
# 创建数据库客户端(简化)
with MariaDBClient() as db_client:
# 查询会话数据
start_date = '2025-08-01 00:00:00'
end_date = '2025-08-01 23:00:00'
logger.info(f"查询时间范围: {start_date}{end_date}")
# 创建会话处理器
processor = SessionProcessor(db_client)
# is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
# if is_debug:
# messages_df = db_client.query_messages_by_session_id("86c919e0-09f1-11f0-84ae-2daf59566989")
# print(db_client.data_processor.messages_df_to_list(messages_df))
# return []
sessions_df = db_client.query_sessions(start_date, end_date)
@@ -585,8 +430,46 @@ def main() -> None:
logger.warning("没有找到符合条件的会话数据")
return
# 处理会话数据
all_conversations = processor.process_sessions(sessions_df)
# 直接并发处理每个会话(替代 SessionProcessor
total_sessions = len(sessions_df)
all_conversations: List[List[Dict[str, Any]]] = []
temp_save_lock = threading.Lock()
def process_single_session(session_row):
try:
session_id = session_row['SESSION_ID']
messages_df = db_client.query_messages_by_session_id(session_id)
if messages_df is not None and not messages_df.empty:
conversation = db_client.data_processor.messages_df_to_list(messages_df, session_row)
if conversation:
return conversation
except Exception as e:
logger.error(f"处理会话 {session_row.get('SESSION_ID', 'unknown')} 时出错: {e}")
return None
with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
future_to_session = {executor.submit(process_single_session, row): i for i, row in sessions_df.iterrows()}
with tqdm(total=total_sessions, desc="处理会话进度") as pbar:
for future in concurrent.futures.as_completed(future_to_session):
try:
conversation = future.result()
if conversation:
all_conversations.append(conversation)
if len(all_conversations) % 100 == 0:
with temp_save_lock:
logger.info(f"临时保存 {len(all_conversations)} 个对话")
temp_output_file = db_client.export_to_excel(
all_conversations,
f"客服对话记录_临时保存",
output_dir="/data/QueryRewrite/data/excel"
)
if temp_output_file:
logger.info(f"临时保存完成: {temp_output_file}")
except Exception as e:
session_idx = future_to_session[future]
logger.error(f"处理会话索引 {session_idx} 时出错: {e}")
pbar.update(1)
# 导出结果
if all_conversations:
output_file = db_client.export_to_excel(
@@ -594,7 +477,6 @@ def main() -> None:
"客服对话记录",
output_dir="/data/QueryRewrite/data/excel"
)
if output_file:
logger.info(f"处理完成!共导出 {len(all_conversations)} 个对话到文件: {output_file}")
else:
+17 -21
View File
@@ -17,21 +17,18 @@ from typing import List, Tuple, Dict, Any, Optional
import re
import jieba
import time
import threading
from .PromptTemplates import (classification_prompt, query_rewrite_prompt_pro,
extract_nouns_prompt, classification_info,
slot_filling_prompt, step_back_prompt,
hyde_prompt)
slot_filling_prompt, step_back_prompt)
from .DataModels import (
Classification, QueryRewrite, Term, TermList,
SoftwareFunctionSlots, SoftwareTroubleShootingSlots, ProfessionalConsultingSlots,
DataProblemSlots, FileExtensionConsultingSlots, SoftwareLockSlots,
InstallationDownloadSlots, ProblemDiagnosisSlots, OtherSlots, IntentAndSlotResult,
StepBackPrompt, HypotheticalDocument
InstallationDownloadSlots, ProblemDiagnosisSlots, OtherSlots,
StepBackPrompt
)
from .ProfessionalNounVector import ProfessionalNounRetriever, AsyncProfessionalNounRetriever
from rag2_0.tool.ModelTool import OpenAiLLM
class AsyncIntentRecognizer:
@@ -344,22 +341,20 @@ class AsyncIntentRecognizer:
"""
start_time = time.time() # 记录开始时间
prompt=f"""
当前提问内容:
<query>{query}</query>
对话上下文:
<chat_history>
{json.dumps(chat_history, ensure_ascii=False)}
</chat_history>
prompt=f"""当前提问内容:
<query>{query}</query>
对话上下文:
<chat_history>
{json.dumps(chat_history, ensure_ascii=False)}
</chat_history>
1、请从当前提问内容中提取电力造价行中定额编码、定额名称、清单编码、清单名称
2、请勿随机编造,如果没有提取到内容返回空的JSON
3、返回结果为json格式,必须严格以纯JSON格式输出
{{
"dinge_info_list":{{"dinge_code_list":["xxxx","xxxx"], "dinge_name_list":["xxxx","xxxx"]}},
"qingdan_info":{{"qingdan_code_list":["xxxx","xxxx"], "qingdan_name_list":["xxxx","xxxx"]}}
}}
"""
1、请从当前提问内容中提取电力造价行中定额编码、定额名称、清单编码、清单名称
2、请勿随机编造,如果没有提取到内容返回空的JSON
3、返回结果为json格式,必须严格以纯JSON格式输出
{{
"dinge_info_list":{{"dinge_code_list":["xxxx","xxxx"], "dinge_name_list":["xxxx","xxxx"]}},
"qingdan_info":{{"qingdan_code_list":["xxxx","xxxx"], "qingdan_name_list":["xxxx","xxxx"]}}
}}"""
try:
# response = await self._llm.ainvoke(prompt, response_format={"type": "json_object"}, extra_body={"enable_thinking": False})
@@ -489,6 +484,7 @@ class AsyncIntentRecognizer:
# 特殊处理 锁相关咨询
if classification.vertical_classification == "安装下载注册" and classification.sub_classification == "软件锁类":
process_lock_start_time = time.time()
# 特殊处理提问只有锁号的问题,手动将问题改写为特定格式
rewrite.rewrite = self._process_lock_related_query(rewrite.rewrite)
process_lock_end_time = time.time()
process_lock_time = process_lock_end_time - process_lock_start_time
+34 -133
View File
@@ -56,13 +56,11 @@ classification_info="""【垂直领域分类】:
4. 问题排查类:软件安装下载失败、报错,系统兼容性问题等
【固定话术类包括以下类】:
1. 规费咨询
**以下两种情况才属于该类**
1、当询问规费(如社会保障费和住房公积金)费率是/填多少
2、去哪里获取规费费率
**其余涉及规费的属于其他垂直领域分类**
2. 调差下载更新
**以下两种情况才属于该类**
1、询问如何下载导入调差文件、调差插件
@@ -73,111 +71,29 @@ classification_info="""【垂直领域分类】:
【其他】:
1. 其他
分类优先级:
固定话术类 > 软件问题 、 业务问题 、 安装下载注册 > 其他
"""
分类优先级:固定话术类 > 软件问题 、 业务问题 、 安装下载注册 > 其他"""
classification_prompt="""
用户正在使用电力造价软件或想询问电力造价领域相关知识,你需要根据用户的输入内容集合历史对话(如果存在),将其归类为以下垂直领域之一:
{classification_info}
classification_prompt="""用户正在使用电力造价软件或想询问电力造价领域相关知识,你需要根据用户的输入内容集合历史对话(如果存在),将其归类为以下垂直领域之一:
{classification_info}
## 【历史对话记录】
{chat_history}
## 【历史对话记录】
{chat_history}
【用户输入】:
{user_input}
【用户输入】:
{user_input}
【输出格式要求】:
{output_format}
【示例】
用户输入1: 技改T1怎样新建工程
输出1:
{{
"vertical_classification":"软件问题",
"sub_classification":"软件功能"
}}
"""
query_rewrite_prompt = """
# 电力造价专业问答优化工程师
你是一名电力造价专业问答优化工程师,负责通过专业关键词集合替换原始问题中的非专业表述以提升知识库检索准确率。
## 核心任务
将用户的原始问题结合专业术语库进行规范化重构,提高知识库检索的准确性和专业性。
## 处理流程
### 第一阶段:输入解析
1. 解析基础信息
- 原始问题(需保留核心语义){query}
- 关键词集合:{keywords}
### 第二阶段:匹配分析
**匹配规则:**
1. 检查原始问题中是否包含关键词集合中的`name`字段或`synonymous`字段中的任何词汇
2. 统计匹配的术语数量
3. 判断执行路径:
- 匹配术语 ≥ 1个 → 执行重构流程
- 匹配术语 = 0个 → 直接输出原始问题
### 第三阶段:问题重构
**重构原则(按优先级排序):**
1. **语义保真**:严格保持原问题的核心意图和诉求
2. **术语规范**
- 将匹配到的同义词替换为对应的标准术语(name字段)
- 对在关键词中的标准术语使用【】进行标记
- 保留在原问题中未在关键词库中的专业术语、限定词和修饰词
3. **结构优化**
- 保持原问题的语态特征5W2H
- 保持主谓宾结构清晰
- 保留时间、版本等限定条件
**术语处理规则:**
- 优先级1:保留原问题中的专业术语、限定词和修饰词(即使不在关键词库中)
- 优先级2:将同义词替换为标准术语并用【】标记
- 优先级3:对原问题中已存在的标准术语添加【】标记
# 输出规范
【输出格式要求】:
{output_format}
# 示范案例库
▶ 案例1(有效匹配)
入:
原始问题:怎么把旧版西藏定额工程转到Z1新版
关键词:【'老版本定额升级', '批量设置定额', '西藏造价软件Z1'
输出:
{{"rewrite":"【西藏造价软件Z1】如何执行【老版本定额升级】操作?"}}
【示例】
用户输入1: 技改T1怎样新建工程
出1:
{{
"vertical_classification":"软件问题",
"sub_classification":"软件功能"
}}"""
▶ 案例2(无效匹配)
输入:
原始问题:程序界面文字显示过小如何处理?
关键词:【'定额升级', '工程批量导入'
输出:
{{"rewrite":"程序界面文字显示过小如何处理?"}}
▶ 案例3(部分匹配,但保留修饰限定词)
输入:
原始问题:"配网软件D3能导出清单的计算公式吗?
关键词:【'配网工程计价通D3软件', '计算式'
输出(保留限定修饰词"清单")
{{"rewrite":"【配网工程计价通D3软件】能导出清单的【计算式】吗?"}}
## 质量检查清单
执行前请确认:
- [ ] 是否保持了原问题的核心诉求?
- [ ] 是否正确执行了同义词替换?
- [ ] 是否保留了原问题中的专业术语和限定条件?
- [ ] 是否正确使用了【】标记?
- [ ] 重构后的问题是否自然流畅?
"""
query_rewrite_prompt_pro="""
# 问答优化工程师
query_rewrite_prompt_pro="""# 问答优化工程师
**角色**:基于历史对话和术语库重构问题,提升知识库检索准确率。
**最高准则**
1、保持问题核心意图,允许指代消除
@@ -196,18 +112,18 @@ query_rewrite_prompt_pro="""
## 处理流程
### 一、输入解析
- 原始问题(需保留核心语义):
<query> {query} </query>
- 原始问题(需保留核心语义):
<query> {query} </query>
- 术语库集合(用于同义词转标准词环节):
<keywords>
{keywords}
</keywords>
- 术语库集合(用于同义词转标准词环节):
<keywords>
{keywords}
</keywords>
- 历史对话记录:
<history>
{chat_history}
</history>
- 历史对话记录:
<history>
{chat_history}
</history>
### 一、重构流程
1、问题是否指代不明,指代不明时根据历史对话补充上下文
@@ -221,23 +137,11 @@ query_rewrite_prompt_pro="""
## 输出规范
{output_format}
## 示例模仿
示例1
输入:
<history>
'user': '811623110668是哪款软件的锁?
'assistant': 可通过查询软件锁的许可证信息,通过许可证名称可以判断对应软件
</history>
<query> ”锂离子电池储能安装“ </query>
输出:
{{"rewrite": "许可证名称为‘锂离子电池储能安装’对应什么软件?"}}
## 质量自检
- [] **主题是否合理继承?**
- [] 核心诉求是否保留?
- [] 语句是否自然流畅?
- [] 避免补充无关信息
"""
- [] 避免补充无关信息"""
slot_filling_prompt = """
你是一个专业的电力造价领域问题槽位填充助手。你需要从用户问题中提取关键信息,并填充到对应的数据结构中。
@@ -282,8 +186,7 @@ slot_filling_prompt = """
"""
# 意图优化环节提示词模板
step_back_prompt = """
# 后退提示生成器
step_back_prompt = """# 后退提示生成器
你是一个专业的电力造价领域问题抽象专家。你的任务是根据用户的具体问题,提出一个更抽象、更高层次的问题,帮助系统更好地理解用户的意图。
@@ -292,11 +195,11 @@ step_back_prompt = """
2. 考虑历史对话和会话背景,理解用户当前问题的上下文
3. 生成更抽象、更高层次的问题,称为"后退问题",后退问题可以生成多个,依次后退到更抽象、更高层次的问题
4. 后退问题应该:
- 更加通用和抽象,不应包含原始问题的具体细节(包括场景限定、界面限定等其他限定词语)
- 涵盖原始问题的核心主题
- 去除过于具体的限制条件(如时间、地点、特定版本、特定工程等)
- 保持在同一领域和主题范围内
- 依次移除问题中的限定词或者修饰词
- 更加通用和抽象,不应包含原始问题的具体细节(包括场景限定、界面限定等其他限定词语)
- 涵盖原始问题的核心主题
- 去除过于具体的限制条件(如时间、地点、特定版本、特定工程等)
- 保持在同一领域和主题范围内
- 依次移除问题中的限定词或者修饰词
## 输入
用户原始问题: {query}
@@ -320,9 +223,7 @@ step_back_prompt = """
"original_query": "某个设备更换后,如何在系统中更新对应的定额?",
"can_use_back_prompt": true,
"step_back_query": ["如何更新设备对应的定额?", "如何更新定额?"]
}}
"""
}}"""
follow_up_questions_prompt = """
# 后续问题生成器
+15 -33
View File
@@ -193,7 +193,9 @@ class OpenAiLLM:
messages=[{'role': 'user', 'content': user_prompt}],
**kwargs
)
return completion.choices[0].message
message = completion.choices[0].message
message.usage = completion.usage
return message
except Exception as e:
raise RuntimeError(f"OpenAiLLM:invoke:error:{str(e)}") from e
@@ -225,36 +227,16 @@ class OpenAiLLM:
if __name__ == "__main__":
# 测试重排模型
reranker = SiliconFlowReRankerModel()
# 测试用例1:简单问题
query = "如何通过【电力经济评价软件】的【打开】功能加载工程文件?"
documents = []
results = reranker.rerank(query, documents)
print(f"测试用例1 - 查询:{query}")
for idx, item in enumerate(results):
print(f"{idx+1}. 文档: {item['document']}, 分数: {item['score']}")
print("-" * 50)
# 异步测试示例
async def test_async():
# 测试异步嵌入
api_key = APIKeyManager.get_api_key()
embeddings = XinferenceEmbeddings(api_key=api_key)
query_embedding = await embeddings.embed_query_async("测试查询")
print(f"异步嵌入向量维度: {len(query_embedding)}")
# 测试异步重排序
results = await SiliconFlowReRankerModel.rerank_async(query, documents)
print(f"异步重排序结果数量: {len(results)}")
# 测试异步LLM调用
llm = OpenAiLLM()
response = await llm.ainvoke("你好,请简单介绍一下自己")
print(f"异步LLM响应: {response.content}")
# 如果需要运行异步测试,取消下面的注释
# import asyncio
# asyncio.run(test_async())
base_url = os.getenv("OPENAI_API_BASE")
model_name = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
# 初始化LLM
llm_params = {
"temperature": 0.4, # 降低随机性,使结果更确定
"top_p": 0.7,
"model": model_name,
"base_url": base_url
}
_llm = OpenAiLLM(**llm_params)
promt="""你好,请简单介绍一下自己"""
print(_llm.invoke(promt))