Files
QueryRewrite/rag2_0/intent_recognition/IntentRecognition.py
T

870 lines
37 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
File: IntentRecognition.py
Author: oyyz
Date: 2025-05-13
Description: 意图分类、改写核心逻辑
"""
import logging
import os
import threading
from langchain.output_parsers import PydanticOutputParser
import json
from typing import List, Tuple, Dict, Any, Optional
import re
import jieba
import time
from .PromptTemplates import (classification_prompt, query_rewrite_prompt,
extract_nouns_prompt, classification_info,
slot_filling_prompt, step_back_prompt,
follow_up_questions_prompt, hyde_prompt, multi_questions_prompt)
from .Multi_PromptTemplates import (
intent_and_slot_prompt, output_example,
generate_slot_mapping_doc, query_rewrite_prompt_pro,
)
from .DataModels import (
Classification, QueryRewrite, Term, TermList,
SoftwareFunctionSlots, SoftwareTroubleShootingSlots, ProfessionalConsultingSlots,
DataProblemSlots, FileExtensionConsultingSlots, SoftwareLockSlots,
InstallationDownloadSlots, ProblemDiagnosisSlots, OtherSlots, IntentAndSlotResult,
StepBackPrompt, FollowUpQuestions, HypotheticalDocument, MultiQuestions
)
from .ProfessionalNounVector import ProfessionalNounRetriever
from rag2_0.tool.ModelTool import XinferenceReRankerModel, OpenAiLLM, SiliconFlowReRankerModel
class IntentRecognizer:
"""
意图识别和问题改写类
"""
def __init__(self, api_key: str = None, base_url: str = None, model_name: str = "gpt-3.5-turbo", vector_index_dir: str = None):
"""
初始化意图识别器
Args:
api_key: OpenAI API密钥,如果为None则从环境变量获取
base_url: OpenAI API基础URL,如果为None则使用默认URL
model_name: 要使用的模型名称
vector_index_dir: 向量索引目录,如果为None则使用默认目录
"""
# 初始化LLM
llm_params = {
"temperature": 0.2, # 降低随机性,使结果更确定
"top_p": 0.7,
"model": model_name
}
# 如果提供了API密钥,则使用提供的密钥
if api_key:
llm_params["api_key"] = api_key
# 如果提供了自定义URL,则使用提供的URL
if base_url:
llm_params["base_url"] = base_url
self._llm = OpenAiLLM(**llm_params)
# 加载suffix关键词
self._suffix_keywords = self._load_suffix_keywords()
# 初始化向量检索器
self._noun_retriever = ProfessionalNounRetriever(api_key=api_key, index_dir=vector_index_dir)
def _load_suffix_keywords(self, filepath: str = None) -> List[str]:
"""
加载后缀关键词列表
Args:
filepath: 后缀关键词文件路径,默认为None使用默认路径
Returns:
后缀关键词列表
"""
try:
# 如果未指定路径,使用默认路径
if filepath is None:
current_dir = os.path.dirname(os.path.abspath(__file__))
filepath = os.path.join(current_dir, "..", "..", "data", "nouns", "suffix_keywords.json")
# 读取JSON文件
with open(filepath, "r", encoding="utf-8") as f:
suffix_data = json.load(f)
# 添加额外的固定后缀
return suffix_data
except Exception as e:
raise RuntimeError(f"加载后缀关键词失败: {e}") from e
def _classify_intent(self, query: str, conversation_context: str = "",
chat_history: List[Dict[str, str]] = None,
previous_slots: Dict[str, Any] = None) -> Classification:
"""
对用户输入进行意图分类
Args:
content: 用户输入内容
keywords: 匹配到的关键词列表
rewrite: 重写的问题
Returns:
分类结果
"""
classification_start_time = time.time()
classification_parser = PydanticOutputParser(pydantic_object=Classification)
formatted_prompt = classification_prompt.format(user_input=query,
classification_info=classification_info,
output_format=classification_parser.get_format_instructions(),
conversation_context=conversation_context,
chat_history=json.dumps(chat_history, ensure_ascii=False))
# 解析输出
try:
# 调用LLM
response = self._llm.invoke(formatted_prompt, False)
classification_end_time = time.time()
classification_time = classification_end_time - classification_start_time
logging.info(f"意图分类耗时统计 - 总耗时: {classification_time:.2f}秒")
# 尝试直接解析JSON响应
parsed_output = classification_parser.parse(response.content.strip())
return parsed_output
except Exception as e:
raise RuntimeError(f"解析分类结果时出错: {e}") from e
def _tokenize_with_jieba(self, query: str) -> List[str]:
"""
使用jieba分词器对查询进行分词
Args:
query: 用户查询
Returns:
分词后的词语列表
"""
# 使用jieba进行分词
seg_list = jieba.cut(query, cut_all=False)
# 过滤掉停用词和标点符号
filtered_tokens = []
for token in seg_list:
# 过滤掉空格和标点符号
if token.strip() and not re.match(r'^[^\w\s]$', token):
filtered_tokens.append(token)
return filtered_tokens
def _extract_keywords_with_llm(self, query: str, use_jieba: bool = False) -> List[Term]:
"""
使用LLM从用户查询中提取专业关键词
Args:
query: 用户查询
use_jieba: 是否使用jieba分词辅助提取关键词
Returns:
提取的术语列表
"""
# 如果使用jieba分词
if use_jieba:
# 先使用jieba分词
tokens = self._tokenize_with_jieba(query)
# 构建术语列表
terms = []
for token in tokens:
if len(token) > 1: # 过滤掉单字词
terms.append(Term(name=token, synonymous=[], description=""))
return terms
else:
# 使用LLM提取关键词
# 准备提示词
formatted_prompt = extract_nouns_prompt.replace("{content}", query)
terms_list_parser = PydanticOutputParser(pydantic_object=TermList)
formatted_prompt = formatted_prompt.replace("{output_format}", terms_list_parser.get_format_instructions())
# 调用LLM
response = self._llm.invoke(formatted_prompt, False)
# 尝试使用Pydantic解析器解析TermList
parsed_output = terms_list_parser.parse(response.content)
return parsed_output.terms
def _rerank_matched_terms(self, query_key: str, matched_terms: set, top_k: int = 2, rerank_score:float = 0.6) -> List[Term]:
"""
对召回的专业术语进行重排序,按与用户查询的相关性排序
Args:
query: 用户查询
matched_terms: 匹配到的专业术语集合
query_keys: 用户查询中提取的关键词列表
Returns:
重排序后的专业术语列表
"""
if not matched_terms:
return []
if len(matched_terms) <= top_k:
return list(matched_terms)
try:
# 将每个术语转换为可用于重排序的文本表示
# term_texts = ["名称:" + term.name + "|" + "同义词:" + ";".join(term.synonymous) + "|" + "描述:" + term.description for term in matched_terms]
term_texts = ["名称:" + term.name + "|" + "同义词:" + ";".join(term.synonymous) for term in matched_terms]
# 使用重排序模型
xinference_reranker = XinferenceReRankerModel()
rerank_results = xinference_reranker.rerank(query_key, term_texts, top_k=top_k)
# 将matched_terms转换为列表以便按索引访问
matched_terms_list = list(matched_terms)
# 根据重排序结果获取排序后的术语列表
reranked_terms = [matched_terms_list[result["index"]] for result in rerank_results if result["score"] >= rerank_score]
return reranked_terms
except Exception as e:
raise RuntimeError(f"_rerank_matched_terms重排失败:{e}") from e
def _match_keywords(self, query: str, use_jieba: bool = False) -> Tuple[TermList, List[str]]:
"""
从用户问题中匹配关键词,结合LLM提取和向量检索
Args:
query: 用户问题
use_jieba: 是否使用jieba分词辅助提取关键词
Returns:
匹配到的关键词列表
"""
start_time = time.time()
query_keys=[]
# 步骤1: 使用LLM提取查询中的关键词
try:
llm_start_time = time.time()
extracted_terms = self._extract_keywords_with_llm(query, use_jieba)
for term in extracted_terms:
query_keys.append(term.name)
llm_end_time = time.time()
llm_time = llm_end_time - llm_start_time
except Exception as e:
raise RuntimeError(f"LLM关键词提取失败: {e}") from e
matched_terms = [] # 存储匹配到的Term对象
# 步骤2: 使用向量检索找到相似的专业名词
try:
vector_start_time = time.time()
# 对matched_terms中的每个关键字进行向量检索
for current_key in query_keys:
vector_results = self._noun_retriever.query(current_key, top_k=5, use_intersection=False)
current_key_terms = set()
# 添加向量检索结果
for result in vector_results:
if isinstance(result.get('synonymous', []), str):
result['synonymous'] = result['synonymous'].split(';')
term = Term(
name=result.get('name'),
synonymous=result.get('synonymous', []),
description=result.get('description', '')
)
current_key_terms.add(term)
if len(current_key_terms) > 0:
reranked_terms = self._rerank_matched_terms(current_key, current_key_terms)
matched_terms.extend(reranked_terms)
vector_end_time = time.time()
vector_time = vector_end_time - vector_start_time
except Exception as e:
raise RuntimeError(f"向量检索关键词时出错: {e}") from e
# 提取所有Term对象的名称并排序
# 将set类型的matched_terms转换为TermList类型
term_list = TermList(terms=list(matched_terms))
end_time = time.time()
total_time = end_time - start_time
# 输出整合的时间日志
logging.info(f"关键词匹配耗时统计 - 总耗时: {total_time:.2f}秒, 问题关键词提取: {llm_time:.2f}秒, 向量检索+重排序: {vector_time:.2f}秒")
return term_list, query_keys
def _rewrite_query(self, query: str, keywords: TermList, query_keys:List[str], chat_history: List[Dict[str, str]] = None, context: str = "") -> QueryRewrite:
"""
对用户问题进行改写
Args:
query: 用户原始问题
keywords: 匹配到的关键词列表
query_keys: 用户查询中提取的关键词列表
Returns:
改写结果
"""
rewrite_start_time = time.time()
# 准备问题改写提示
terms_dict = [term.model_dump(exclude={"description"}) for term in keywords.terms]
# terms_dict = [term.model_dump() for term in keywords.terms]
keywords_str = json.dumps(terms_dict, ensure_ascii=False)
query_rewrite_parser = PydanticOutputParser(pydantic_object=QueryRewrite)
# formatted_prompt = query_rewrite_prompt.format(query=query,
# output_format=query_rewrite_parser.get_format_instructions(),
# keywords=keywords_str)
formatted_prompt = query_rewrite_prompt_pro.format(query=query,
output_format=query_rewrite_parser.get_format_instructions(),
keywords=keywords_str,
chat_history=chat_history,
context=context)
# 解析输出
try:
# 调用LLM
response = self._llm.invoke(formatted_prompt, False)
# 尝试直接解析JSON响应
parsed_output = query_rewrite_parser.parse(response.content)
rewrite_end_time = time.time()
rewrite_time = rewrite_end_time - rewrite_start_time
logging.info(f"问题改写耗时统计 - 总耗时: {rewrite_time:.2f}秒")
return parsed_output
except Exception as e:
raise RuntimeError(f"解析问题改写结果时出错: {e}") from e
def _judge_define_suffix(self, input_str: str) -> Tuple[bool, List[str]]:
"""
判断输入字符串是否包含定义的后缀,并返回所有匹配到的后缀名列表
Args:
input_str: 输入字符串
Returns:
Tuple[bool, List[str]]: (是否包含定义的后缀, 匹配到的后缀名列表)
"""
# 构建正则表达式模式,匹配大小写不敏感且前面可能带有.
pattern = r'(?:\.?)(' + '|'.join(re.escape(field.get('name')) for field in self._suffix_keywords) + r')'
# 使用 re.IGNORECASE 标志来忽略大小写,findall找到所有匹配
matches = re.finditer(pattern, input_str, re.IGNORECASE)
matched_suffixes = [match.group(1) for match in matches]
return bool(matched_suffixes), matched_suffixes
def process_query(self, query: str, conversation_context: str = "",
chat_history: List[Dict[str, str]] = None,
previous_slots: Dict[str, Any] = None,
use_jieba: bool = False,
enable_query_expansion: bool = False) -> Dict[str, Any]:
"""
处理用户问题的完整流程
Args:
query: 用户原始问题
conversation_context: 会话背景信息
chat_history: 历史对话记录,格式为[{"user": "content"}, {"assistant": "content"}]
previous_slots: 历史槽位信息
use_jieba: 是否使用jieba分词辅助提取关键词
Returns:
包含分类、关键词、改写和槽位填充结果的字典
"""
# 是否是扩展名
# is_suffix, matched_suffixes = self._judge_define_suffix(query)
# if is_suffix:
# # 将所有匹配到的后缀名作为Term添加到结果中
# suffix_terms = []
# for suffix in matched_suffixes:
# term_dict = next((item for item in self._suffix_keywords if item['name'].lower() == suffix.lower()), None)
# if term_dict:
# suffix_term = Term(
# name=term_dict.get('name'),
# synonymous=term_dict.get('synonymous', []),
# description=json.dumps(term_dict.get('description', ''), ensure_ascii=False)
# )
# suffix_terms.append(suffix_term)
# return Classification(vertical_classification="安装下载", sub_classification="查询"), TermList(terms=suffix_terms), QueryRewrite(rewrite=query), matched_suffixes
if chat_history is None:
chat_history = []
if previous_slots is None:
previous_slots = {}
# 步骤: 并行执行提问扩展
if enable_query_expansion:
# 创建线程和结果容器
threads_and_results = [
# 5.1: 后退提示
self._run_in_thread(self._generate_step_back_prompt, args=(query, chat_history, conversation_context)),
# 5.2: Follow Up Questions
self._run_in_thread(self._generate_follow_up_questions, args=(query, chat_history, conversation_context)),
# 5.3: HyDE
self._run_in_thread(self._generate_hypothetical_document, args=(query, chat_history, conversation_context)),
# 5.4: 多问题查询
self._run_in_thread(self._generate_multi_questions, args=(query, chat_history, conversation_context))
]
# 步骤1: 匹配关键词
keywords_terms, query_keys = self._match_keywords(query, use_jieba)
# 步骤2: 问题改写
rewrite = self._rewrite_query(
query=query,
keywords=keywords_terms,
query_keys=query_keys,
chat_history=chat_history,
context=conversation_context
)
# 步骤3: 进行意图识别和槽位填充
# result = self._process_intent_and_slot(rewrite.rewrite, conversation_context, chat_history, previous_slots)
# result.update({"keywords": keywords_terms.model_dump(),
# "rewrite": rewrite.model_dump(),
# "query_keys": query_keys})
# return result
# 步骤3: 进行意图分类
classification = self._classify_intent(rewrite.rewrite, conversation_context, chat_history, previous_slots)
# 步骤4: 进行槽位填充
# 如果是有效分类,进行槽位填充
slot_filling_result = {}
if classification.vertical_classification not in ["其他", "闲聊"] and classification.sub_classification not in ["其他", "闲聊"]:
slot_filling_result = self._fill_slots(rewrite.rewrite, classification, conversation_context, chat_history, previous_slots)
if not enable_query_expansion:
return {
"classification": classification.model_dump(),
"keywords": keywords_terms.model_dump(),
"rewrite": rewrite.model_dump(),
"query_keys": query_keys,
"slot_filling": slot_filling_result
}
# 等待所有线程完成
start_time = time.time()
for thread, _ in threads_and_results:
thread.join()
end_time = time.time()
logging.info(f"问题扩展环节耗时统计 - 总耗时: {end_time - start_time:.2f}秒")
# 收集结果
step_back_result = threads_and_results[0][1][0] if threads_and_results[0][1] else StepBackPrompt(original_query=query, step_back_query=query)
follow_up_result = threads_and_results[1][1][0] if threads_and_results[1][1] else FollowUpQuestions(original_query=query, follow_up_query=query)
hyde_result = threads_and_results[2][1][0] if threads_and_results[2][1] else HypotheticalDocument(original_query=query, hypothetical_answer="")
multi_questions_result = threads_and_results[3][1][0] if threads_and_results[3][1] else MultiQuestions(original_query=query, sub_questions=[query])
all_questions=multi_questions_result.sub_questions
all_questions.append(query)
all_questions.append(step_back_result.step_back_query)
all_questions.append(follow_up_result.follow_up_query)
all_questions.append(hyde_result.hypothetical_answer)
all_questions = list(set(all_questions))
query_expand={"all":all_questions,
"step_back":step_back_result.model_dump(),
"follow_up":follow_up_result.model_dump(),
"hyde":hyde_result.model_dump(),
"multi_questions":multi_questions_result.model_dump()}
# 返回所有结果
return {
"classification": classification.model_dump(),
"keywords": keywords_terms.model_dump(),
"rewrite": rewrite.model_dump(),
"query_keys": query_keys,
"slot_filling": slot_filling_result,
"query_expand": query_expand
}
def _fill_slots(self, query: str, classification: Classification, conversation_context: str = "",
chat_history: List[Dict[str, str]] = None,
previous_slots: Dict[str, Any] = None,) -> Dict[str, Any]:
"""
根据分类结果对问题进行槽位填充
Args:
query: 用户原始问题
classification: 意图分类结果
keywords: 匹配的关键词列表
Returns:
填充后的槽位数据模型
"""
# 根据分类结果选择对应的数据模型
slot_model = self._get_slot_model(classification)
if not slot_model:
raise RuntimeError("未找到匹配的槽位模型")
fill_slots_start_time = time.time()
# 使用LLM进行槽位填充
filled_slots = self._fill_slots_with_llm(query, classification, slot_model, conversation_context, chat_history, previous_slots)
fill_slots_end_time = time.time()
fill_slots_time = fill_slots_end_time - fill_slots_start_time
logging.info(f"槽位填充耗时统计 - 总耗时: {fill_slots_time:.2f}秒")
# 检查必填槽位是否都已填充
is_complete, missing_slots = filled_slots.check_required_slots()
return {
"is_complete": is_complete,
"missing_slots": missing_slots,
"filled_data": filled_slots.model_dump()
}
def _get_slot_model(self, classification: Classification) -> Optional[type]:
"""
根据分类结果获取对应的槽位模型类,用于统一提示词处理
Args:
classification: 意图分类结果
Returns:
对应的槽位模型类
"""
# 软件问题
if classification.vertical_classification == "软件问题":
if classification.sub_classification == "软件功能":
return SoftwareFunctionSlots
elif classification.sub_classification == "故障排查":
return SoftwareTroubleShootingSlots
# 业务问题
elif classification.vertical_classification == "业务问题":
if classification.sub_classification == "专业咨询":
return ProfessionalConsultingSlots
elif classification.sub_classification == "数据问题":
return DataProblemSlots
# 安装下载注册
elif classification.vertical_classification == "安装下载注册":
if classification.sub_classification == "后缀名咨询":
return FileExtensionConsultingSlots
elif classification.sub_classification == "软件锁类":
return SoftwareLockSlots
elif classification.sub_classification == "安装下载类":
return InstallationDownloadSlots
elif classification.sub_classification == "问题排查类":
return ProblemDiagnosisSlots
# 其他
elif classification.vertical_classification == "其他":
return OtherSlots
return None
def _fill_slots_with_llm(self, query: str,
classification: Classification,
slot_model_class: type,
conversation_context: str = "",
chat_history: List[Dict[str, str]] = None,
previous_slots: Dict[str, Any] = None) -> Any:
"""
使用LLM进行槽位填充
Args:
query: 用户原始问题
classification: 意图分类结果
slot_model_class: 槽位模型类
Returns:
填充后的槽位数据模型实例
"""
# 准备提示词
slot_parser = PydanticOutputParser(pydantic_object=slot_model_class)
formatted_prompt = slot_filling_prompt.format(
query=query,
vertical_classification=classification.vertical_classification,
sub_classification=classification.sub_classification,
output_format=slot_parser.get_format_instructions(),
conversation_context=conversation_context,
chat_history=json.dumps(chat_history,ensure_ascii=False),
previous_slots=json.dumps(previous_slots,ensure_ascii=False),
)
try:
# 调用LLM
response = self._llm.invoke(formatted_prompt, False)
# 尝试解析LLM响应
parsed_output = slot_parser.parse(response.content)
return parsed_output
except Exception as e:
# 如果解析失败,创建一个空的模型实例
empty_instance = slot_model_class()
return empty_instance
def _generate_step_back_prompt(self, query: str, chat_history: List[Dict[str, str]] = None, conversation_context: str = "") -> StepBackPrompt:
"""
生成后退提示
Args:
query: 用户原始问题
chat_history: 历史对话记录
conversation_context: 会话背景信息
Returns:
后退提示结果
"""
step_back_start_time = time.time()
# 准备提示词
step_back_parser = PydanticOutputParser(pydantic_object=StepBackPrompt)
formatted_prompt = step_back_prompt.format(
query=query,
chat_history=json.dumps(chat_history, ensure_ascii=False) if chat_history else "[]",
conversation_context=conversation_context,
output_format=step_back_parser.get_format_instructions()
)
try:
# 调用LLM
response = self._llm.invoke(formatted_prompt, False)
# 解析输出
parsed_output = step_back_parser.parse(response.content)
step_back_end_time = time.time()
step_back_time = step_back_end_time - step_back_start_time
logging.debug(f"后退提示生成耗时统计 - 总耗时: {step_back_time:.2f}秒")
return parsed_output
except Exception as e:
# 如果解析失败,返回原始查询作为后退提示
logging.error(f"后退提示生成失败: {e}")
return StepBackPrompt(original_query=query, step_back_query=query)
def _generate_follow_up_questions(self, query: str, chat_history: List[Dict[str, str]] = None, conversation_context: str = "") -> FollowUpQuestions:
"""
生成后续问题
Args:
query: 用户原始问题
chat_history: 历史对话记录
conversation_context: 会话背景信息
Returns:
后续问题结果
"""
follow_up_start_time = time.time()
# 准备提示词
follow_up_parser = PydanticOutputParser(pydantic_object=FollowUpQuestions)
formatted_prompt = follow_up_questions_prompt.format(
query=query,
chat_history=json.dumps(chat_history, ensure_ascii=False) if chat_history else "[]",
conversation_context=conversation_context,
output_format=follow_up_parser.get_format_instructions()
)
try:
# 调用LLM
response = self._llm.invoke(formatted_prompt, False)
# 解析输出
parsed_output = follow_up_parser.parse(response.content)
follow_up_end_time = time.time()
follow_up_time = follow_up_end_time - follow_up_start_time
logging.debug(f"后续问题生成耗时统计 - 总耗时: {follow_up_time:.2f}秒")
return parsed_output
except Exception as e:
# 如果解析失败,返回原始查询作为后续问题
logging.error(f"后续问题生成失败: {e}")
return FollowUpQuestions(original_query=query, follow_up_query=query)
def _generate_hypothetical_document(self, query: str, chat_history: List[Dict[str, str]] = None, conversation_context: str = "") -> HypotheticalDocument:
"""
生成假设性文档
Args:
query: 用户原始问题
chat_history: 历史对话记录
conversation_context: 会话背景信息
Returns:
假设性文档结果
"""
hyde_start_time = time.time()
# 准备提示词
hyde_parser = PydanticOutputParser(pydantic_object=HypotheticalDocument)
formatted_prompt = hyde_prompt.format(
query=query,
chat_history=json.dumps(chat_history, ensure_ascii=False) if chat_history else "[]",
conversation_context=conversation_context,
output_format=hyde_parser.get_format_instructions()
)
try:
# 调用LLM
response = self._llm.invoke(formatted_prompt, False)
# 解析输出
parsed_output = hyde_parser.parse(response.content)
hyde_end_time = time.time()
hyde_time = hyde_end_time - hyde_start_time
logging.debug(f"假设性文档生成耗时统计 - 总耗时: {hyde_time:.2f}秒")
return parsed_output
except Exception as e:
# 如果解析失败,返回空的假设性回答
logging.error(f"假设性文档生成失败: {e}")
return HypotheticalDocument(original_query=query, hypothetical_answer="")
def _generate_multi_questions(self, query: str, chat_history: List[Dict[str, str]] = None, conversation_context: str = "") -> MultiQuestions:
"""
生成多角度问题
Args:
query: 用户原始问题
chat_history: 历史对话记录
conversation_context: 会话背景信息
Returns:
多角度问题结果
"""
multi_questions_start_time = time.time()
# 准备提示词
multi_questions_parser = PydanticOutputParser(pydantic_object=MultiQuestions)
formatted_prompt = multi_questions_prompt.format(
query=query,
chat_history=json.dumps(chat_history, ensure_ascii=False) if chat_history else "[]",
conversation_context=conversation_context,
output_format=multi_questions_parser.get_format_instructions()
)
try:
# 调用LLM
response = self._llm.invoke(formatted_prompt, False)
# 解析输出
parsed_output = multi_questions_parser.parse(response.content)
multi_questions_end_time = time.time()
multi_questions_time = multi_questions_end_time - multi_questions_start_time
logging.debug(f"多角度问题生成耗时统计 - 总耗时: {multi_questions_time:.2f}秒")
return parsed_output
except Exception as e:
# 如果解析失败,返回原始查询作为唯一子问题
logging.error(f"多角度问题生成失败: {e}LLM返回内容:{response.content}")
return MultiQuestions(original_query=query, sub_questions=[query])
def _run_in_thread(self, func, args=(), kwargs={}):
"""
在线程中执行函数并返回结果
Args:
func: 要执行的函数
args: 函数的位置参数
kwargs: 函数的关键字参数
Returns:
(thread, result_container): 线程对象和存放结果的容器
"""
result_container = []
def thread_target():
try:
result = func(*args, **kwargs)
result_container.append(result)
except Exception as e:
logging.error(f"线程执行函数 {func.__name__} 时出错: {e}")
result_container.append(None)
thread = threading.Thread(target=thread_target)
thread.start()
return thread, result_container
def _process_intent_and_slot(self, user_input: str, conversation_context: str = "",
chat_history: List[Dict[str, str]] = None,
previous_slots: Dict[str, Any] = None) -> Dict[str, Any]:
"""
使用统一提示词同时进行意图识别和槽位填充
Args:
user_input: 当前用户输入
conversation_context: 会话背景信息
chat_history: 历史对话记录,格式为[{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
previous_slots: 历史槽位信息
Returns:
包含意图分类和槽位填充结果的字典
"""
# 初始化默认值
if chat_history is None:
chat_history = []
if previous_slots is None:
previous_slots = {}
# 生成槽位映射文档
slot_mapping_doc = generate_slot_mapping_doc()
# 准备提示词
parser = PydanticOutputParser(pydantic_object=IntentAndSlotResult)
formatted_prompt = intent_and_slot_prompt.format(
conversation_context=conversation_context,
chat_history=json.dumps(chat_history, ensure_ascii=False),
previous_slots=json.dumps(previous_slots, ensure_ascii=False),
user_input=user_input,
slot_mapping_doc=slot_mapping_doc,
output_format=parser.get_format_instructions(),
classification_info=classification_info
)
# 调用LLM
llm_start_time = time.time()
response = self._llm.invoke(formatted_prompt + output_example, False)
llm_end_time = time.time()
llm_time = llm_end_time - llm_start_time
try:
# 解析LLM响应为JSON
result_json = parser.parse(response.content)
classification=result_json.classification
slot_filling=result_json.slots
is_complete, missing_slots = slot_filling.check_required_slots()
expected_slot_model = self._get_slot_model(classification)
# 添加容错处理,发生概率较低,但仍需处理
if expected_slot_model is None:
# 添加容错处理,应对LLM返回错误分类信息,一级分类跟二级分类错乱
# 重新分类
classification = self._classify_intent(user_input, conversation_context, chat_history, previous_slots)
fill_slots = self._fill_slots(user_input, classification, conversation_context, chat_history, previous_slots)
result = {
"classification": classification.model_dump(),
"slot_filling": fill_slots
}
logging.warning(f"重新分类与槽点填充")
return result
elif expected_slot_model.__name__ != type(slot_filling).__name__:
# 添加容错处理,应对LLM槽位与分类不匹配。重新填充槽位
slot_filling = self._fill_slots(user_input, classification, conversation_context, chat_history, previous_slots)
result = {
"classification": classification.model_dump(),
"slot_filling": slot_filling
}
logging.warning(f"重新填充槽点")
return result
logging.info(f"意图识别+槽位LLM调用耗时: {llm_time:.2f}秒")
# 构建最终结果
result = {
"classification": classification.model_dump(),
"slot_filling": {
"is_complete": is_complete,
"missing_slots": missing_slots,
"filled_data": slot_filling.model_dump()
}
}
return result
except Exception as e:
logging.error(f"process_intent_and_slot error:{e}")
raise RuntimeError(f"process_intent_and_slot error:{e}") from e