优化意图识别示例,新增命令行参数解析功能,支持输入输出文件路径和调试模式,增强代码可读性和灵活性。同时更新Dify工具,调整检索信息获取逻辑,确保重排得分信息的正确传递。
This commit is contained in:
@@ -17,6 +17,7 @@ import concurrent.futures
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import time
|
import time
|
||||||
import sys
|
import sys
|
||||||
|
import argparse
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
# 加载环境变量
|
# 加载环境变量
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
@@ -176,6 +177,7 @@ def save_results_to_excel(results, output_file, is_final=False):
|
|||||||
|
|
||||||
# 示例查询
|
# 示例查询
|
||||||
examples_query = """那储能软件如何操作"""
|
examples_query = """那储能软件如何操作"""
|
||||||
|
examples_query = """博微软件如何新建工程啊"""
|
||||||
conversation_context=""
|
conversation_context=""
|
||||||
chat_history=[
|
chat_history=[
|
||||||
{
|
{
|
||||||
@@ -199,34 +201,68 @@ previous_slots={
|
|||||||
"software_version": None,
|
"software_version": None,
|
||||||
"operation_steps": None
|
"operation_steps": None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def parse_arguments():
|
||||||
|
"""解析命令行参数"""
|
||||||
|
parser = argparse.ArgumentParser(description='意图识别和问题改写工具')
|
||||||
|
|
||||||
|
# 添加数据文件路径参数
|
||||||
|
parser.add_argument('--input', '-i', type=str,
|
||||||
|
help='输入Excel文件路径,包含待处理的提问数据(第一列)')
|
||||||
|
parser.add_argument('--output', '-o', type=str,
|
||||||
|
help='输出Excel文件路径,用于保存处理结果')
|
||||||
|
|
||||||
|
# 添加LLM相关参数
|
||||||
|
parser.add_argument('--model', '-m', type=str,
|
||||||
|
help='LLM模型名称,默认使用环境变量中的配置')
|
||||||
|
parser.add_argument('--api_base', '-a', type=str,
|
||||||
|
help='API基础URL,默认使用环境变量中的配置')
|
||||||
|
|
||||||
|
# 添加处理相关参数
|
||||||
|
parser.add_argument('--max_workers', '-w', type=int, default=20,
|
||||||
|
help='并发处理的最大线程数,默认为20')
|
||||||
|
parser.add_argument('--debug', '-d', action='store_true',
|
||||||
|
help='启用调试模式,使用示例查询而非从文件读取')
|
||||||
|
parser.add_argument('--query', '-q', type=str,
|
||||||
|
help='在调试模式下使用的查询字符串')
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""
|
"""
|
||||||
意图识别和问题改写示例
|
意图识别和问题改写示例
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 从环境变量中获取配置
|
# 解析命令行参数
|
||||||
|
args = parse_arguments()
|
||||||
|
|
||||||
|
# 从环境变量中获取配置,命令行参数优先
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
base_url = os.getenv("OPENAI_API_BASE")
|
base_url = args.api_base if args.api_base else os.getenv("OPENAI_API_BASE")
|
||||||
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
model_name = args.model if args.model else os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||||
|
|
||||||
# 初始化意图识别器
|
# 初始化意图识别器
|
||||||
recognizer = IntentRecognizer(api_key=api_key, base_url=base_url, model_name=model_name)
|
recognizer = IntentRecognizer(api_key=api_key, base_url=base_url, model_name=model_name)
|
||||||
|
|
||||||
# 读取提问数据
|
# 读取提问数据
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
data_file = os.path.join(current_dir, "..", "..", "data", "excel", "200条点踩数据测试.xlsx")
|
data_file = args.input if args.input else os.path.join(current_dir, "..", "..", "data", "excel", "历史提问数据(dislike)_提问明确.xlsx")
|
||||||
output_file = os.path.join(current_dir, "..", "..", "data", "excel", "200条点踩数据测试_槽位填充结果.xlsx")
|
output_file = args.output if args.output else os.path.join(current_dir, "..", "..", "data", "excel", "历史提问数据(dislike)_槽位(分类)填充结果.xlsx")
|
||||||
|
|
||||||
|
# 检测是否为调试模式
|
||||||
|
is_debug = args.debug or (hasattr(sys, 'gettrace') and sys.gettrace() is not None)
|
||||||
|
|
||||||
# 检测是否为调试模式,调试模式下使用examples_query,否则从Excel读取
|
|
||||||
is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
|
|
||||||
# is_debug = False
|
|
||||||
if is_debug:
|
if is_debug:
|
||||||
examples = examples_query.strip().split("\n")
|
# 如果提供了查询参数,使用它;否则使用默认示例
|
||||||
|
if args.query:
|
||||||
|
examples = [args.query]
|
||||||
|
else:
|
||||||
|
examples = examples_query.strip().split("\n")
|
||||||
else:
|
else:
|
||||||
examples = load_questions_from_excel(data_file)
|
examples = load_questions_from_excel(data_file)
|
||||||
|
|
||||||
if not is_debug:
|
if not is_debug:
|
||||||
max_workers = 20 # 减少并发数以避免API限制
|
max_workers = args.max_workers
|
||||||
logging.info(f"共有 {len(examples)} 个问题需要处理,使用 {max_workers} 个并发线程")
|
logging.info(f"共有 {len(examples)} 个问题需要处理,使用 {max_workers} 个并发线程")
|
||||||
|
|
||||||
# 创建一个与输入顺序相同的结果列表
|
# 创建一个与输入顺序相同的结果列表
|
||||||
@@ -262,8 +298,8 @@ def main():
|
|||||||
for idx, query in enumerate(examples):
|
for idx, query in enumerate(examples):
|
||||||
if query.strip() == "":
|
if query.strip() == "":
|
||||||
continue
|
continue
|
||||||
process_query(recognizer, query, conversation_context, chat_history, previous_slots)
|
# process_query(recognizer, query, conversation_context, chat_history, previous_slots)
|
||||||
# print(json.dumps(process_query(recognizer, query), ensure_ascii=False, indent=2))
|
print(json.dumps(process_query(recognizer, query), ensure_ascii=False, indent=2))
|
||||||
|
|
||||||
def setup_logging():
|
def setup_logging():
|
||||||
# 配置日志输出到控制台
|
# 配置日志输出到控制台
|
||||||
|
|||||||
+34
-19
@@ -318,7 +318,7 @@ content: "{content}"
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
def get_retrieve_info(self, query: str, outputs: dict) -> tuple:
|
def get_retrieve_info(self, query: str, outputs: dict, reranker_sorce_info:list) -> tuple:
|
||||||
"""
|
"""
|
||||||
获取检索信息并计算分数
|
获取检索信息并计算分数
|
||||||
|
|
||||||
@@ -333,20 +333,21 @@ content: "{content}"
|
|||||||
min_score = 10
|
min_score = 10
|
||||||
total_score = 0
|
total_score = 0
|
||||||
valid_scores = 0
|
valid_scores = 0
|
||||||
retrieve_content = []
|
retrieve_title = []
|
||||||
|
|
||||||
# 使用线程池并发计算分数
|
# 使用线程池并发计算分数
|
||||||
with ThreadPoolExecutor() as executor:
|
with ThreadPoolExecutor() as executor:
|
||||||
# 创建任务列表
|
# 创建任务列表
|
||||||
future_to_content = {}
|
future_to_content = {}
|
||||||
for result in outputs["result"]:
|
for result in outputs:
|
||||||
content = result["content"].strip()
|
content = result["segment_content"].strip()
|
||||||
|
segment_id = result["segment_id"].strip()
|
||||||
future = executor.submit(self.calculate_score, query=query, content=content)
|
future = executor.submit(self.calculate_score, query=query, content=content)
|
||||||
future_to_content[future] = content
|
future_to_content[future] = (content, segment_id)
|
||||||
|
|
||||||
# 收集结果
|
# 收集结果
|
||||||
for future in as_completed(future_to_content):
|
for future in as_completed(future_to_content):
|
||||||
content = future_to_content[future]
|
content, segment_id = future_to_content[future]
|
||||||
score = future.result()
|
score = future.result()
|
||||||
content_title = content.split("\n")[0]
|
content_title = content.split("\n")[0]
|
||||||
|
|
||||||
@@ -357,10 +358,11 @@ content: "{content}"
|
|||||||
valid_scores += 1
|
valid_scores += 1
|
||||||
|
|
||||||
if content_title:
|
if content_title:
|
||||||
retrieve_content.append(content_title + f"--相关性得分({score}分)")
|
current_score = next((cur_source_info["score"] for cur_source_info in reranker_sorce_info if cur_source_info["segment_id"] == segment_id), None)
|
||||||
|
retrieve_title.append(content_title + f"--LLM得分({score}分)--重排得分({current_score:.2f}分)")
|
||||||
|
|
||||||
avg_score = total_score / valid_scores if valid_scores > 0 else 0
|
avg_score = total_score / valid_scores if valid_scores > 0 else 0
|
||||||
return retrieve_content, max_score, min_score, avg_score
|
return retrieve_title, max_score, min_score, avg_score
|
||||||
|
|
||||||
|
|
||||||
class NewWorkflowChat(BaseWorkflowChat):
|
class NewWorkflowChat(BaseWorkflowChat):
|
||||||
@@ -395,7 +397,6 @@ class NewWorkflowChat(BaseWorkflowChat):
|
|||||||
"新问题分类": workflow_info["问题分类"],
|
"新问题分类": workflow_info["问题分类"],
|
||||||
"槽点信息": workflow_info["槽点信息"],
|
"槽点信息": workflow_info["槽点信息"],
|
||||||
"新检索词条": workflow_info["检索词条"],
|
"新检索词条": workflow_info["检索词条"],
|
||||||
"检索内容": workflow_info["检索内容"],
|
|
||||||
"message_id":message_id
|
"message_id":message_id
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -421,14 +422,23 @@ class NewWorkflowChat(BaseWorkflowChat):
|
|||||||
vertical_classification = ""
|
vertical_classification = ""
|
||||||
sub_classification = ""
|
sub_classification = ""
|
||||||
slot_info = ""
|
slot_info = ""
|
||||||
|
reranker_sorce=[]
|
||||||
try:
|
try:
|
||||||
|
# 先取出重排得分
|
||||||
message_info = DifyTool.get_message_debug_info_by_id(message_id=message_id)
|
message_info = DifyTool.get_message_debug_info_by_id(message_id=message_id)
|
||||||
for workflow_node in message_info["workflow_node_executions_info"]:
|
for workflow_node in message_info["workflow_node_executions_info"]:
|
||||||
if workflow_node["title"] == "知识检索结果后处理":
|
if workflow_node["title"] == "软件知识检索聚合":
|
||||||
outputs = json.loads(workflow_node["outputs"])
|
retrieve_outputs = json.loads(workflow_node["inputs"])["result"]
|
||||||
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
|
reranker_sorce = [{"score":result["metadata"]["score"], "segment_id":result["metadata"]["segment_id"]} for result in retrieve_outputs]
|
||||||
retrieve_content = outputs["result"]
|
|
||||||
|
|
||||||
|
for workflow_node in message_info["workflow_node_executions_info"]:
|
||||||
|
if workflow_node["title"] == "软件知识检索聚合":
|
||||||
|
retrieve_outputs = json.loads(workflow_node["inputs"])["result"]
|
||||||
|
reranker_sorce = [{"score":result["metadata"]["score"], "segment_id":result["metadata"]["segment_id"]} for result in retrieve_outputs]
|
||||||
|
elif workflow_node["title"] == "提取处理后的知识":
|
||||||
|
outputs = json.loads(workflow_node["outputs"])["knowledge_list"]
|
||||||
|
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs, reranker_sorce_info=reranker_sorce)
|
||||||
elif workflow_node["title"] == "问题优化结果解析":
|
elif workflow_node["title"] == "问题优化结果解析":
|
||||||
outputs = json.loads(workflow_node["outputs"])
|
outputs = json.loads(workflow_node["outputs"])
|
||||||
rewrite_query = outputs["optimize_query"]
|
rewrite_query = outputs["optimize_query"]
|
||||||
@@ -439,11 +449,18 @@ class NewWorkflowChat(BaseWorkflowChat):
|
|||||||
slot_info = json.dumps(json_result["slot_filling"], ensure_ascii=False, indent=2)
|
slot_info = json.dumps(json_result["slot_filling"], ensure_ascii=False, indent=2)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
retrieve_content = ""
|
||||||
|
if len(reranker_sorce)==0:
|
||||||
|
retrieve_content="未检索知识库"
|
||||||
|
elif len(reranker_sorce) > 0 and len(retrieve_title)==0:
|
||||||
|
retrieve_content = "知识与提问不相关,被丢弃"
|
||||||
|
else:
|
||||||
|
retrieve_content = "\n".join(retrieve_title)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"问题改写": rewrite_query,
|
"问题改写": rewrite_query,
|
||||||
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
|
"检索词条": retrieve_content,
|
||||||
"检索内容": retrieve_content,
|
|
||||||
"问题分类": f"{vertical_classification} - {sub_classification}",
|
"问题分类": f"{vertical_classification} - {sub_classification}",
|
||||||
"槽点信息": slot_info,
|
"槽点信息": slot_info,
|
||||||
|
|
||||||
@@ -479,7 +496,6 @@ class OldWorkFlowChat(BaseWorkflowChat):
|
|||||||
"旧流程答案": answer,
|
"旧流程答案": answer,
|
||||||
"旧问题改写": workflow_info["问题改写"],
|
"旧问题改写": workflow_info["问题改写"],
|
||||||
"旧检索词条": workflow_info["检索词条"],
|
"旧检索词条": workflow_info["检索词条"],
|
||||||
"检索内容": workflow_info["检索内容"],
|
|
||||||
"message_id":message_id
|
"message_id":message_id
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -519,7 +535,6 @@ class OldWorkFlowChat(BaseWorkflowChat):
|
|||||||
return {
|
return {
|
||||||
"问题改写": rewrite_query,
|
"问题改写": rewrite_query,
|
||||||
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
|
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
|
||||||
"检索内容": retrieve_content,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -411,7 +411,6 @@ content: "{content}"
|
|||||||
return {
|
return {
|
||||||
"问题改写": rewrite_query,
|
"问题改写": rewrite_query,
|
||||||
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
|
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
|
||||||
"检索内容": retrieve_content,
|
|
||||||
"问题分类": f"{vertical_classification} - {sub_classification}",
|
"问题分类": f"{vertical_classification} - {sub_classification}",
|
||||||
"槽点信息": slot_info
|
"槽点信息": slot_info
|
||||||
}
|
}
|
||||||
@@ -451,7 +450,6 @@ content: "{content}"
|
|||||||
return {
|
return {
|
||||||
"问题改写": rewrite_query,
|
"问题改写": rewrite_query,
|
||||||
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
|
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
|
||||||
"检索内容": retrieve_content,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_retrieve_title_similarity(self, old_retrieve_content:list[dict], new_retrieve_content:list[dict]) -> str:
|
def get_retrieve_title_similarity(self, old_retrieve_content:list[dict], new_retrieve_content:list[dict]) -> str:
|
||||||
@@ -589,9 +587,7 @@ content: "{content}"
|
|||||||
|
|
||||||
if judge_result is None:
|
if judge_result is None:
|
||||||
judge_result = ""
|
judge_result = ""
|
||||||
|
|
||||||
# retrieve_title_score = self.get_retrieve_title_similarity(old_retrieve_content=old_workflow_info["检索内容"], new_retrieve_content=new_workflow_info["检索内容"])
|
|
||||||
|
|
||||||
# 返回结果
|
# 返回结果
|
||||||
return {
|
return {
|
||||||
"问题": query,
|
"问题": query,
|
||||||
|
|||||||
@@ -0,0 +1,101 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import random
|
||||||
|
import math
|
||||||
|
|
||||||
|
work_order_excel="data/excel/6万工单记录.xlsx"
|
||||||
|
|
||||||
|
soft_row_data={
|
||||||
|
"博微配网计价通D3":{"基本功能":[], "高级功能":[]},
|
||||||
|
"储能C1软件":{"基本功能":[], "高级功能":[]},
|
||||||
|
"西藏计价通Z1":{"基本功能":[], "高级功能":[]},
|
||||||
|
"技改检修工程计价通T1":{"基本功能":[], "高级功能":[]},
|
||||||
|
"检修清单计价通T1":{"基本功能":[], "高级功能":[]},
|
||||||
|
"电力建设计价通软件":{"基本功能":[], "高级功能":[]},
|
||||||
|
}
|
||||||
|
|
||||||
|
df = pd.read_excel(work_order_excel)
|
||||||
|
|
||||||
|
for idx, row in df.iterrows():
|
||||||
|
if pd.isna(row["产品线"]):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "博微配网计价通D3" in row["产品线"]:
|
||||||
|
soft_row_data["博微配网计价通D3"][row["问题类型"]].append((idx, row))
|
||||||
|
elif "博微电力建设计价通软件" in row["产品线"]:
|
||||||
|
soft_row_data["电力建设计价通软件"][row["问题类型"]].append((idx, row))
|
||||||
|
elif "新能源系列" in row["产品线"] and "博微新型储能电站建设计价通C1软件" in row["产品名称"]:
|
||||||
|
soft_row_data["储能C1软件"][row["问题类型"]].append((idx, row))
|
||||||
|
elif "博微西藏计价通Z1" in row["产品线"]:
|
||||||
|
soft_row_data["西藏计价通Z1"][row["问题类型"]].append((idx, row))
|
||||||
|
elif "博微技改检修计价通T1软件" in row["产品线"] and "技改检修计价通T1软件-概预算" in row["产品名称"]:
|
||||||
|
soft_row_data["技改检修工程计价通T1"][row["问题类型"]].append((idx, row))
|
||||||
|
elif "博微技改检修计价通T1软件" in row["产品线"] and "技改检修计价通T1软件-清单" in row["产品名称"]:
|
||||||
|
soft_row_data["检修清单计价通T1"][row["问题类型"]].append((idx, row))
|
||||||
|
|
||||||
|
# 计算每个软件和功能类型的数据量
|
||||||
|
total_count = 0
|
||||||
|
counts = {}
|
||||||
|
for software, types in soft_row_data.items():
|
||||||
|
counts[software] = {}
|
||||||
|
for type_name, rows in types.items():
|
||||||
|
counts[software][type_name] = len(rows)
|
||||||
|
total_count += len(rows)
|
||||||
|
|
||||||
|
print(f"原始数据总量: {total_count}条")
|
||||||
|
for software, types in counts.items():
|
||||||
|
print(f"{software}: 基本功能 {types['基本功能']}条, 高级功能 {types['高级功能']}条")
|
||||||
|
|
||||||
|
# 计算均衡提取的数量
|
||||||
|
total_target = 2000
|
||||||
|
categories_count = sum(len(types) for types in soft_row_data.values())
|
||||||
|
per_category_target = math.ceil(total_target / categories_count)
|
||||||
|
|
||||||
|
# 均衡提取数据
|
||||||
|
balanced_data = []
|
||||||
|
extracted_counts = {}
|
||||||
|
extracted_indices = set() # 使用集合存储已提取数据的索引
|
||||||
|
|
||||||
|
for software, types in soft_row_data.items():
|
||||||
|
extracted_counts[software] = {}
|
||||||
|
|
||||||
|
for type_name, rows in types.items():
|
||||||
|
# 如果数据量不足,全部提取;否则随机抽取目标数量
|
||||||
|
if len(rows) <= per_category_target:
|
||||||
|
extracted = rows
|
||||||
|
else:
|
||||||
|
extracted = random.sample(rows, per_category_target)
|
||||||
|
|
||||||
|
extracted_counts[software][type_name] = len(extracted)
|
||||||
|
for idx, row in extracted:
|
||||||
|
extracted_indices.add(idx) # 记录已提取数据的索引
|
||||||
|
balanced_data.append(row)
|
||||||
|
|
||||||
|
# 数据量不足2000时,从剩余数据中补充
|
||||||
|
remaining_target = total_target - len(balanced_data)
|
||||||
|
if remaining_target > 0:
|
||||||
|
# 收集所有未被选中的数据
|
||||||
|
remaining_data = []
|
||||||
|
for software, types in soft_row_data.items():
|
||||||
|
for type_name, rows in types.items():
|
||||||
|
# 添加未被选中的数据
|
||||||
|
for idx, row in rows:
|
||||||
|
if idx not in extracted_indices:
|
||||||
|
remaining_data.append(row)
|
||||||
|
|
||||||
|
# 如果剩余数据足够,随机抽取补充
|
||||||
|
if len(remaining_data) >= remaining_target:
|
||||||
|
additional_data = random.sample(remaining_data, remaining_target)
|
||||||
|
else:
|
||||||
|
additional_data = remaining_data
|
||||||
|
|
||||||
|
balanced_data.extend(additional_data)
|
||||||
|
|
||||||
|
# 输出结果
|
||||||
|
print(f"\n均衡提取后数据总量: {len(balanced_data)}条")
|
||||||
|
for software, types in extracted_counts.items():
|
||||||
|
print(f"{software}: 基本功能 {types['基本功能']}条, 高级功能 {types['高级功能']}条")
|
||||||
|
|
||||||
|
# 将均衡提取的数据转换为DataFrame并保存
|
||||||
|
balanced_df = pd.DataFrame(balanced_data)
|
||||||
|
balanced_df.to_excel("data/excel/均衡提取2000条工单.xlsx", index=False)
|
||||||
|
print(f"\n已将均衡提取的{len(balanced_data)}条数据保存至'data/excel/均衡提取2000条工单.xlsx'")
|
||||||
@@ -160,7 +160,7 @@ class SoftwareFunctionSlots(SlotBase):
|
|||||||
self.project_type="单工程"
|
self.project_type="单工程"
|
||||||
missing_slots = {}
|
missing_slots = {}
|
||||||
if not self.software_name:
|
if not self.software_name:
|
||||||
missing_slots["software_name"] = f"{SoftwareFunctionSlots.model_fields['software_name'].description},可选值:{', '.join([name.value for name in SoftwareName if name not in [SoftwareName.UNKNOWN, SoftwareName.ALIASES]])}"
|
missing_slots["software_name"] = f"{SoftwareFunctionSlots.model_fields['software_name'].description},支持的软件:{', '.join([name.value for name in SoftwareName if name not in [SoftwareName.UNKNOWN, SoftwareName.ALIASES]])}"
|
||||||
if not self.function_name:
|
if not self.function_name:
|
||||||
missing_slots["function_name"] = SoftwareFunctionSlots.model_fields["function_name"].description
|
missing_slots["function_name"] = SoftwareFunctionSlots.model_fields["function_name"].description
|
||||||
if not self.operation:
|
if not self.operation:
|
||||||
@@ -181,7 +181,7 @@ class SoftwareTroubleShootingSlots(SlotBase):
|
|||||||
"""检查必填槽位是否都存在"""
|
"""检查必填槽位是否都存在"""
|
||||||
missing_slots = {}
|
missing_slots = {}
|
||||||
if not self.software_name:
|
if not self.software_name:
|
||||||
missing_slots["software_name"] = f"{SoftwareTroubleShootingSlots.model_fields['software_name'].description},可选值:{', '.join([name.value for name in SoftwareName if name not in [SoftwareName.UNKNOWN, SoftwareName.ALIASES]])}"
|
missing_slots["software_name"] = f"{SoftwareTroubleShootingSlots.model_fields['software_name'].description},支持的软件:{', '.join([name.value for name in SoftwareName if name not in [SoftwareName.UNKNOWN, SoftwareName.ALIASES]])}"
|
||||||
if not self.function_name:
|
if not self.function_name:
|
||||||
missing_slots["function_name"] = SoftwareTroubleShootingSlots.model_fields["function_name"].description
|
missing_slots["function_name"] = SoftwareTroubleShootingSlots.model_fields["function_name"].description
|
||||||
if not self.error_message:
|
if not self.error_message:
|
||||||
@@ -191,7 +191,7 @@ class SoftwareTroubleShootingSlots(SlotBase):
|
|||||||
# 2. 业务问题
|
# 2. 业务问题
|
||||||
# 2.1 专业咨询
|
# 2.1 专业咨询
|
||||||
class ProfessionalConsultingSlots(SlotBase):
|
class ProfessionalConsultingSlots(SlotBase):
|
||||||
scene_subject: str = Field(default="", description="场景主体")
|
scene_subject: str = Field(default="", description="业务主体。即询问的业务对象(规范、标准、费用等)")
|
||||||
business_scene: str = Field(default="", description="业务场景描述")
|
business_scene: str = Field(default="", description="业务场景描述")
|
||||||
software_name: Optional[str] = Field(default="", description="软件名称")
|
software_name: Optional[str] = Field(default="", description="软件名称")
|
||||||
|
|
||||||
@@ -266,7 +266,6 @@ class InstallationDownloadSlots(SlotBase):
|
|||||||
missing_slots = {}
|
missing_slots = {}
|
||||||
if not self.software_name and not self.file_name:
|
if not self.software_name and not self.file_name:
|
||||||
missing_slots["software_name"] = f"{InstallationDownloadSlots.model_fields['software_name'].description},"
|
missing_slots["software_name"] = f"{InstallationDownloadSlots.model_fields['software_name'].description},"
|
||||||
f"可选值:{', '.join([name.value for name in SoftwareName if name not in [SoftwareName.UNKNOWN, SoftwareName.ALIASES]])}"
|
|
||||||
missing_slots["file_name"] = InstallationDownloadSlots.model_fields["file_name"].description
|
missing_slots["file_name"] = InstallationDownloadSlots.model_fields["file_name"].description
|
||||||
if not self.operation_stage:
|
if not self.operation_stage:
|
||||||
missing_slots["operation_stage"] = InstallationDownloadSlots.model_fields["operation_stage"].description
|
missing_slots["operation_stage"] = InstallationDownloadSlots.model_fields["operation_stage"].description
|
||||||
|
|||||||
@@ -304,8 +304,8 @@ class IntentRecognizer:
|
|||||||
|
|
||||||
rewrite_start_time = time.time()
|
rewrite_start_time = time.time()
|
||||||
# 准备问题改写提示
|
# 准备问题改写提示
|
||||||
# terms_dict = [term.model_dump(exclude={"description"}) for term in keywords.terms]
|
terms_dict = [term.model_dump(exclude={"description"}) for term in keywords.terms]
|
||||||
terms_dict = [term.model_dump() for term in keywords.terms]
|
# terms_dict = [term.model_dump() for term in keywords.terms]
|
||||||
keywords_str = json.dumps(terms_dict, ensure_ascii=False)
|
keywords_str = json.dumps(terms_dict, ensure_ascii=False)
|
||||||
query_rewrite_parser = PydanticOutputParser(pydantic_object=QueryRewrite)
|
query_rewrite_parser = PydanticOutputParser(pydantic_object=QueryRewrite)
|
||||||
# formatted_prompt = query_rewrite_prompt.format(query=query,
|
# formatted_prompt = query_rewrite_prompt.format(query=query,
|
||||||
@@ -401,27 +401,27 @@ class IntentRecognizer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 步骤3: 进行意图识别和槽位填充
|
# 步骤3: 进行意图识别和槽位填充
|
||||||
result = self._process_intent_and_slot(rewrite.rewrite, conversation_context, chat_history, previous_slots)
|
# result = self._process_intent_and_slot(rewrite.rewrite, conversation_context, chat_history, previous_slots)
|
||||||
result.update({"keywords": keywords_terms.model_dump(),
|
# result.update({"keywords": keywords_terms.model_dump(),
|
||||||
"rewrite": rewrite.model_dump(),
|
# "rewrite": rewrite.model_dump(),
|
||||||
"query_keys": query_keys})
|
# "query_keys": query_keys})
|
||||||
return result
|
# return result
|
||||||
# # 步骤3: 进行意图分类
|
# 步骤3: 进行意图分类
|
||||||
# classification = self._classify_intent(query)
|
classification = self._classify_intent(rewrite.rewrite, conversation_context, chat_history, previous_slots)
|
||||||
|
|
||||||
# # 步骤4: 进行槽位填充
|
# 步骤4: 进行槽位填充
|
||||||
# # 如果是有效分类,进行槽位填充
|
# 如果是有效分类,进行槽位填充
|
||||||
# slot_filling_result = {}
|
slot_filling_result = {}
|
||||||
# if classification.vertical_classification not in ["其他", "闲聊"] and classification.sub_classification not in ["其他", "闲聊"]:
|
if classification.vertical_classification not in ["其他", "闲聊"] and classification.sub_classification not in ["其他", "闲聊"]:
|
||||||
# slot_filling_result = self._fill_slots(rewrite.rewrite, classification)
|
slot_filling_result = self._fill_slots(rewrite.rewrite, classification, conversation_context, chat_history, previous_slots)
|
||||||
|
|
||||||
# return {
|
return {
|
||||||
# "classification": classification.model_dump(),
|
"classification": classification.model_dump(),
|
||||||
# "keywords": keywords_terms.model_dump(),
|
"keywords": keywords_terms.model_dump(),
|
||||||
# "rewrite": rewrite.model_dump(),
|
"rewrite": rewrite.model_dump(),
|
||||||
# "query_keys": query_keys,
|
"query_keys": query_keys,
|
||||||
# "slot_filling": slot_filling_result
|
"slot_filling": slot_filling_result
|
||||||
# }
|
}
|
||||||
|
|
||||||
|
|
||||||
def _fill_slots(self, query: str, classification: Classification, conversation_context: str = "",
|
def _fill_slots(self, query: str, classification: Classification, conversation_context: str = "",
|
||||||
|
|||||||
@@ -127,11 +127,13 @@ query_rewrite_prompt_pro_old="""
|
|||||||
query_rewrite_prompt_pro="""
|
query_rewrite_prompt_pro="""
|
||||||
# 电力造价问答优化工程师(精简版)
|
# 电力造价问答优化工程师(精简版)
|
||||||
**角色**:基于历史对话和术语库重构问题,提升知识库检索准确率。
|
**角色**:基于历史对话和术语库重构问题,提升知识库检索准确率。
|
||||||
|
最高准则:保持问题核心意图,但允许在指代消除、背景继承下添加隐含功能词。但重构后的问题,所有引入的主体背景等均要来源于历史对话、聊天背景或术语库,不得凭空捏造未提及的内容。
|
||||||
|
|
||||||
## 核心原则
|
## 核心原则
|
||||||
1. 语义保真 → 保持问题核心意图
|
1. **指代消除 → 当指示代词("那"/"这")出现时,强制继承历史对话的最新核心主题(如功能或任务),并应用到当前主体。**
|
||||||
2. 术语规范 → 同义词转标准词并【】标记
|
2. 背景继承 → 补充历史对话和聊天背景中的隐含信息(包括主题和功能)。
|
||||||
3. 背景继承 → 补充历史对话的隐含信息
|
4. 术语规范 → 同义词转标准词并【】标记。提问中的同义词(synonymous)替换为标准词(name)
|
||||||
|
5. 语义保真 → 保持问题核心意图,但允许在指代消除、背景继承下添加隐含功能词。
|
||||||
|
|
||||||
## 处理流程
|
## 处理流程
|
||||||
### 一、输入解析
|
### 一、输入解析
|
||||||
@@ -155,37 +157,30 @@ query_rewrite_prompt_pro="""
|
|||||||
### 二、重构决策树
|
### 二、重构决策树
|
||||||
```mermaid
|
```mermaid
|
||||||
graph TD
|
graph TD
|
||||||
A[输入问题] --> B{{匹配关键词或上下文?}}
|
A[输入问题] --> B{{包含指示代词?}}
|
||||||
B -- 是 --> C[执行重构]
|
B -- 是 --> C[提取历史最新主题]
|
||||||
B -- 否 --> D[直接输出原始问题]
|
C --> D{{主题是否明确?}}
|
||||||
C --> E[补充缺失背景]
|
D -- 是 --> E[继承主题到当前问题]
|
||||||
E --> F[同义词替换+【】标记]
|
E --> F[执行重构]
|
||||||
F --> G[保留原生专业术语]
|
D -- 否 --> F
|
||||||
|
F --> G[补充缺失背景]
|
||||||
|
G --> H[同义词替换+【】标记]
|
||||||
|
H --> I[保留原生专业术语]
|
||||||
|
B -- 否 --> I
|
||||||
```
|
```
|
||||||
|
|
||||||
### 三、重构优先级
|
### 三、重构优先级
|
||||||
1. **背景补充**
|
1. **指代消除 → 当指示代词出现时,优先继承历史对话的核心主题(如功能词),并替换当前问题的动词部分。**
|
||||||
- 历史对话中确定的背景信息需要保留(例:"这软件"→"【配网工程计价通D3软件】")
|
2. 背景继承 → 历史对话中确定的背景信息需要保留。
|
||||||
|
3. 术语处理 → 同义词转标准词 + 【】标记。
|
||||||
2. **术语处理**
|
4. 同义词转标准词 → 将提问中的同义词(synonymous)替换为标准词(name)
|
||||||
- 同义词转标准词 → 将提问中的同义词(synonymous)替换为标准词(name)
|
4. 结构优化 → 保持原问题的5W2H特征,指代消除、背景继承下允许微调意图。
|
||||||
- 存在即标记 → 【计算式】
|
|
||||||
|
|
||||||
3. **结构优化**
|
|
||||||
- 保持原问题的5W2H特征,确保问题意图不发生改变。
|
|
||||||
- 明确指代关系("该功能"→"【批量导入】功能")
|
|
||||||
|
|
||||||
## 输出规范
|
## 输出规范
|
||||||
{output_format}
|
{output_format}
|
||||||
|
|
||||||
## 典型案例
|
|
||||||
| 场景 | 输入问题 | 输出结果 |
|
|
||||||
|---------------------|-----------------------------------|------------------------------------------|
|
|
||||||
| 强上下文关联 | “怎么升级旧版工程” | {{"rewrite":"【西藏Z1】如何执行【老版本定额升级】?"}} |
|
|
||||||
| 弱术语匹配 | “界面文字太小怎么办” | 原样输出 |
|
|
||||||
| 代词+背景继承 | “这个定额如何导入” | {{"rewrite":"【山东定额】如何执行【批量导入定额】?"}}|
|
|
||||||
|
|
||||||
## 质量自检
|
## 质量自检
|
||||||
|
- [] **主题是否合理继承?**(当有代词时,历史主题必须注入)
|
||||||
- [] 核心诉求是否保留?
|
- [] 核心诉求是否保留?
|
||||||
- [] 背景信息是否合理补充?
|
- [] 背景信息是否合理补充?
|
||||||
- [] 术语标记是否完整【】?
|
- [] 术语标记是否完整【】?
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ import requests
|
|||||||
|
|
||||||
API_KEY_LIST=[
|
API_KEY_LIST=[
|
||||||
"sk-kvgfuqeqvpmfsccykyoohheshclcrtvjlnewratvrjpkpbkc",
|
"sk-kvgfuqeqvpmfsccykyoohheshclcrtvjlnewratvrjpkpbkc",
|
||||||
"sk-zhnbqnpuumuuvegnvbgoggxafpukbzchpgrugpkobiwkzsar",
|
|
||||||
"sk-kzhxlqvqcxlnbdgnpalqnzumkmspepkttkgbophnkqanainw",
|
"sk-kzhxlqvqcxlnbdgnpalqnzumkmspepkttkgbophnkqanainw",
|
||||||
"sk-bzttugqtlskrvguvhckwamdssvgmgnrqpsialpdbskfsyyak",
|
"sk-bzttugqtlskrvguvhckwamdssvgmgnrqpsialpdbskfsyyak",
|
||||||
"sk-tovmogiablsoeabwgqyvevpcfichyjpuzqdymmvksspdrtqt",
|
"sk-tovmogiablsoeabwgqyvevpcfichyjpuzqdymmvksspdrtqt",
|
||||||
|
|||||||
Reference in New Issue
Block a user