1、删除不再使用的.cursorrules文件
2、更新poetry.lock以反映Poetry版本的变化,添加jieba依赖, 3、重构意图识别逻辑以支持多轮对话,优化槽位填充和意图分类功能,增强代码可读性和维护性。
This commit is contained in:
@@ -7,18 +7,27 @@ Date: 2025-05-13
|
||||
Description: 意图分类、改写核心逻辑
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
import json
|
||||
from typing import List, Tuple, Dict, Any, Optional, Union
|
||||
from typing import List, Tuple, Dict, Any, Optional
|
||||
import re
|
||||
from .PromptTemplates import classification_prompt, query_rewrite_prompt, extract_nouns_prompt, classification_info, slot_filling_prompt
|
||||
import jieba
|
||||
from .PromptTemplates import (classification_prompt, query_rewrite_prompt,
|
||||
extract_nouns_prompt, classification_info,
|
||||
slot_filling_prompt)
|
||||
|
||||
from .Multi_PromptTemplates import (
|
||||
intent_and_slot_prompt, output_example,
|
||||
generate_slot_mapping_doc, query_rewrite_prompt_pro,
|
||||
)
|
||||
|
||||
from .DataModels import (
|
||||
Classification, QueryRewrite, Term, TermList,
|
||||
SoftwareFunction, TroubleShooting, ProfessionalConsulting,
|
||||
DataProblem, FileExtensionConsulting, SoftwareLock,
|
||||
InstallationDownload, ProblemDiagnosis
|
||||
SoftwareFunctionSlots, SoftwareTroubleShootingSlots, ProfessionalConsultingSlots,
|
||||
DataProblemSlots, FileExtensionConsultingSlots, SoftwareLockSlots,
|
||||
InstallationDownloadSlots, ProblemDiagnosisSlots, OtherSlots, IntentAndSlotResult
|
||||
)
|
||||
from .ProfessionalNounVector import ProfessionalNounRetriever
|
||||
from rag2_0.tool.ModelTool import XinferenceReRankerModel, OpenAiLLM, SiliconFlowReRankerModel
|
||||
@@ -52,22 +61,13 @@ class IntentRecognizer:
|
||||
if base_url:
|
||||
llm_params["base_url"] = base_url
|
||||
|
||||
self.llm = OpenAiLLM(**llm_params)
|
||||
|
||||
# 准备分类解析器
|
||||
self.classification_parser = PydanticOutputParser(pydantic_object=Classification)
|
||||
|
||||
# 准备问题改写解析器
|
||||
self.query_rewrite_parser = PydanticOutputParser(pydantic_object=QueryRewrite)
|
||||
|
||||
# 准备术语列表解析器
|
||||
self.terms_list_parser = PydanticOutputParser(pydantic_object=TermList)
|
||||
self._llm = OpenAiLLM(**llm_params)
|
||||
|
||||
# 加载suffix关键词
|
||||
self.suffix_keywords = self._load_suffix_keywords()
|
||||
self._suffix_keywords = self._load_suffix_keywords()
|
||||
|
||||
# 初始化向量检索器
|
||||
self.noun_retriever = ProfessionalNounRetriever(api_key=api_key, index_dir=vector_index_dir)
|
||||
self._noun_retriever = ProfessionalNounRetriever(api_key=api_key, index_dir=vector_index_dir)
|
||||
|
||||
def _load_suffix_keywords(self, filepath: str = None) -> List[str]:
|
||||
"""
|
||||
@@ -95,7 +95,7 @@ class IntentRecognizer:
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"加载后缀关键词失败: {e}") from e
|
||||
|
||||
def classify_intent(self, query: str, keywords: TermList) -> Classification:
|
||||
def _classify_intent(self, query: str) -> Classification:
|
||||
"""
|
||||
对用户输入进行意图分类
|
||||
|
||||
@@ -106,49 +106,85 @@ class IntentRecognizer:
|
||||
Returns:
|
||||
分类结果
|
||||
"""
|
||||
classification_parser = PydanticOutputParser(pydantic_object=Classification)
|
||||
formatted_prompt = classification_prompt.format(user_input=query,
|
||||
classification_info=classification_info,
|
||||
output_format=self.classification_parser.get_format_instructions())
|
||||
# 将关键词列表转换为JSON字符串
|
||||
terms_dict = [term.model_dump() for term in keywords.terms]
|
||||
keywords_str = json.dumps(terms_dict, ensure_ascii=False)
|
||||
formatted_prompt = formatted_prompt.replace("{keywords}", keywords_str)
|
||||
output_format=classification_parser.get_format_instructions())
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm.invoke(formatted_prompt, False)
|
||||
response = self._llm.invoke(formatted_prompt, False)
|
||||
|
||||
# 解析输出
|
||||
try:
|
||||
# 尝试直接解析JSON响应
|
||||
parsed_output = self.classification_parser.parse(response.content.strip())
|
||||
parsed_output = classification_parser.parse(response.content.strip())
|
||||
return parsed_output
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"解析分类结果时出错: {e}") from e
|
||||
|
||||
def extract_keywords_with_llm(self, query: str) -> List[Term]:
|
||||
def _tokenize_with_jieba(self, query: str) -> List[str]:
|
||||
"""
|
||||
使用LLM从用户查询中提取专业关键词
|
||||
使用jieba分词器对查询进行分词
|
||||
|
||||
Args:
|
||||
query: 用户查询
|
||||
|
||||
Returns:
|
||||
分词后的词语列表
|
||||
"""
|
||||
# 使用jieba进行分词
|
||||
seg_list = jieba.cut(query, cut_all=False)
|
||||
|
||||
# 过滤掉停用词和标点符号
|
||||
filtered_tokens = []
|
||||
for token in seg_list:
|
||||
# 过滤掉空格和标点符号
|
||||
if token.strip() and not re.match(r'^[^\w\s]$', token):
|
||||
filtered_tokens.append(token)
|
||||
|
||||
return filtered_tokens
|
||||
|
||||
def _extract_keywords_with_llm(self, query: str, use_jieba: bool = False) -> List[Term]:
|
||||
"""
|
||||
使用LLM从用户查询中提取专业关键词
|
||||
|
||||
Args:
|
||||
query: 用户查询
|
||||
use_jieba: 是否使用jieba分词辅助提取关键词
|
||||
|
||||
Returns:
|
||||
提取的术语列表
|
||||
"""
|
||||
# 准备提示词
|
||||
formatted_prompt = extract_nouns_prompt.replace("{content}", query)
|
||||
formatted_prompt = formatted_prompt.replace("{output_format}", self.terms_list_parser.get_format_instructions())
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm.invoke(formatted_prompt, False)
|
||||
|
||||
try:
|
||||
# 尝试使用Pydantic解析器解析TermList
|
||||
parsed_output = self.terms_list_parser.parse(response.content)
|
||||
return parsed_output.terms
|
||||
# 如果使用jieba分词
|
||||
if use_jieba:
|
||||
# 先使用jieba分词
|
||||
tokens = self._tokenize_with_jieba(query)
|
||||
|
||||
# 构建术语列表
|
||||
terms = []
|
||||
for token in tokens:
|
||||
if len(token) > 1: # 过滤掉单字词
|
||||
terms.append(Term(name=token, synonymous=[], description=""))
|
||||
|
||||
return terms
|
||||
else:
|
||||
# 使用LLM提取关键词
|
||||
# 准备提示词
|
||||
formatted_prompt = extract_nouns_prompt.replace("{content}", query)
|
||||
terms_list_parser = PydanticOutputParser(pydantic_object=TermList)
|
||||
formatted_prompt = formatted_prompt.replace("{output_format}", terms_list_parser.get_format_instructions())
|
||||
|
||||
# 调用LLM
|
||||
response = self._llm.invoke(formatted_prompt, False)
|
||||
|
||||
# 尝试使用Pydantic解析器解析TermList
|
||||
parsed_output = terms_list_parser.parse(response.content)
|
||||
return parsed_output.terms
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"无法解析LLM关键词提取响应: {e}") from e
|
||||
|
||||
def rerank_matched_terms(self, query_key: str, matched_terms: set, top_k: int = 2) -> List[Term]:
|
||||
|
||||
def _rerank_matched_terms(self, query_key: str, matched_terms: set, top_k: int = 2) -> List[Term]:
|
||||
"""
|
||||
对召回的专业术语进行重排序,按与用户查询的相关性排序
|
||||
|
||||
@@ -182,31 +218,32 @@ class IntentRecognizer:
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"SiliconFlowReRankerModel重排失败:{e}") from e
|
||||
|
||||
def match_keywords(self, query: str) -> Tuple[TermList, List[str]]:
|
||||
def _match_keywords(self, query: str, use_jieba: bool = False) -> Tuple[TermList, List[str]]:
|
||||
"""
|
||||
从用户问题中匹配关键词,结合LLM提取和向量检索
|
||||
|
||||
Args:
|
||||
query: 用户问题
|
||||
use_jieba: 是否使用jieba分词辅助提取关键词
|
||||
|
||||
Returns:
|
||||
匹配到的关键词列表
|
||||
"""
|
||||
query_keys=[]
|
||||
# 步骤2: 使用LLM提取查询中的关键词
|
||||
# 步骤1: 使用LLM提取查询中的关键词
|
||||
try:
|
||||
extracted_terms = self.extract_keywords_with_llm(query)
|
||||
extracted_terms = self._extract_keywords_with_llm(query, use_jieba)
|
||||
for term in extracted_terms:
|
||||
query_keys.append(term.name)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"LLM关键词提取失败: {e}") from e
|
||||
|
||||
matched_terms = [] # 存储匹配到的Term对象
|
||||
# 步骤3: 使用向量检索找到相似的专业名词
|
||||
# 步骤2: 使用向量检索找到相似的专业名词
|
||||
try:
|
||||
# 对matched_terms中的每个关键字进行向量检索
|
||||
for current_key in query_keys:
|
||||
vector_results = self.noun_retriever.query(current_key, top_k=5, use_intersection=False)
|
||||
vector_results = self._noun_retriever.query(current_key, top_k=5, use_intersection=False)
|
||||
current_key_terms = set()
|
||||
# 添加向量检索结果
|
||||
for result in vector_results:
|
||||
@@ -218,8 +255,9 @@ class IntentRecognizer:
|
||||
description=result.get('description', '')
|
||||
)
|
||||
current_key_terms.add(term)
|
||||
reranked_terms = self.rerank_matched_terms(current_key, current_key_terms)
|
||||
matched_terms.extend(reranked_terms)
|
||||
if len(current_key_terms) > 0:
|
||||
reranked_terms = self._rerank_matched_terms(current_key, current_key_terms)
|
||||
matched_terms.extend(reranked_terms)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"向量检索关键词时出错: {e}") from e
|
||||
|
||||
@@ -228,7 +266,7 @@ class IntentRecognizer:
|
||||
term_list = TermList(terms=list(matched_terms))
|
||||
return term_list, query_keys
|
||||
|
||||
def rewrite_query(self, query: str, keywords: TermList) -> QueryRewrite:
|
||||
def _rewrite_query(self, query: str, keywords: TermList, chat_history: List[Dict[str, str]] = None, context: str = "") -> QueryRewrite:
|
||||
"""
|
||||
对用户问题进行改写
|
||||
|
||||
@@ -242,23 +280,28 @@ class IntentRecognizer:
|
||||
# 准备问题改写提示
|
||||
terms_dict = [term.model_dump(exclude={"description"}) for term in keywords.terms]
|
||||
keywords_str = json.dumps(terms_dict, ensure_ascii=False)
|
||||
formatted_prompt = query_rewrite_prompt.format(query=query,
|
||||
output_format=self.query_rewrite_parser.get_format_instructions(),
|
||||
keywords=keywords_str)
|
||||
|
||||
query_rewrite_parser = PydanticOutputParser(pydantic_object=QueryRewrite)
|
||||
# formatted_prompt = query_rewrite_prompt.format(query=query,
|
||||
# output_format=query_rewrite_parser.get_format_instructions(),
|
||||
# keywords=keywords_str)
|
||||
formatted_prompt = query_rewrite_prompt_pro.format(query=query,
|
||||
output_format=query_rewrite_parser.get_format_instructions(),
|
||||
keywords=keywords_str,
|
||||
chat_history=chat_history,
|
||||
context=context)
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm.invoke(formatted_prompt, False)
|
||||
response = self._llm.invoke(formatted_prompt, False)
|
||||
|
||||
# 解析输出
|
||||
try:
|
||||
# 尝试直接解析JSON响应
|
||||
parsed_output = self.query_rewrite_parser.parse(response.content)
|
||||
parsed_output = query_rewrite_parser.parse(response.content)
|
||||
return parsed_output
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"解析问题改写结果时出错: {e}") from e
|
||||
|
||||
def judge_define_suffix(self, input_str: str) -> Tuple[bool, List[str]]:
|
||||
def _judge_define_suffix(self, input_str: str) -> Tuple[bool, List[str]]:
|
||||
"""
|
||||
判断输入字符串是否包含定义的后缀,并返回所有匹配到的后缀名列表
|
||||
|
||||
@@ -270,7 +313,7 @@ class IntentRecognizer:
|
||||
"""
|
||||
|
||||
# 构建正则表达式模式,匹配大小写不敏感且前面可能带有.
|
||||
pattern = r'(?:\.?)(' + '|'.join(re.escape(field.get('name')) for field in self.suffix_keywords) + r')'
|
||||
pattern = r'(?:\.?)(' + '|'.join(re.escape(field.get('name')) for field in self._suffix_keywords) + r')'
|
||||
|
||||
# 使用 re.IGNORECASE 标志来忽略大小写,findall找到所有匹配
|
||||
matches = re.finditer(pattern, input_str, re.IGNORECASE)
|
||||
@@ -278,23 +321,30 @@ class IntentRecognizer:
|
||||
|
||||
return bool(matched_suffixes), matched_suffixes
|
||||
|
||||
def process_query(self, query: str) -> Tuple[Classification, TermList, QueryRewrite, List[str]]:
|
||||
def process_query(self, query: str, conversation_context: str = "",
|
||||
chat_history: List[Dict[str, str]] = None,
|
||||
previous_slots: Dict[str, Any] = None,
|
||||
use_jieba: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
处理用户问题的完整流程
|
||||
|
||||
Args:
|
||||
query: 用户原始问题
|
||||
conversation_context: 会话背景信息
|
||||
chat_history: 历史对话记录,格式为[{"user": "content"}, {"assistant": "content"}]
|
||||
previous_slots: 历史槽位信息
|
||||
use_jieba: 是否使用jieba分词辅助提取关键词
|
||||
|
||||
Returns:
|
||||
(意图分类结果, 匹配的关键词列表, 问题改写结果)的元组
|
||||
包含分类、关键词、改写和槽位填充结果的字典
|
||||
"""
|
||||
# 是否是扩展名
|
||||
# is_suffix, matched_suffixes = self.judge_define_suffix(query)
|
||||
# is_suffix, matched_suffixes = self._judge_define_suffix(query)
|
||||
# if is_suffix:
|
||||
# # 将所有匹配到的后缀名作为Term添加到结果中
|
||||
# suffix_terms = []
|
||||
# for suffix in matched_suffixes:
|
||||
# term_dict = next((item for item in self.suffix_keywords if item['name'].lower() == suffix.lower()), None)
|
||||
# term_dict = next((item for item in self._suffix_keywords if item['name'].lower() == suffix.lower()), None)
|
||||
# if term_dict:
|
||||
# suffix_term = Term(
|
||||
# name=term_dict.get('name'),
|
||||
@@ -306,26 +356,41 @@ class IntentRecognizer:
|
||||
# return Classification(vertical_classification="安装下载", sub_classification="查询"), TermList(terms=suffix_terms), QueryRewrite(rewrite=query), matched_suffixes
|
||||
|
||||
# 步骤1: 匹配关键词
|
||||
keywords_terms, query_keys = self.match_keywords(query)
|
||||
keywords_terms, query_keys = self._match_keywords(query, use_jieba)
|
||||
|
||||
# 步骤2: 问题改写
|
||||
rewrite = self.rewrite_query(
|
||||
rewrite = self._rewrite_query(
|
||||
query=query,
|
||||
keywords=keywords_terms
|
||||
keywords=keywords_terms,
|
||||
chat_history=chat_history,
|
||||
context=conversation_context
|
||||
)
|
||||
|
||||
# 步骤3: 进行意图识别和槽位填充
|
||||
result = self._process_intent_and_slot(query, conversation_context, chat_history, previous_slots)
|
||||
result.update({"keywords": keywords_terms.model_dump(),
|
||||
"rewrite": rewrite.model_dump(),
|
||||
"query_keys": query_keys})
|
||||
return result
|
||||
# # 步骤3: 进行意图分类
|
||||
# classification = self._classify_intent(query)
|
||||
|
||||
# 步骤3: 进行意图分类
|
||||
classification = self.classify_intent(query, keywords_terms)
|
||||
if classification.vertical_classification == "其他" or classification.sub_classification == "其他":
|
||||
return classification, TermList(terms=[]), QueryRewrite(rewrite=query), []
|
||||
# # 步骤4: 进行槽位填充
|
||||
# # 如果是有效分类,进行槽位填充
|
||||
# slot_filling_result = {}
|
||||
# if classification.vertical_classification not in ["其他", "闲聊"] and classification.sub_classification not in ["其他", "闲聊"]:
|
||||
# slot_filling_result = self._fill_slots(rewrite.rewrite, classification)
|
||||
|
||||
# return {
|
||||
# "classification": classification.model_dump(),
|
||||
# "keywords": keywords_terms.model_dump(),
|
||||
# "rewrite": rewrite.model_dump(),
|
||||
# "query_keys": query_keys,
|
||||
# "slot_filling": slot_filling_result
|
||||
# }
|
||||
|
||||
if classification.vertical_classification == "闲聊" or classification.sub_classification == "闲聊":
|
||||
return classification, TermList(terms=[]), QueryRewrite(rewrite=query),[]
|
||||
|
||||
# rewrite = QueryRewrite(rewrite=query)
|
||||
return classification, keywords_terms, rewrite, query_keys
|
||||
|
||||
def fill_slots(self, query: str, classification: Classification) -> Dict[str, Any]:
|
||||
def _fill_slots(self, query: str, classification: Classification) -> Dict[str, Any]:
|
||||
"""
|
||||
根据分类结果对问题进行槽位填充
|
||||
|
||||
@@ -340,7 +405,7 @@ class IntentRecognizer:
|
||||
# 根据分类结果选择对应的数据模型
|
||||
slot_model = self._get_slot_model(classification)
|
||||
if not slot_model:
|
||||
return {"error": "未找到匹配的槽位模型"}
|
||||
raise RuntimeError("未找到匹配的槽位模型")
|
||||
|
||||
# 使用LLM进行槽位填充
|
||||
filled_slots = self._fill_slots_with_llm(query, classification, slot_model)
|
||||
@@ -356,7 +421,7 @@ class IntentRecognizer:
|
||||
|
||||
def _get_slot_model(self, classification: Classification) -> Optional[type]:
|
||||
"""
|
||||
根据分类结果获取对应的槽位模型类
|
||||
根据分类结果获取对应的槽位模型类,用于统一提示词处理
|
||||
|
||||
Args:
|
||||
classification: 意图分类结果
|
||||
@@ -367,31 +432,33 @@ class IntentRecognizer:
|
||||
# 软件问题
|
||||
if classification.vertical_classification == "软件问题":
|
||||
if classification.sub_classification == "软件功能":
|
||||
return SoftwareFunction
|
||||
return SoftwareFunctionSlots
|
||||
elif classification.sub_classification == "故障排查":
|
||||
return TroubleShooting
|
||||
return SoftwareTroubleShootingSlots
|
||||
|
||||
# 业务问题
|
||||
elif classification.vertical_classification == "业务问题":
|
||||
if classification.sub_classification == "专业咨询":
|
||||
return ProfessionalConsulting
|
||||
return ProfessionalConsultingSlots
|
||||
elif classification.sub_classification == "数据问题":
|
||||
return DataProblem
|
||||
return DataProblemSlots
|
||||
|
||||
# 安装下载注册
|
||||
elif classification.vertical_classification == "安装下载注册":
|
||||
if classification.sub_classification == "后缀名咨询":
|
||||
return FileExtensionConsulting
|
||||
return FileExtensionConsultingSlots
|
||||
elif classification.sub_classification == "软件锁类":
|
||||
return SoftwareLock
|
||||
return SoftwareLockSlots
|
||||
elif classification.sub_classification == "安装下载类":
|
||||
return InstallationDownload
|
||||
return InstallationDownloadSlots
|
||||
elif classification.sub_classification == "问题排查类":
|
||||
return ProblemDiagnosis
|
||||
return ProblemDiagnosisSlots
|
||||
|
||||
# 其他
|
||||
elif classification.vertical_classification == "其他":
|
||||
return OtherSlots
|
||||
|
||||
return None
|
||||
|
||||
count=1
|
||||
|
||||
def _fill_slots_with_llm(self, query: str, classification: Classification, slot_model_class: type) -> Any:
|
||||
"""
|
||||
@@ -416,7 +483,7 @@ class IntentRecognizer:
|
||||
)
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm.invoke(formatted_prompt, False)
|
||||
response = self._llm.invoke(formatted_prompt, False)
|
||||
|
||||
try:
|
||||
# 尝试解析LLM响应
|
||||
@@ -426,29 +493,88 @@ class IntentRecognizer:
|
||||
# 如果解析失败,创建一个空的模型实例
|
||||
empty_instance = slot_model_class()
|
||||
return empty_instance
|
||||
|
||||
def process_query_with_slots(self, query: str) -> Dict[str, Any]:
|
||||
|
||||
def _process_intent_and_slot(self, user_input: str, conversation_context: str = "",
|
||||
chat_history: List[Dict[str, str]] = None,
|
||||
previous_slots: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
处理用户问题的完整流程,包括槽位填充
|
||||
使用统一提示词同时进行意图识别和槽位填充
|
||||
|
||||
Args:
|
||||
query: 用户原始问题
|
||||
user_input: 当前用户输入
|
||||
conversation_context: 会话背景信息
|
||||
chat_history: 历史对话记录,格式为[{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
previous_slots: 历史槽位信息
|
||||
|
||||
Returns:
|
||||
包含分类、关键词、改写和槽位填充结果的字典
|
||||
包含意图分类和槽位填充结果的字典
|
||||
"""
|
||||
# 执行基本处理流程
|
||||
classification, keywords, rewrite, query_keys = self.process_query(query)
|
||||
# 初始化默认值
|
||||
if chat_history is None:
|
||||
chat_history = []
|
||||
|
||||
# 如果是有效分类,进行槽位填充
|
||||
slot_filling_result = {}
|
||||
if classification.vertical_classification not in ["其他", "闲聊"] and classification.sub_classification not in ["其他", "闲聊"]:
|
||||
slot_filling_result = self.fill_slots(rewrite.rewrite, classification)
|
||||
if previous_slots is None:
|
||||
previous_slots = {}
|
||||
|
||||
# 生成槽位映射文档
|
||||
slot_mapping_doc = generate_slot_mapping_doc()
|
||||
|
||||
return {
|
||||
"classification": classification.model_dump(),
|
||||
"keywords": keywords.model_dump(),
|
||||
"rewrite": rewrite.model_dump(),
|
||||
"query_keys": query_keys,
|
||||
"slot_filling": slot_filling_result
|
||||
}
|
||||
# 准备提示词
|
||||
parser = PydanticOutputParser(pydantic_object=IntentAndSlotResult)
|
||||
formatted_prompt = intent_and_slot_prompt.format(
|
||||
conversation_context=conversation_context,
|
||||
chat_history=json.dumps(chat_history, ensure_ascii=False),
|
||||
previous_slots=json.dumps(previous_slots, ensure_ascii=False),
|
||||
user_input=user_input,
|
||||
slot_mapping_doc=slot_mapping_doc,
|
||||
output_format=parser.get_format_instructions(),
|
||||
classification_info=classification_info
|
||||
)
|
||||
# 调用LLM
|
||||
response = self._llm.invoke(formatted_prompt + output_example, False)
|
||||
|
||||
try:
|
||||
# 解析LLM响应为JSON
|
||||
result_json = parser.parse(response.content)
|
||||
classification=result_json.classification
|
||||
slot_filling=result_json.slots
|
||||
is_complete, missing_slots = slot_filling.check_required_slots()
|
||||
expected_slot_model = self._get_slot_model(classification)
|
||||
|
||||
# 添加容错处理,发生概率较低,但仍需处理
|
||||
if expected_slot_model is None:
|
||||
# 添加容错处理,应对LLM返回错误分类信息,一级分类跟二级分类错乱
|
||||
# 重新分类
|
||||
classification = self._classify_intent(user_input)
|
||||
fill_slots = self._fill_slots(user_input, classification)
|
||||
result = {
|
||||
"classification": classification.model_dump(),
|
||||
"slot_filling": fill_slots
|
||||
}
|
||||
logging.warning(f"重新分类与槽点填充")
|
||||
return result
|
||||
elif expected_slot_model.__name__ != type(slot_filling).__name__:
|
||||
# 添加容错处理,应对LLM槽位与分类不匹配。重新填充槽位
|
||||
slot_filling = self._fill_slots(user_input, classification)
|
||||
result = {
|
||||
"classification": classification.model_dump(),
|
||||
"slot_filling": slot_filling
|
||||
}
|
||||
logging.warning(f"重新填充槽点")
|
||||
return result
|
||||
|
||||
# 构建最终结果
|
||||
result = {
|
||||
"classification": classification.model_dump(),
|
||||
"slot_filling": {
|
||||
"is_complete": is_complete,
|
||||
"missing_slots": missing_slots,
|
||||
"filled_data": slot_filling.model_dump()
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"process_intent_and_slot error:{e}")
|
||||
raise RuntimeError(f"process_intent_and_slot error:{e}") from e
|
||||
Reference in New Issue
Block a user