优化对话转工单功能,添加重试机制以提高稳定性,限制处理会话数量为前2000个,更新示例查询和文件路径,增强代码可读性和维护性。同时新增数据库客户端功能,支持批量处理会话数据并导出至Excel。

This commit is contained in:
2025-06-17 19:46:04 +08:00
parent a5c1548240
commit 22d48c951f
10 changed files with 718 additions and 96 deletions
+20 -8
View File
@@ -231,7 +231,7 @@ class DialogueToWorkorder:
output_format = self.user_question_and_solution_parser.get_format_instructions() output_format = self.user_question_and_solution_parser.get_format_instructions()
llm_prompt = prompt.format(output_format=output_format, dialogue_str=dialogue_str) llm_prompt = prompt.format(output_format=output_format, dialogue_str=dialogue_str)
response = self.llm.invoke(user_prompt=llm_prompt) response = self.llm.invoke(user_prompt=llm_prompt, need_retry=False)
try: try:
if response.content.count('user_question') == 1: if response.content.count('user_question') == 1:
@@ -261,7 +261,7 @@ class DialogueToWorkorder:
except Exception as e: except Exception as e:
output_format = self.user_question_and_solution_list_parser.get_format_instructions() output_format = self.user_question_and_solution_list_parser.get_format_instructions()
llm_prompt = prompt.format(output_format=output_format, dialogue_str=dialogue_str) llm_prompt = prompt.format(output_format=output_format, dialogue_str=dialogue_str)
response = self.llm.invoke(user_prompt=llm_prompt) response = self.llm.invoke(user_prompt=llm_prompt, need_retry=False)
user_question_and_solution_temp = self.user_question_and_solution_list_parser.parse(response.content) user_question_and_solution_temp = self.user_question_and_solution_list_parser.parse(response.content)
return user_question_and_solution_temp.user_question_list return user_question_and_solution_temp.user_question_list
@@ -293,7 +293,7 @@ class DialogueToWorkorder:
{dialogue_str} {dialogue_str}
""" """
response = self.llm.invoke(user_prompt=prompt) response = self.llm.invoke(user_prompt=prompt, need_retry=False)
product_name_and_module_name = self.product_name_and_module_name_parser.parse(response.content) product_name_and_module_name = self.product_name_and_module_name_parser.parse(response.content)
return product_name_and_module_name.product_name, product_name_and_module_name.module_name return product_name_and_module_name.product_name, product_name_and_module_name.module_name
@@ -322,7 +322,7 @@ class DialogueToWorkorder:
{dialogue_str} {dialogue_str}
""" """
response = self.llm.invoke(user_prompt=prompt) response = self.llm.invoke(user_prompt=prompt, need_retry=False)
product_line = self.product_line_parser.parse(response.content) product_line = self.product_line_parser.parse(response.content)
return product_line.product_line return product_line.product_line
@@ -358,7 +358,7 @@ class DialogueToWorkorder:
{dialogue_str} {dialogue_str}
""" """
response = self.llm.invoke(user_prompt=prompt) response = self.llm.invoke(user_prompt=prompt, need_retry=False)
question_type = self.question_type_parser.parse(response.content) question_type = self.question_type_parser.parse(response.content)
return question_type.question_type return question_type.question_type
@@ -394,7 +394,7 @@ class DialogueToWorkorder:
""" """
response = self.llm.invoke(user_prompt=prompt) response = self.llm.invoke(user_prompt=prompt, need_retry=False)
is_complaint = self.is_complaint_parser.parse(response.content) is_complaint = self.is_complaint_parser.parse(response.content)
return (is_complaint.is_dissatisfaction, return (is_complaint.is_dissatisfaction,
@@ -479,7 +479,19 @@ class DialogueToWorkorder:
# 按会话ID分组 # 按会话ID分组
conversation_dict = self.group_conversations_by_id(df) conversation_dict = self.group_conversations_by_id(df)
# 限制处理的会话数量为前2000个
if len(conversation_dict) > 2000:
print(f"会话总数为 {len(conversation_dict)},限制处理前2000个会话")
# 获取所有会话ID
conversation_ids = list(conversation_dict.keys())
# 只保留前2000个会话
limited_conversation_dict = {
conversation_id: conversation_dict[conversation_id]
for conversation_id in conversation_ids[:2000]
}
conversation_dict = limited_conversation_dict
else:
print(f"会话总数为 {len(conversation_dict)},处理全部会话")
# 使用线程池处理每个会话 # 使用线程池处理每个会话
workorder_dict_list = [] workorder_dict_list = []
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
@@ -593,7 +605,7 @@ def main():
args = parse_arguments() args = parse_arguments()
# 设置默认文件路径 # 设置默认文件路径
conversation_excel_path = args.conversation_file or os.path.join('data', 'excel', '会话内容详情20250528110230.xlsx') conversation_excel_path = args.conversation_file or os.path.join('data', 'excel', '2025年1月到6月12号所有对话记录.xlsx')
product_detail_excel_path = args.product_detail_file or os.path.join('data', 'excel', '产品详情_工单.xlsx') product_detail_excel_path = args.product_detail_file or os.path.join('data', 'excel', '产品详情_工单.xlsx')
# 创建处理实例 # 创建处理实例
+537
View File
@@ -0,0 +1,537 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import annotations
import json
import os
import re
import configparser
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass
from contextlib import contextmanager
import threading
import time
from queue import Queue, Empty, Full
import pandas as pd
import pymysql
from pymysql.connections import Connection
from pymysql.cursors import Cursor
from tqdm import tqdm
import concurrent.futures
import sys
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('./data/log/mariadb_client.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
os.makedirs('./data/log', exist_ok=True)
@dataclass
class DatabaseConfig:
"""数据库配置类"""
host: str = '192.168.0.123'
port: int = 3307
user: str = 'fuzhimei'
password: str = 'fuzhimei@135'
charset: str = 'utf8mb4'
connect_timeout: int = 10
read_timeout: int = 300
write_timeout: int = 300
@classmethod
def from_config_file(cls, config_file: str = 'config.ini') -> 'DatabaseConfig':
"""从配置文件加载配置"""
if not os.path.exists(config_file):
logger.warning(f"配置文件 {config_file} 不存在,使用默认配置")
return cls()
config = configparser.ConfigParser()
config.read(config_file, encoding='utf-8')
if 'database' not in config:
logger.warning("配置文件中没有 [database] 部分,使用默认配置")
return cls()
db_config = config['database']
return cls(
host=db_config.get('host', cls.host),
port=int(db_config.get('port', cls.port)),
user=db_config.get('user', cls.user),
password=db_config.get('password', cls.password),
charset=db_config.get('charset', cls.charset),
connect_timeout=int(db_config.get('connect_timeout', cls.connect_timeout)),
read_timeout=int(db_config.get('read_timeout', cls.read_timeout)),
write_timeout=int(db_config.get('write_timeout', cls.write_timeout))
)
class ConnectionPool:
"""数据库连接池"""
def __init__(self, config: DatabaseConfig, max_connections: int = 10):
self.config = config
self.max_connections = max_connections
self.pool = Queue(maxsize=max_connections)
self.active_connections = 0
self.lock = threading.Lock()
# 预创建一些连接
self._initialize_pool()
def _initialize_pool(self) -> None:
"""初始化连接池,预创建一些连接"""
initial_connections = min(3, self.max_connections)
for _ in range(initial_connections):
try:
conn = self._create_connection()
if conn:
self.pool.put_nowait(conn)
self.active_connections += 1
except Full:
break
except Exception as e:
logger.error(f"初始化连接池时创建连接失败: {e}")
def _create_connection(self) -> Optional[Connection]:
"""创建新的数据库连接"""
try:
conn = pymysql.connect(
host=self.config.host,
port=self.config.port,
user=self.config.user,
password=self.config.password,
charset=self.config.charset,
connect_timeout=self.config.connect_timeout,
read_timeout=self.config.read_timeout,
write_timeout=self.config.write_timeout,
autocommit=True
)
return conn
except Exception as e:
logger.error(f"创建数据库连接失败: {e}")
return None
@contextmanager
def get_connection(self):
"""获取连接的上下文管理器"""
conn = None
try:
# 尝试从池中获取连接
try:
conn = self.pool.get_nowait()
except Empty:
# 池中没有连接,尝试创建新连接
with self.lock:
if self.active_connections < self.max_connections:
conn = self._create_connection()
if conn:
self.active_connections += 1
else:
raise Exception("无法创建新的数据库连接")
else:
# 等待可用连接
logger.info("等待可用连接...")
conn = self.pool.get(timeout=30)
# 检查连接是否仍然有效
if conn and not self._is_connection_alive(conn):
logger.warning("连接已失效,重新创建")
try:
conn.close()
except:
pass
conn = self._create_connection()
if not conn:
raise Exception("重新创建连接失败")
yield conn
except Exception as e:
logger.error(f"获取数据库连接时出错: {e}")
if conn:
try:
conn.close()
except:
pass
with self.lock:
self.active_connections -= 1
raise
else:
# 归还连接到池中
if conn:
try:
self.pool.put_nowait(conn)
except Full:
# 池已满,关闭连接
try:
conn.close()
except:
pass
with self.lock:
self.active_connections -= 1
def _is_connection_alive(self, conn: Connection) -> bool:
"""检查连接是否仍然有效"""
try:
conn.ping(reconnect=False)
return True
except:
return False
def close_all(self) -> None:
"""关闭所有连接"""
logger.info("正在关闭连接池中的所有连接...")
while not self.pool.empty():
try:
conn = self.pool.get_nowait()
conn.close()
except (Empty, Exception):
break
self.active_connections = 0
logger.info("连接池已关闭")
class DataProcessor:
"""数据处理器"""
@staticmethod
def clean_html_tags(text: str) -> str:
"""清除文本中的HTML标签"""
if not isinstance(text, str):
return str(text) if text is not None else ""
# 使用正则表达式移除HTML标签
clean_text = re.sub(r'<[^>]+>', '', text)
# 处理HTML实体
html_entities = {
'&nbsp;': ' ',
'&lt;': '<',
'&gt;': '>',
'&amp;': '&',
'&quot;': '"',
'&apos;': "'"
}
for entity, char in html_entities.items():
clean_text = clean_text.replace(entity, char)
return clean_text.strip()
@staticmethod
def messages_df_to_list(messages_df: pd.DataFrame) -> List[Dict[str, Any]]:
"""将消息DataFrame转换为字典列表,使用高效的向量化操作"""
if messages_df.empty:
return []
# 过滤掉系统消息
mask = (messages_df["MODE"] != "system") & (messages_df["SYSTEM_MODE_MESSAGE_TYPE"].isna())
filtered_df = messages_df[mask].copy()
if filtered_df.empty:
return []
# 向量化操作
filtered_df['message_sender'] = filtered_df["MODE"].map({'reply': '坐席', 'receive': '访客'}).fillna('未知')
# 处理发送者昵称
filtered_df['sender_nickname'] = filtered_df.apply(
lambda row: row["AGENT_NAME"] if row["message_sender"] == "坐席" else row["CUS_NICK_NAME"],
axis=1
)
# 处理内容
def process_content(row):
content = row["CONTENT"]
if row["MSG_TYPE"] == "attachment":
return f"附件:{DataProcessor.clean_html_tags(content)}"
elif row["MSG_TYPE"] == "image":
return f"图片:{DataProcessor.clean_html_tags(content)}"
else:
return content
filtered_df['processed_content'] = filtered_df.apply(process_content, axis=1)
# 过滤掉空昵称
filtered_df = filtered_df[filtered_df['sender_nickname'].notna() & (filtered_df['sender_nickname'] != '')]
# 转换为字典列表
result = []
for record in filtered_df.to_dict('records'):
result.append({
"账号id": record["ACCOUNT"],
"会话id": record["SESSION_ID"],
"消息内容": record["processed_content"],
"消息发送者": record["message_sender"],
"发送者昵称": record["sender_nickname"],
"创建时间": record["CREATE_TIME"],
})
return result
class MariaDBClient:
"""优化后的MariaDB数据库客户端"""
def __init__(self, config: DatabaseConfig, max_connections: int = 10):
self.config = config
self.connection_pool = ConnectionPool(config, max_connections)
self.data_processor = DataProcessor()
def __enter__(self) -> 'MariaDBClient':
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.close()
def close(self) -> None:
"""关闭客户端"""
self.connection_pool.close_all()
def execute_query(self, sql: str, params: Optional[Tuple] = None) -> Tuple[Optional[pd.DataFrame], List[str]]:
"""执行SQL查询"""
try:
with self.connection_pool.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(sql, params)
results = cursor.fetchall()
# 获取列名
column_names = [desc[0] for desc in cursor.description] if cursor.description else []
if results:
df = pd.DataFrame(results, columns=column_names)
return df, column_names
else:
return pd.DataFrame(), column_names
except Exception as e:
logger.error(f"执行查询时出错: {e}")
logger.error(f"SQL: {sql}")
return None, []
def query_sessions(self, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""查询指定日期范围内的会话数据"""
sql = """
SELECT ACCOUNT, BEGIN_TIME, END_TIME, CUST_SEND_MESSAGE_COUNT,
AGENT_SEND_MESSAGE_COUNT, STATUS, CHANNEL_NAME, SESSION_ID, SESSION_TAG_NAME
FROM crm_hlyj.crm_hlyj_dsri
WHERE BEGIN_TIME >= %s
AND BEGIN_TIME < %s
AND STATUS = 'assign'
ORDER BY BEGIN_TIME DESC
"""
df, _ = self.execute_query(sql, (start_date, end_date))
return df
def query_messages_by_session_id(self, session_id: str) -> Optional[pd.DataFrame]:
"""根据会话ID查询消息详情"""
sql = """
SELECT CREATE_TIME, CUS_NICK_NAME, MODE, MSG_TYPE, AGENT_NAME, CONTENT,
SESSION_ID, ACCOUNT, SYSTEM_MODE_MESSAGE_TYPE
FROM crm_hlyj.crm_hlyj_dmri
WHERE SESSION_ID = %s
ORDER BY CREATE_TIME
"""
df, _ = self.execute_query(sql, (session_id,))
return df
def export_to_excel(self, data: List[Dict[str, Any]], filename: str, output_dir: str = "output") -> Optional[str]:
"""导出数据到Excel文件"""
if not data:
logger.warning(f"没有数据可导出到 {filename}")
return None
try:
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
# 生成文件路径
# timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
file_path = os.path.join(output_dir, f"{filename}.xlsx")
# 准备数据:不同对话之间添加空行
all_rows = []
current_session_id = None
for conversation in data:
if not conversation: # 跳过空对话
continue
# 如果是新的会话,添加空行(除了第一个会话)
if current_session_id and current_session_id != conversation[0]["会话id"]:
empty_row = {key: "" for key in conversation[0].keys()}
all_rows.append(empty_row)
# 更新当前会话ID
current_session_id = conversation[0]["会话id"]
# 添加当前会话的所有消息
all_rows.extend(conversation)
# 创建DataFrame并导出
if all_rows:
df = pd.DataFrame(all_rows)
with pd.ExcelWriter(file_path, engine='openpyxl') as writer:
df.to_excel(writer, sheet_name='对话记录', index=False)
logger.info(f"数据已导出到 {file_path}")
return file_path
else:
logger.warning("没有有效数据可导出")
return None
except Exception as e:
logger.error(f"导出到Excel时出错: {e}")
return None
def process_session_batch(db_client: MariaDBClient, session_batch: pd.DataFrame) -> List[List[Dict[str, Any]]]:
"""批量处理会话数据"""
conversations = []
for _, session_row in session_batch.iterrows():
try:
session_id = session_row['SESSION_ID']
messages_df = db_client.query_messages_by_session_id(session_id)
if messages_df is not None and not messages_df.empty:
conversation = db_client.data_processor.messages_df_to_list(messages_df)
if conversation:
conversations.append(conversation)
except Exception as e:
logger.error(f"处理会话 {session_row.get('SESSION_ID', 'unknown')} 时出错: {e}")
continue
return conversations
class SessionProcessor:
"""会话处理器,负责批量和并发处理"""
def __init__(self, db_client: MariaDBClient, max_workers: int = None, batch_size: int = 50):
self.db_client = db_client
self.max_workers = max_workers if max_workers is not None else os.cpu_count()
self.batch_size = batch_size
self.temp_save_lock = threading.Lock() # 添加锁用于保护临时保存操作
logger.info(f"初始化会话处理器: max_workers={self.max_workers}, batch_size={self.batch_size}")
def process_sessions(self, sessions_df: pd.DataFrame) -> List[List[Dict[str, Any]]]:
"""处理所有会话数据"""
if sessions_df.empty:
logger.warning("没有会话数据需要处理")
return []
total_sessions = len(sessions_df)
logger.info(f"开始处理 {total_sessions} 个会话...")
# 分批处理
all_conversations = []
batch_count = (total_sessions + self.batch_size - 1) // self.batch_size
# 使用线程池处理批次
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# 提交所有批次任务
future_to_batch = {}
for i in range(0, total_sessions, self.batch_size):
batch = sessions_df.iloc[i:i + self.batch_size]
future = executor.submit(process_session_batch, self.db_client, batch)
future_to_batch[future] = i // self.batch_size + 1
# 收集结果
with tqdm(total=batch_count, desc="处理批次进度") as pbar:
for future in concurrent.futures.as_completed(future_to_batch):
try:
batch_conversations = future.result()
all_conversations.extend(batch_conversations)
# 使用锁保护临时列表的操作
with self.temp_save_lock:
# 每处理100个对话临时保存一次
logger.info(f"临时保存 {len(all_conversations)} 个对话")
temp_output_file = self.db_client.export_to_excel(
all_conversations,
f"客服对话记录_临时保存",
output_dir="/data/QueryRewrite/data/excel"
)
if temp_output_file:
logger.info(f"临时保存完成: {temp_output_file}")
batch_num = future_to_batch[future]
logger.debug(f"批次 {batch_num} 完成,获得 {len(batch_conversations)} 个对话")
except Exception as e:
batch_num = future_to_batch[future]
logger.error(f"处理批次 {batch_num} 时出错: {e}")
pbar.update(1)
logger.info(f"处理完成,共获得 {len(all_conversations)} 个有效对话")
return all_conversations
def main() -> None:
"""主函数"""
try:
# 加载配置
config = DatabaseConfig.from_config_file()
logger.info(f"使用数据库配置: {config.host}:{config.port}")
# 创建数据库客户端
with MariaDBClient(config, max_connections=12) as db_client:
# 查询会话数据
start_date = '2025-01-01 00:00:00'
end_date = '2025-06-12 00:00:00'
logger.info(f"查询时间范围: {start_date}{end_date}")
# 创建会话处理器
processor = SessionProcessor(db_client, batch_size=100)
is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
if is_debug:
messages_df = db_client.query_messages_by_session_id("86c919e0-09f1-11f0-84ae-2daf59566989")
print(db_client.data_processor.messages_df_to_list(messages_df))
return []
sessions_df = db_client.query_sessions(start_date, end_date)
if sessions_df is None or sessions_df.empty:
logger.warning("没有找到符合条件的会话数据")
return
# 处理会话数据
all_conversations = processor.process_sessions(sessions_df)
# 导出结果
if all_conversations:
output_file = db_client.export_to_excel(
all_conversations,
"客服对话记录",
output_dir="/data/QueryRewrite/data/excel"
)
if output_file:
logger.info(f"处理完成!共导出 {len(all_conversations)} 个对话到文件: {output_file}")
else:
logger.error("导出文件失败")
else:
logger.warning("没有有效的对话数据可导出")
except KeyboardInterrupt:
logger.info("用户中断程序")
except Exception as e:
logger.error(f"程序执行出错: {e}", exc_info=True)
if __name__ == "__main__":
main()
+8 -7
View File
@@ -175,7 +175,7 @@ def save_results_to_excel(results, output_file, is_final=False):
logging.info(f"已保存{len(valid_results)}条结果至: {temp_output_file}") logging.info(f"已保存{len(valid_results)}条结果至: {temp_output_file}")
# 示例查询 # 示例查询
examples_query = """西藏软件呢""" examples_query = """储能软件如何操作"""
conversation_context="" conversation_context=""
chat_history=[ chat_history=[
{ {
@@ -214,8 +214,8 @@ def main():
# 读取提问数据 # 读取提问数据
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
data_file = os.path.join(current_dir, "..", "..", "data", "excel", "历史提问数据(like)_提问明确.xlsx") data_file = os.path.join(current_dir, "..", "..", "data", "excel", "200条点踩数据测试.xlsx")
output_file = os.path.join(current_dir, "..", "..", "data", "excel", "测试提问数据_槽位填充结果.xlsx") output_file = os.path.join(current_dir, "..", "..", "data", "excel", "200条点踩数据测试_槽位填充结果.xlsx")
# 检测是否为调试模式,调试模式下使用examples_query,否则从Excel读取 # 检测是否为调试模式,调试模式下使用examples_query,否则从Excel读取
is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
@@ -226,7 +226,7 @@ def main():
examples = load_questions_from_excel(data_file) examples = load_questions_from_excel(data_file)
if not is_debug: if not is_debug:
max_workers = 40 # 减少并发数以避免API限制 max_workers = 20 # 减少并发数以避免API限制
logging.info(f"共有 {len(examples)} 个问题需要处理,使用 {max_workers} 个并发线程") logging.info(f"共有 {len(examples)} 个问题需要处理,使用 {max_workers} 个并发线程")
# 创建一个与输入顺序相同的结果列表 # 创建一个与输入顺序相同的结果列表
@@ -260,9 +260,10 @@ def main():
logging.info(f"所有处理完成,最终结果已保存至: {output_file}") logging.info(f"所有处理完成,最终结果已保存至: {output_file}")
else: else:
for idx, query in enumerate(examples): for idx, query in enumerate(examples):
if query.strip() == "": if query.strip() == "":
continue continue
process_query(recognizer, query, conversation_context, chat_history, previous_slots) process_query(recognizer, query, conversation_context, chat_history, previous_slots)
# print(json.dumps(process_query(recognizer, query), ensure_ascii=False, indent=2))
def setup_logging(): def setup_logging():
# 配置日志输出到控制台 # 配置日志输出到控制台
+14 -1
View File
@@ -6,10 +6,23 @@ import json
import time import time
import threading import threading
import datetime import datetime
import logging
# 加载环境变量 # 加载环境变量
load_dotenv() load_dotenv()
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler()
]
)
logging.getLogger('httpx').setLevel(logging.WARNING)
logging.getLogger('openai').setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
app = Flask(__name__) app = Flask(__name__)
# 创建线程锁,用于保护共享资源 # 创建线程锁,用于保护共享资源
@@ -50,7 +63,7 @@ def intent_recognize():
end_time = time.time() end_time = time.time()
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S %z") current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S %z")
print(f"[{current_time}] [{os.getpid()}] [INFO] 意图识别耗时: {end_time - start_time:.2f}") logger.info(f"[{os.getpid()}] 意图识别耗时: {end_time - start_time:.2f}")
# 提取分类信息 # 提取分类信息
classification = result["classification"] classification = result["classification"]
+3 -1
View File
@@ -150,12 +150,14 @@ class SoftwareFunctionSlots(SlotBase):
software_name: str = Field(default="", description="软件名称") software_name: str = Field(default="", description="软件名称")
function_name: str = Field(default="", description="具体功能名称") function_name: str = Field(default="", description="具体功能名称")
operation: str = Field(default="", description="用户操作意图(如何使用功能、功能入口、功能使用场景)") operation: str = Field(default="", description="用户操作意图(如何使用功能、功能入口、功能使用场景)")
project_type: Optional[str] = Field(default="单工程", description="工程类型(单工程、多工程、批次工程)") project_type: Optional[str] = Field(default="单工程", description="工程类型(单工程、多工程、批次工程), 未明确提及则默认下是(单工程)")
software_version: Optional[str] = Field(default="", description="软件版本") software_version: Optional[str] = Field(default="", description="软件版本")
operation_steps: Optional[str] = Field(default="", description="操作步骤描述") operation_steps: Optional[str] = Field(default="", description="操作步骤描述")
def check_required_slots(self) -> Tuple[bool, Dict[str, str]]: def check_required_slots(self) -> Tuple[bool, Dict[str, str]]:
"""检查必填槽位是否都存在""" """检查必填槽位是否都存在"""
if self.project_type is None or len(self.project_type) == 0:
self.project_type="单工程"
missing_slots = {} missing_slots = {}
if not self.software_name: if not self.software_name:
missing_slots["software_name"] = f"{SoftwareFunctionSlots.model_fields['software_name'].description},可选值:{', '.join([name.value for name in SoftwareName if name not in [SoftwareName.UNKNOWN, SoftwareName.ALIASES]])}" missing_slots["software_name"] = f"{SoftwareFunctionSlots.model_fields['software_name'].description},可选值:{', '.join([name.value for name in SoftwareName if name not in [SoftwareName.UNKNOWN, SoftwareName.ALIASES]])}"
+72 -11
View File
@@ -14,6 +14,8 @@ import json
from typing import List, Tuple, Dict, Any, Optional from typing import List, Tuple, Dict, Any, Optional
import re import re
import jieba import jieba
import time
from .PromptTemplates import (classification_prompt, query_rewrite_prompt, from .PromptTemplates import (classification_prompt, query_rewrite_prompt,
extract_nouns_prompt, classification_info, extract_nouns_prompt, classification_info,
slot_filling_prompt) slot_filling_prompt)
@@ -95,7 +97,9 @@ class IntentRecognizer:
except Exception as e: except Exception as e:
raise RuntimeError(f"加载后缀关键词失败: {e}") from e raise RuntimeError(f"加载后缀关键词失败: {e}") from e
def _classify_intent(self, query: str) -> Classification: def _classify_intent(self, query: str, conversation_context: str = "",
chat_history: List[Dict[str, str]] = None,
previous_slots: Dict[str, Any] = None) -> Classification:
""" """
对用户输入进行意图分类 对用户输入进行意图分类
@@ -109,7 +113,9 @@ class IntentRecognizer:
classification_parser = PydanticOutputParser(pydantic_object=Classification) classification_parser = PydanticOutputParser(pydantic_object=Classification)
formatted_prompt = classification_prompt.format(user_input=query, formatted_prompt = classification_prompt.format(user_input=query,
classification_info=classification_info, classification_info=classification_info,
output_format=classification_parser.get_format_instructions()) output_format=classification_parser.get_format_instructions(),
conversation_context=conversation_context,
chat_history=json.dumps(chat_history, ensure_ascii=False))
# 调用LLM # 调用LLM
response = self._llm.invoke(formatted_prompt, False) response = self._llm.invoke(formatted_prompt, False)
@@ -208,7 +214,7 @@ class IntentRecognizer:
term_texts = ["名称:" + term.name + "|" + "同义词:" + ";".join(term.synonymous) for term in matched_terms] term_texts = ["名称:" + term.name + "|" + "同义词:" + ";".join(term.synonymous) for term in matched_terms]
# 使用重排序模型 # 使用重排序模型
xinference_reranker = SiliconFlowReRankerModel() xinference_reranker = XinferenceReRankerModel()
rerank_results = xinference_reranker.rerank(query_key, term_texts, top_k=top_k) rerank_results = xinference_reranker.rerank(query_key, term_texts, top_k=top_k)
# 将matched_terms转换为列表以便按索引访问 # 将matched_terms转换为列表以便按索引访问
@@ -220,7 +226,7 @@ class IntentRecognizer:
return reranked_terms return reranked_terms
except Exception as e: except Exception as e:
raise RuntimeError(f"SiliconFlowReRankerModel重排失败:{e}") from e raise RuntimeError(f"_rerank_matched_terms重排失败:{e}") from e
def _match_keywords(self, query: str, use_jieba: bool = False) -> Tuple[TermList, List[str]]: def _match_keywords(self, query: str, use_jieba: bool = False) -> Tuple[TermList, List[str]]:
""" """
@@ -233,18 +239,23 @@ class IntentRecognizer:
Returns: Returns:
匹配到的关键词列表 匹配到的关键词列表
""" """
start_time = time.time()
query_keys=[] query_keys=[]
# 步骤1: 使用LLM提取查询中的关键词 # 步骤1: 使用LLM提取查询中的关键词
try: try:
llm_start_time = time.time()
extracted_terms = self._extract_keywords_with_llm(query, use_jieba) extracted_terms = self._extract_keywords_with_llm(query, use_jieba)
for term in extracted_terms: for term in extracted_terms:
query_keys.append(term.name) query_keys.append(term.name)
llm_end_time = time.time()
llm_time = llm_end_time - llm_start_time
except Exception as e: except Exception as e:
raise RuntimeError(f"LLM关键词提取失败: {e}") from e raise RuntimeError(f"LLM关键词提取失败: {e}") from e
matched_terms = [] # 存储匹配到的Term对象 matched_terms = [] # 存储匹配到的Term对象
# 步骤2: 使用向量检索找到相似的专业名词 # 步骤2: 使用向量检索找到相似的专业名词
try: try:
vector_start_time = time.time()
# 对matched_terms中的每个关键字进行向量检索 # 对matched_terms中的每个关键字进行向量检索
for current_key in query_keys: for current_key in query_keys:
vector_results = self._noun_retriever.query(current_key, top_k=5, use_intersection=False) vector_results = self._noun_retriever.query(current_key, top_k=5, use_intersection=False)
@@ -262,12 +273,20 @@ class IntentRecognizer:
if len(current_key_terms) > 0: if len(current_key_terms) > 0:
reranked_terms = self._rerank_matched_terms(current_key, current_key_terms) reranked_terms = self._rerank_matched_terms(current_key, current_key_terms)
matched_terms.extend(reranked_terms) matched_terms.extend(reranked_terms)
vector_end_time = time.time()
vector_time = vector_end_time - vector_start_time
except Exception as e: except Exception as e:
raise RuntimeError(f"向量检索关键词时出错: {e}") from e raise RuntimeError(f"向量检索关键词时出错: {e}") from e
# 提取所有Term对象的名称并排序 # 提取所有Term对象的名称并排序
# 将set类型的matched_terms转换为TermList类型 # 将set类型的matched_terms转换为TermList类型
term_list = TermList(terms=list(matched_terms)) term_list = TermList(terms=list(matched_terms))
end_time = time.time()
total_time = end_time - start_time
# 输出整合的时间日志
logging.info(f"关键词匹配耗时统计 - 总耗时: {total_time:.2f}秒, 问题关键词提取: {llm_time:.2f}秒, 向量检索+重排序: {vector_time:.2f}")
return term_list, query_keys return term_list, query_keys
def _rewrite_query(self, query: str, keywords: TermList, query_keys:List[str], chat_history: List[Dict[str, str]] = None, context: str = "") -> QueryRewrite: def _rewrite_query(self, query: str, keywords: TermList, query_keys:List[str], chat_history: List[Dict[str, str]] = None, context: str = "") -> QueryRewrite:
@@ -282,6 +301,8 @@ class IntentRecognizer:
Returns: Returns:
改写结果 改写结果
""" """
rewrite_start_time = time.time()
# 准备问题改写提示 # 准备问题改写提示
# terms_dict = [term.model_dump(exclude={"description"}) for term in keywords.terms] # terms_dict = [term.model_dump(exclude={"description"}) for term in keywords.terms]
terms_dict = [term.model_dump() for term in keywords.terms] terms_dict = [term.model_dump() for term in keywords.terms]
@@ -303,6 +324,9 @@ class IntentRecognizer:
try: try:
# 尝试直接解析JSON响应 # 尝试直接解析JSON响应
parsed_output = query_rewrite_parser.parse(response.content) parsed_output = query_rewrite_parser.parse(response.content)
rewrite_end_time = time.time()
rewrite_time = rewrite_end_time - rewrite_start_time
logging.info(f"问题改写耗时统计 - 总耗时: {rewrite_time:.2f}")
return parsed_output return parsed_output
except Exception as e: except Exception as e:
raise RuntimeError(f"解析问题改写结果时出错: {e}") from e raise RuntimeError(f"解析问题改写结果时出错: {e}") from e
@@ -360,7 +384,10 @@ class IntentRecognizer:
# suffix_terms.append(suffix_term) # suffix_terms.append(suffix_term)
# return Classification(vertical_classification="安装下载", sub_classification="查询"), TermList(terms=suffix_terms), QueryRewrite(rewrite=query), matched_suffixes # return Classification(vertical_classification="安装下载", sub_classification="查询"), TermList(terms=suffix_terms), QueryRewrite(rewrite=query), matched_suffixes
if chat_history is None:
chat_history = []
if previous_slots is None:
previous_slots = {}
# 步骤1: 匹配关键词 # 步骤1: 匹配关键词
keywords_terms, query_keys = self._match_keywords(query, use_jieba) keywords_terms, query_keys = self._match_keywords(query, use_jieba)
@@ -397,7 +424,9 @@ class IntentRecognizer:
# } # }
def _fill_slots(self, query: str, classification: Classification) -> Dict[str, Any]: def _fill_slots(self, query: str, classification: Classification, conversation_context: str = "",
chat_history: List[Dict[str, str]] = None,
previous_slots: Dict[str, Any] = None,) -> Dict[str, Any]:
""" """
根据分类结果对问题进行槽位填充 根据分类结果对问题进行槽位填充
@@ -415,7 +444,7 @@ class IntentRecognizer:
raise RuntimeError("未找到匹配的槽位模型") raise RuntimeError("未找到匹配的槽位模型")
# 使用LLM进行槽位填充 # 使用LLM进行槽位填充
filled_slots = self._fill_slots_with_llm(query, classification, slot_model) filled_slots = self._fill_slots_with_llm(query, classification, slot_model, conversation_context, chat_history, previous_slots)
# 检查必填槽位是否都已填充 # 检查必填槽位是否都已填充
is_complete, missing_slots = filled_slots.check_required_slots() is_complete, missing_slots = filled_slots.check_required_slots()
@@ -467,7 +496,12 @@ class IntentRecognizer:
return None return None
def _fill_slots_with_llm(self, query: str, classification: Classification, slot_model_class: type) -> Any: def _fill_slots_with_llm(self, query: str,
classification: Classification,
slot_model_class: type,
conversation_context: str = "",
chat_history: List[Dict[str, str]] = None,
previous_slots: Dict[str, Any] = None) -> Any:
""" """
使用LLM进行槽位填充 使用LLM进行槽位填充
@@ -486,7 +520,10 @@ class IntentRecognizer:
query=query, query=query,
vertical_classification=classification.vertical_classification, vertical_classification=classification.vertical_classification,
sub_classification=classification.sub_classification, sub_classification=classification.sub_classification,
output_format=slot_parser.get_format_instructions() output_format=slot_parser.get_format_instructions(),
conversation_context=conversation_context,
chat_history=json.dumps(chat_history,ensure_ascii=False),
previous_slots=json.dumps(previous_slots,ensure_ascii=False),
) )
# 调用LLM # 调用LLM
@@ -537,8 +574,13 @@ class IntentRecognizer:
output_format=parser.get_format_instructions(), output_format=parser.get_format_instructions(),
classification_info=classification_info classification_info=classification_info
) )
# 调用LLM # 调用LLM
llm_start_time = time.time()
response = self._llm.invoke(formatted_prompt + output_example, False) response = self._llm.invoke(formatted_prompt + output_example, False)
llm_end_time = time.time()
llm_time = llm_end_time - llm_start_time
try: try:
# 解析LLM响应为JSON # 解析LLM响应为JSON
@@ -552,8 +594,19 @@ class IntentRecognizer:
if expected_slot_model is None: if expected_slot_model is None:
# 添加容错处理,应对LLM返回错误分类信息,一级分类跟二级分类错乱 # 添加容错处理,应对LLM返回错误分类信息,一级分类跟二级分类错乱
# 重新分类 # 重新分类
classification = self._classify_intent(user_input) classify_start_time = time.time()
fill_slots = self._fill_slots(user_input, classification) classification = self._classify_intent(user_input, conversation_context, chat_history, previous_slots)
classify_end_time = time.time()
classify_time = classify_end_time - classify_start_time
# logging.info(f"重新分类耗时: {classify_time:.2f}秒")
fill_start_time = time.time()
fill_slots = self._fill_slots(user_input, classification, conversation_context, chat_history, previous_slots)
fill_end_time = time.time()
fill_time = fill_end_time - fill_start_time
all_time=fill_end_time-llm_start_time
logging.info(f"总耗时:{all_time:.2f}秒,首次槽位+分类:{llm_time:.2f}秒, 重新分类耗时: {classify_time:.2f}秒, 重新槽位填充耗时: {fill_time:.2f}")
result = { result = {
"classification": classification.model_dump(), "classification": classification.model_dump(),
"slot_filling": fill_slots "slot_filling": fill_slots
@@ -562,7 +615,13 @@ class IntentRecognizer:
return result return result
elif expected_slot_model.__name__ != type(slot_filling).__name__: elif expected_slot_model.__name__ != type(slot_filling).__name__:
# 添加容错处理,应对LLM槽位与分类不匹配。重新填充槽位 # 添加容错处理,应对LLM槽位与分类不匹配。重新填充槽位
fill_start_time = time.time()
slot_filling = self._fill_slots(user_input, classification) slot_filling = self._fill_slots(user_input, classification)
fill_end_time = time.time()
fill_time = fill_end_time - fill_start_time
all_time=fill_end_time-llm_start_time
logging.info(f"总耗时:{all_time:.2f}秒,首次槽位+分类:{llm_time:.2f}秒, 重新槽位填充耗时: {fill_time:.2f}")
result = { result = {
"classification": classification.model_dump(), "classification": classification.model_dump(),
"slot_filling": slot_filling "slot_filling": slot_filling
@@ -570,6 +629,8 @@ class IntentRecognizer:
logging.warning(f"重新填充槽点") logging.warning(f"重新填充槽点")
return result return result
logging.info(f"意图识别+槽位LLM调用耗时: {llm_time:.2f}")
# 构建最终结果 # 构建最终结果
result = { result = {
"classification": classification.model_dump(), "classification": classification.model_dump(),
@@ -126,7 +126,7 @@ query_rewrite_prompt_pro_old="""
query_rewrite_prompt_pro=""" query_rewrite_prompt_pro="""
# 电力造价问答优化工程师(精简版) # 电力造价问答优化工程师(精简版)
**角色**:基于历史对话和专业术语库重构问题,提升知识库检索准确率。 **角色**:基于历史对话和术语库重构问题,提升知识库检索准确率。
## 核心原则 ## 核心原则
1. 语义保真 → 保持问题核心意图 1. 语义保真 → 保持问题核心意图
@@ -135,8 +135,14 @@ query_rewrite_prompt_pro="""
## 处理流程 ## 处理流程
### 一、输入解析 ### 一、输入解析
- 原始问题(需保留核心语义){query} - 原始问题(需保留核心语义):
- 关键词集合:{keywords} <query>
{query}
</query>
- 术语库集合:
<keywords>
{keywords}
</keywords>
- 历史对话记录: - 历史对话记录:
<history> <history>
{chat_history} {chat_history}
@@ -159,14 +165,14 @@ graph TD
### 三、重构优先级 ### 三、重构优先级
1. **背景补充** 1. **背景补充**
- 历史对话中确定的背景信息需要保留(例:"这软件""【配网工程D3" - 历史对话中确定的背景信息需要保留(例:"这软件""【配网工程计价通D3软件"
2. **术语处理** 2. **术语处理**
- 同义词转标准词 → 批量设置定额 - 同义词转标准词 → 将提问中的同义词(synonymous)替换为标准词(name)
- 存在即标记 → 【计算式】 - 存在即标记 → 【计算式】
3. **结构优化** 3. **结构优化**
- 保持原问题的5W2H特征 - 保持原问题的5W2H特征,确保问题意图不发生改变。
- 明确指代关系("该功能""【批量导入】功能" - 明确指代关系("该功能""【批量导入】功能"
## 输出规范 ## 输出规范
@@ -184,7 +190,7 @@ graph TD
- [] 背景信息是否合理补充? - [] 背景信息是否合理补充?
- [] 术语标记是否完整【】? - [] 术语标记是否完整【】?
- [] 语句是否自然流畅? - [] 语句是否自然流畅?
- [] 避免过度补充无关信息 - [] 避免补充无关信息
""" """
@@ -349,7 +355,7 @@ def generate_slot_mapping_doc() -> str:
doc.append(f"- {sub_class} -> {slot_model}") doc.append(f"- {sub_class} -> {slot_model}")
doc.append("\n## 【注意事项】") doc.append("\n## 【注意事项】")
doc.append("1. 分类与槽位模型必须严格对应") doc.append("1. 分类与槽位模型必须严格对应。严格遵守,不得违背")
doc.append("2. 每个分类只能使用其对应的槽位模型") doc.append("2. 每个分类只能使用其对应的槽位模型")
doc.append("3. 不允许混用不同分类的槽位模型") doc.append("3. 不允许混用不同分类的槽位模型")
@@ -58,6 +58,12 @@ classification_prompt="""
用户正在使用电力造价软件或想询问电力造价领域相关知识,你需要根据用户的输入内容,将其归类为以下垂直领域之一: 用户正在使用电力造价软件或想询问电力造价领域相关知识,你需要根据用户的输入内容,将其归类为以下垂直领域之一:
{classification_info} {classification_info}
## 【会话背景信息】
{conversation_context}
## 【历史对话记录】
{chat_history}
【用户输入】: 【用户输入】:
{user_input} {user_input}
@@ -154,6 +160,15 @@ slot_filling_prompt = """
【用户问题】 【用户问题】
{query} {query}
## 【会话背景信息】
{conversation_context}
## 【历史对话记录】
{chat_history}
## 【历史槽位信息】
{previous_slots}
【问题分类】 【问题分类】
垂直领域分类: {vertical_classification} 垂直领域分类: {vertical_classification}
子分类: {sub_classification} 子分类: {sub_classification}
+20 -10
View File
@@ -23,16 +23,6 @@ API_KEY_LIST=[
"sk-kzhxlqvqcxlnbdgnpalqnzumkmspepkttkgbophnkqanainw", "sk-kzhxlqvqcxlnbdgnpalqnzumkmspepkttkgbophnkqanainw",
"sk-bzttugqtlskrvguvhckwamdssvgmgnrqpsialpdbskfsyyak", "sk-bzttugqtlskrvguvhckwamdssvgmgnrqpsialpdbskfsyyak",
"sk-tovmogiablsoeabwgqyvevpcfichyjpuzqdymmvksspdrtqt", "sk-tovmogiablsoeabwgqyvevpcfichyjpuzqdymmvksspdrtqt",
"sk-wqdpapdkisovziexgcyxvumpwzbjnhqbxvcqcspzctjhyhjk",
"sk-bbntrnifrtdzhhgrtlrhvwbnaysuszviemshdakxonnnymnb",
"sk-vmpnwjxersrwybmfhfxgsvbmhsmpjldxseiyxovnysrlbuzi",
"sk-nscsxwfqigkfpfqfzebkmaickxjzbhtfwywdppmmobrrbfnw",
"sk-irbxuakhntsrusrympiubkkjbkabbfbdgpstqnxbztzdtxdq",
"sk-hcfojzczbgwgcuhzxkicxqrhadurtakwbawiesyxyvksmcoz",
"sk-wiyosqgyutjypgzibveiwkgqwfkfsnonrmvjfbvrbkoicciv",
"sk-ocglenyvxkkvzupzumoypnyndjpjqhivyqpedusunboglspz",
"sk-dtbawdwajkhdctrukundbkqwswzfzihqbebfuvqnfnounbuc",
"sk-zqiyiqtbwqgyeenkvppymfbkspriolwbnxnjakugzxyvcuql",
"sk-wtnjpejveiobtvzsmnuaefqkocsafbfyrtqkkyqardndtxcs", "sk-wtnjpejveiobtvzsmnuaefqkocsafbfyrtqkkyqardndtxcs",
"sk-gqdvtrwvzxewnagwsfakrvajtzwgcknatpflkesyqhzjrlal", "sk-gqdvtrwvzxewnagwsfakrvajtzwgcknatpflkesyqhzjrlal",
"sk-plivglrkxahodgtgjlaqdjusdoerxspjbcbizaybicarfyuk", "sk-plivglrkxahodgtgjlaqdjusdoerxspjbcbizaybicarfyuk",
@@ -96,6 +86,26 @@ API_KEY_LIST=[
"sk-jrdzerhmvrtvzawkksowbgkggkubwfquplmrxbdhespqgtis", "sk-jrdzerhmvrtvzawkksowbgkggkubwfquplmrxbdhespqgtis",
"sk-jjbpnkbeupsxyclcivbhizcfpfjrppddunbqynyjkqhtmpwu", "sk-jjbpnkbeupsxyclcivbhizcfpfjrppddunbqynyjkqhtmpwu",
"sk-oqehupcveovkjqqtxypqyifidcdissuyehwrkdwgruoyjkpq", "sk-oqehupcveovkjqqtxypqyifidcdissuyehwrkdwgruoyjkpq",
"sk-jnnmltwtqwuoyagoogzzeraczmyfxhoairiddgayksqdfnbr",
"sk-eghuepxnbcollzrjwbzqvbnhiiwagkejaclyhvaodeqgwrog",
"sk-poszkbjdmamimconjustnrxxqusuzlryxkrzkpronlenrmen",
"sk-zolvcegarsrwqhwgvwzgtqupodsdmckjiocyvoyldbkusbzc",
"sk-ywfafulcniaqdgdcsnbtqquaqeuiqlkcnknkaflwxyuemcow",
"sk-hhedmocgtfpywbbpwamgfkygrahiqsuurntlbqqbmjwfipmm",
"sk-gzdqfoyvulrqscdpjlwlufdecrsyjpmwpkknuhnjsvtyftox",
"sk-bkcufidsebujopqqwexwxwpmevrpelmvxzdymncvllcyojce",
"sk-olabhscekudzkyudypkcjvehwqunagubwdmtppugrjmcptwv",
"sk-zpdqyocliebhqpkuwvebpgcnfjdkvavdltimllmgkthwnwph",
"sk-gvhchlfelocjniuydusyhhwacnomxnvucjonzkhtqoplnbcr",
"sk-lzneagvdxhisodndnxnpkntghpkimjmjsebiqdzaoqzuhbla",
"sk-xotcfdkigykevngedupitbcatjqppxmcibjtcebyoglykuxz",
"sk-ufydqsdqnwsegaqwtappzwdyzqnoblyunfvslomnnmykedgk",
"sk-jwasykftbkyjzdqlwcxuicrwzxsbhttilxfefbrozrznpwlv",
"sk-xngteojwkxmftyaabjdwwgyoadspsowmcpcqobteutdcfmnr",
"sk-akzkgniebruqrtuqskvlibkpcxjuazhcatysptkfyqivldfn",
"sk-vpqkxtmcgkggllexchzysuewyfaoexzasoumxngdplzgwksw",
"sk-fvcsqdbqmdlwxzjyofrilusqcypbfyczogaqwqrjrwvojmer",
"sk-htjprscvfgskjtjzpxxxjhyymshagogykpawxekrrfbgftyx",
] ]
class APIKeyManager: class APIKeyManager:
+13 -48
View File
@@ -100,10 +100,10 @@ class XinferenceReRankerModel:
Returns: Returns:
List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引 List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引
""" """
url = "http://10.1.16.39:9995/v1/rerank" url = "http://172.20.0.145:9995/v1/rerank"
params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": os.getenv("RERANKER_MODEL_NAME")} params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": "bge-reranker-v2-m3"}
headers = { headers = {
"Authorization": "Bearer <token>", # 这里需要替换为实际的token "Authorization": "Bearer <token>", # 这里需要替换为实际的token
"Content-Type": "application/json" "Content-Type": "application/json"
@@ -140,8 +140,7 @@ class OpenAiLLM:
def invoke(self, user_prompt="你是谁?", need_retry=True): def invoke(self, user_prompt="你是谁?", need_retry=True):
# 初始化 OpenAI 客户端 # 初始化 OpenAI 客户端
api_key = APIKeyManager.get_api_key()
client = OpenAI(api_key=api_key, base_url=self._url)
max_retries = 3 max_retries = 3
retry_count = 0 retry_count = 0
@@ -149,6 +148,8 @@ class OpenAiLLM:
if need_retry: if need_retry:
while retry_count < max_retries: while retry_count < max_retries:
try: try:
api_key = APIKeyManager.get_api_key()
client = OpenAI(api_key=api_key, base_url=self._url)
# 创建 Completion 请求. 超时120s # 创建 Completion 请求. 超时120s
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=self._model, model=self._model,
@@ -162,11 +163,13 @@ class OpenAiLLM:
retry_count += 1 retry_count += 1
if retry_count == max_retries: if retry_count == max_retries:
logging.error(f"LLM 重试{max_retries}次后仍然失败: {e}") logging.error(f"LLM 重试{max_retries}次后仍然失败: {e}")
return "" raise e
else: else:
time.sleep(5*retry_count) # 重试前等待1秒 time.sleep(5*retry_count) # 重试前等待1秒
else: else:
# 创建 Completion 请求. 超时120s # 创建 Completion 请求. 超时120s
api_key = APIKeyManager.get_api_key()
client = OpenAI(api_key=api_key, base_url=self._url)
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=self._model, model=self._model,
messages=[{'role': 'user', 'content': user_prompt}], messages=[{'role': 'user', 'content': user_prompt}],
@@ -180,53 +183,15 @@ if __name__ == "__main__":
reranker = SiliconFlowReRankerModel() reranker = SiliconFlowReRankerModel()
# 测试用例1:简单问题 # 测试用例1:简单问题
query = "他想做什么" query = "如何通过【电力经济评价软件】的【打开】功能加载工程文件?"
documents = ["她想去公园跑步", "她想换一个新手机", "明天她想出去旅游"] documents = ["\n# (电力建设计价通软件) (概预算工程)工程备份管理\n## 操作步骤\n**方法一:** \n\n1、查找工程:输入工程文件名称的关键字,点击“查找”按钮,可以快速定位需查找的工程;\n\n![2](https://172.20.0.145/files/ea1d6edc-090c-4c35-be13-e506f0eeb176/image-preview)\n\n2、根据时间点找备份工程:选中对应工程文件,在右侧选中“备份时间”的备份记录,点击“还原工程”或者“另存为工程”;\n\n **还原工程:** 将工程还原保存在原路径下;\n\n **另存为工程:** 另存为一个新工程,可选择保存路径,保存后,可点击文件——打开,浏览到另存的新工程打开。\n\n注:不确定备份是否是需要时,优先建议另存为工程。\n\n![3](https://172.20.0.145/files/63fb4e9d-06ce-44e6-adbc-45ebc92c3e6f/image-preview)\n\n **方法二:** \n\n1、点击桌面软件快捷图标,右键属性—打开文件位置,直接定位软件安装根目录。 \n\n![打开文件所在位置](https://172.20.0.145/files/28540c85-2b85-4717-9b87-31eba886c5f0/image-preview)\n\n2、在软件安装根目录,点击“数据备份”文件夹,进入到文件夹内,根据修改日期找到对应工程,右键复制粘贴至桌面。\n\n![打开数据备份文件夹](https://172.20.0.145/files/76a4263c-ea75-4083-9ed8-7f2e28412e51/image-preview)\n\n![按照时间排序,找到.bak文件复制到桌面](https://172.20.0.145/files/eaa59154-6fbc-494e-817c-e5f7610e6294/image-preview)\n\n3、定位桌面复制粘贴出来的数据工程,右键\"重命名\",将bak修改成相应的文件后缀(概预算工程及施工图预算工程后缀为zwzj,招标工程及投标工程后缀为zwqd),然后点击“确定”,再通过软件的“文件”——“打开”按钮去浏览工程打开。\n",
"\n# (配网计价通D3)插件管理/全国版和专版切换\n## 使用场景\n1.打开软件提示“当前工程文件为全国版文件,请使用全国版软件打开!”,该如何打开这个工程呢?\n\n![1](https://172.20.0.145/files/3751d1c2-da12-4076-bd9f-3e1c9eced1ab/image-preview)\n\n2.打开软件提示“当前工程文件为辽宁版文件,请确认是否要在全国版软件中打开?”,这是什么意思?点击“确定”又可以打开工程?\n\n![2](https://172.20.0.145/files/68c18647-c4fa-4496-974a-f5b07352063e/image-preview)\n## 知识原理\n\n## 费用去向\n\n",
"\n(电力建设计价通软件) 云造价--停用\n# 工程文件管理\n\n## 【主页】中点击“云端工程管理”,进入博微服务大厅;\n![6](https://172.20.0.145/files/69d1d763-bb6b-459f-b7ba-264c788225e9/image-preview)![7](https://172.20.0.145/files/eb93d8ea-d3c8-45f3-86ab-f23c0d244d96/image-preview)\n## 工程文件管理界面中显示云端备份的工程列表,可支持\n![8](https://172.20.0.145/files/91200171-182b-4de7-a895-f5118bfe9d98/image-preview)\n## 高级设置:可对历史版本数量进行设置,默认数量为10,可设置(5-15);\n![9](https://172.20.0.145/files/31b63307-3313-4402-9652-9b46e9549aec/image-preview)\n## 历史版本:勾选单个工程,点击“历史版本”可查看该工程保存的不同时间节点的历史工程;\n![10](https://172.20.0.145/files/1ae83da5-9e86-4a44-9af1-013199a933c6/image-preview)\n## 在线查阅:可查看工程数据,仅为只读模式不支持任何编辑;\n![11](https://172.20.0.145/files/61ef2bc8-0d73-47bf-b66f-aebcc39d7a46/image-preview)\n## 下载;选择需要的工程点击“下载”,可下载软件版本工程;\n![12](https://172.20.0.145/files/3391d7b2-c289-485a-a459-754ce6c8e294/image-preview)\n",
"\n(配网D3软件)打开工程\n\n# (配网D3软件)打开工程\n\n## 功能入口\n各界面点击“文件”按钮——“打开”按钮 \n\n![打开工程-功能入口](https://172.20.0.145/files/7c6dbff1-932c-4cf5-ab16-b17a6309af0e/image-preview)\n## 操作步骤\n**打开工程:** \n\n点击“打开”按钮,浏览到工程存放位置,选中工程文件,点击“打开”即可。"]
results = reranker.rerank(query, documents) results = reranker.rerank(query, documents)
print(f"测试用例1 - 查询:{query}") print(f"测试用例1 - 查询:{query}")
for idx, item in enumerate(results): for idx, item in enumerate(results):
print(f"{idx+1}. 文档: {item['document']}, 分数: {item['score']}") print(f"{idx+1}. 文档: {item['document']}, 分数: {item['score']}")
print("-" * 50) print("-" * 50)
# 测试用例2:技术问题
query = "Python如何处理JSON数据"
documents = [
"Python中可以使用json模块来处理JSON数据,例如json.loads()将JSON字符串转换为字典",
"Java提供了多种处理JSON的库,比如Jackson和Gson",
"在Python中,可以使用pandas库来分析CSV数据",
"JavaScript可以使用JSON.parse()方法解析JSON字符串"
]
results = reranker.rerank(query, documents)
print(f"测试用例2 - 查询:{query}")
for idx, item in enumerate(results):
print(f"{idx+1}. 文档: {item['document']}, 分数: {item['score']}")
print("-" * 50)
# 测试用例3:医疗问题
query = "高血压的症状有哪些"
documents = [
"高血压的常见症状包括头痛、头晕、耳鸣和视力模糊",
"糖尿病的症状包括多饮、多尿和体重减轻",
"心脏病的症状通常包括胸痛、呼吸急促和疲劳",
"高血压患者应该定期监测血压,保持健康的生活方式"
]
results = reranker.rerank(query, documents)
print(f"测试用例3 - 查询:{query}")
for idx, item in enumerate(results):
print(f"{idx+1}. 文档: {item['document']}, 分数: {item['score']}")
print("-" * 50)
# 测试用例4:长文本查询和文档
query = "人工智能在医疗领域的应用及其伦理问题"
documents = [
"人工智能在医疗诊断中的应用已经显示出良好的效果,例如通过分析医学影像来检测疾病。然而,这也引发了关于医生角色和责任的伦理问题。",
"在教育领域,人工智能可以提供个性化学习体验,适应不同学生的学习进度和风格。",
"医疗伦理问题主要包括患者隐私保护、知情同意和医疗资源分配等方面。",
"人工智能技术在金融领域的应用主要集中在风险评估、欺诈检测和算法交易等方面。"
]
results = reranker.rerank(query, documents)
print(f"测试用例4 - 查询:{query}")
for idx, item in enumerate(results):
print(f"{idx+1}. 文档: {item['document']}, 分数: {item['score']}")
print("-" * 50)