453 lines
18 KiB
Python
453 lines
18 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
File: IntentRecognition.py
|
|
Author: oyyz
|
|
Date: 2025-05-13
|
|
Description: 意图分类、改写核心逻辑
|
|
"""
|
|
|
|
import os
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain.output_parsers import PydanticOutputParser
|
|
import json
|
|
from typing import List, Tuple, Dict, Any, Optional, Union
|
|
import re
|
|
from .PromptTemplates import classification_prompt, query_rewrite_prompt, extract_nouns_prompt, classification_info, slot_filling_prompt
|
|
from .DataModels import (
|
|
Classification, QueryRewrite, Term, TermList,
|
|
SoftwareFunction, TroubleShooting, ProfessionalConsulting,
|
|
DataProblem, FileExtensionConsulting, SoftwareLock,
|
|
InstallationDownload, ProblemDiagnosis
|
|
)
|
|
from .ProfessionalNounVector import ProfessionalNounRetriever
|
|
from rag2_0.tool.ModelTool import XinferenceReRankerModel, OpenAiLLM, SiliconFlowReRankerModel
|
|
|
|
|
|
class IntentRecognizer:
|
|
"""
|
|
意图识别和问题改写类
|
|
"""
|
|
def __init__(self, api_key: str = None, base_url: str = None, model_name: str = "gpt-3.5-turbo", vector_index_dir: str = None):
|
|
"""
|
|
初始化意图识别器
|
|
|
|
Args:
|
|
api_key: OpenAI API密钥,如果为None则从环境变量获取
|
|
base_url: OpenAI API基础URL,如果为None则使用默认URL
|
|
model_name: 要使用的模型名称
|
|
vector_index_dir: 向量索引目录,如果为None则使用默认目录
|
|
"""
|
|
# 初始化LLM
|
|
llm_params = {
|
|
"temperature": 0.2, # 降低随机性,使结果更确定
|
|
"model": model_name
|
|
}
|
|
|
|
# 如果提供了API密钥,则使用提供的密钥
|
|
if api_key:
|
|
llm_params["api_key"] = api_key
|
|
|
|
# 如果提供了自定义URL,则使用提供的URL
|
|
if base_url:
|
|
llm_params["base_url"] = base_url
|
|
|
|
self.llm = OpenAiLLM(**llm_params)
|
|
|
|
# 准备分类解析器
|
|
self.classification_parser = PydanticOutputParser(pydantic_object=Classification)
|
|
|
|
# 准备问题改写解析器
|
|
self.query_rewrite_parser = PydanticOutputParser(pydantic_object=QueryRewrite)
|
|
|
|
# 准备术语列表解析器
|
|
self.terms_list_parser = PydanticOutputParser(pydantic_object=TermList)
|
|
|
|
# 加载suffix关键词
|
|
self.suffix_keywords = self._load_suffix_keywords()
|
|
|
|
# 初始化向量检索器
|
|
self.noun_retriever = ProfessionalNounRetriever(api_key=api_key, index_dir=vector_index_dir)
|
|
|
|
def _load_suffix_keywords(self, filepath: str = None) -> List[str]:
|
|
"""
|
|
加载后缀关键词列表
|
|
|
|
Args:
|
|
filepath: 后缀关键词文件路径,默认为None使用默认路径
|
|
|
|
Returns:
|
|
后缀关键词列表
|
|
"""
|
|
try:
|
|
# 如果未指定路径,使用默认路径
|
|
if filepath is None:
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
filepath = os.path.join(current_dir, "..", "..", "data", "nouns", "suffix_keywords.json")
|
|
|
|
# 读取JSON文件
|
|
with open(filepath, "r", encoding="utf-8") as f:
|
|
suffix_data = json.load(f)
|
|
|
|
# 添加额外的固定后缀
|
|
return suffix_data
|
|
|
|
except Exception as e:
|
|
raise RuntimeError(f"加载后缀关键词失败: {e}") from e
|
|
|
|
def classify_intent(self, query: str, keywords: TermList) -> Classification:
|
|
"""
|
|
对用户输入进行意图分类
|
|
|
|
Args:
|
|
content: 用户输入内容
|
|
keywords: 匹配到的关键词列表
|
|
rewrite: 重写的问题
|
|
Returns:
|
|
分类结果
|
|
"""
|
|
formatted_prompt = classification_prompt.format(user_input=query,
|
|
classification_info=classification_info,
|
|
output_format=self.classification_parser.get_format_instructions())
|
|
# 将关键词列表转换为JSON字符串
|
|
terms_dict = [term.model_dump() for term in keywords.terms]
|
|
keywords_str = json.dumps(terms_dict, ensure_ascii=False)
|
|
formatted_prompt = formatted_prompt.replace("{keywords}", keywords_str)
|
|
# 调用LLM
|
|
response = self.llm.invoke(formatted_prompt, False)
|
|
|
|
# 解析输出
|
|
try:
|
|
# 尝试直接解析JSON响应
|
|
parsed_output = self.classification_parser.parse(response.content.strip())
|
|
return parsed_output
|
|
except Exception as e:
|
|
raise RuntimeError(f"解析分类结果时出错: {e}") from e
|
|
|
|
def extract_keywords_with_llm(self, query: str) -> List[Term]:
|
|
"""
|
|
使用LLM从用户查询中提取专业关键词
|
|
|
|
Args:
|
|
query: 用户查询
|
|
|
|
Returns:
|
|
提取的术语列表
|
|
"""
|
|
# 准备提示词
|
|
formatted_prompt = extract_nouns_prompt.replace("{content}", query)
|
|
formatted_prompt = formatted_prompt.replace("{output_format}", self.terms_list_parser.get_format_instructions())
|
|
|
|
# 调用LLM
|
|
response = self.llm.invoke(formatted_prompt, False)
|
|
|
|
try:
|
|
# 尝试使用Pydantic解析器解析TermList
|
|
parsed_output = self.terms_list_parser.parse(response.content)
|
|
return parsed_output.terms
|
|
except Exception as e:
|
|
raise RuntimeError(f"无法解析LLM关键词提取响应: {e}") from e
|
|
|
|
def rerank_matched_terms(self, query_key: str, matched_terms: set, top_k: int = 2) -> List[Term]:
|
|
"""
|
|
对召回的专业术语进行重排序,按与用户查询的相关性排序
|
|
|
|
Args:
|
|
query: 用户查询
|
|
matched_terms: 匹配到的专业术语集合
|
|
query_keys: 用户查询中提取的关键词列表
|
|
|
|
Returns:
|
|
重排序后的专业术语列表
|
|
"""
|
|
if not matched_terms:
|
|
return []
|
|
|
|
try:
|
|
# 将每个术语转换为可用于重排序的文本表示
|
|
term_texts = ["名称:" + term.name + "|" + "同义词:" + ";".join(term.synonymous) + "|" + "描述:" + term.description for term in matched_terms]
|
|
|
|
# 使用重排序模型
|
|
xinference_reranker = SiliconFlowReRankerModel()
|
|
rerank_results = xinference_reranker.rerank(query_key, term_texts, top_k=top_k)
|
|
|
|
# 将matched_terms转换为列表以便按索引访问
|
|
matched_terms_list = list(matched_terms)
|
|
|
|
# 根据重排序结果获取排序后的术语列表
|
|
reranked_terms = [matched_terms_list[result["index"]] for result in rerank_results if result["score"] >= 0.6]
|
|
|
|
return reranked_terms
|
|
|
|
except Exception as e:
|
|
return list(matched_terms)
|
|
|
|
def match_keywords(self, query: str) -> Tuple[TermList, List[str]]:
|
|
"""
|
|
从用户问题中匹配关键词,结合LLM提取和向量检索
|
|
|
|
Args:
|
|
query: 用户问题
|
|
|
|
Returns:
|
|
匹配到的关键词列表
|
|
"""
|
|
query_keys=[]
|
|
# 步骤2: 使用LLM提取查询中的关键词
|
|
try:
|
|
extracted_terms = self.extract_keywords_with_llm(query)
|
|
for term in extracted_terms:
|
|
query_keys.append(term.name)
|
|
except Exception as e:
|
|
raise RuntimeError(f"LLM关键词提取失败: {e}") from e
|
|
|
|
matched_terms = [] # 存储匹配到的Term对象
|
|
# 步骤3: 使用向量检索找到相似的专业名词
|
|
try:
|
|
# 对matched_terms中的每个关键字进行向量检索
|
|
for current_key in query_keys:
|
|
vector_results = self.noun_retriever.query(current_key, top_k=3, use_intersection=True)
|
|
current_key_terms = set()
|
|
# 添加向量检索结果
|
|
for result in vector_results:
|
|
if isinstance(result.get('synonymous', []), str):
|
|
result['synonymous'] = result['synonymous'].split(';')
|
|
term = Term(
|
|
name=result.get('name'),
|
|
synonymous=result.get('synonymous', []),
|
|
description=result.get('description', '')
|
|
)
|
|
current_key_terms.add(term)
|
|
reranked_terms = self.rerank_matched_terms(current_key, current_key_terms)
|
|
matched_terms.extend(reranked_terms)
|
|
except Exception as e:
|
|
raise RuntimeError(f"向量检索关键词时出错: {e}") from e
|
|
|
|
# 提取所有Term对象的名称并排序
|
|
# 将set类型的matched_terms转换为TermList类型
|
|
term_list = TermList(terms=list(matched_terms))
|
|
return term_list, query_keys
|
|
|
|
def rewrite_query(self, query: str, keywords: TermList) -> QueryRewrite:
|
|
"""
|
|
对用户问题进行改写
|
|
|
|
Args:
|
|
query: 用户原始问题
|
|
keywords: 匹配到的关键词列表
|
|
|
|
Returns:
|
|
改写结果
|
|
"""
|
|
# 准备问题改写提示
|
|
terms_dict = [term.model_dump(exclude={"description"}) for term in keywords.terms]
|
|
keywords_str = json.dumps(terms_dict, ensure_ascii=False)
|
|
formatted_prompt = query_rewrite_prompt.format(query=query,
|
|
output_format=self.query_rewrite_parser.get_format_instructions(),
|
|
keywords=keywords_str)
|
|
|
|
|
|
# 调用LLM
|
|
response = self.llm.invoke(formatted_prompt, False)
|
|
|
|
# 解析输出
|
|
try:
|
|
# 尝试直接解析JSON响应
|
|
parsed_output = self.query_rewrite_parser.parse(response.content)
|
|
return parsed_output
|
|
except Exception as e:
|
|
raise RuntimeError(f"解析问题改写结果时出错: {e}") from e
|
|
|
|
def judge_define_suffix(self, input_str: str) -> Tuple[bool, List[str]]:
|
|
"""
|
|
判断输入字符串是否包含定义的后缀,并返回所有匹配到的后缀名列表
|
|
|
|
Args:
|
|
input_str: 输入字符串
|
|
|
|
Returns:
|
|
Tuple[bool, List[str]]: (是否包含定义的后缀, 匹配到的后缀名列表)
|
|
"""
|
|
|
|
# 构建正则表达式模式,匹配大小写不敏感且前面可能带有.
|
|
pattern = r'(?:\.?)(' + '|'.join(re.escape(field.get('name')) for field in self.suffix_keywords) + r')'
|
|
|
|
# 使用 re.IGNORECASE 标志来忽略大小写,findall找到所有匹配
|
|
matches = re.finditer(pattern, input_str, re.IGNORECASE)
|
|
matched_suffixes = [match.group(1) for match in matches]
|
|
|
|
return bool(matched_suffixes), matched_suffixes
|
|
|
|
def process_query(self, query: str) -> Tuple[Classification, TermList, QueryRewrite, List[str]]:
|
|
"""
|
|
处理用户问题的完整流程
|
|
|
|
Args:
|
|
query: 用户原始问题
|
|
|
|
Returns:
|
|
(意图分类结果, 匹配的关键词列表, 问题改写结果)的元组
|
|
"""
|
|
# 是否是扩展名
|
|
# is_suffix, matched_suffixes = self.judge_define_suffix(query)
|
|
# if is_suffix:
|
|
# # 将所有匹配到的后缀名作为Term添加到结果中
|
|
# suffix_terms = []
|
|
# for suffix in matched_suffixes:
|
|
# term_dict = next((item for item in self.suffix_keywords if item['name'].lower() == suffix.lower()), None)
|
|
# if term_dict:
|
|
# suffix_term = Term(
|
|
# name=term_dict.get('name'),
|
|
# synonymous=term_dict.get('synonymous', []),
|
|
# description=json.dumps(term_dict.get('description', ''), ensure_ascii=False)
|
|
# )
|
|
# suffix_terms.append(suffix_term)
|
|
|
|
# return Classification(vertical_classification="安装下载", sub_classification="查询"), TermList(terms=suffix_terms), QueryRewrite(rewrite=query), matched_suffixes
|
|
|
|
# 步骤1: 匹配关键词
|
|
keywords_terms, query_keys = self.match_keywords(query)
|
|
|
|
# 步骤2: 问题改写
|
|
rewrite = self.rewrite_query(
|
|
query=query,
|
|
keywords=keywords_terms
|
|
)
|
|
|
|
# 步骤3: 进行意图分类
|
|
classification = self.classify_intent(query, keywords_terms)
|
|
if classification.vertical_classification == "其他" or classification.sub_classification == "其他":
|
|
return classification, TermList(terms=[]), QueryRewrite(rewrite=query), []
|
|
|
|
if classification.vertical_classification == "闲聊" or classification.sub_classification == "闲聊":
|
|
return classification, TermList(terms=[]), QueryRewrite(rewrite=query),[]
|
|
|
|
# rewrite = QueryRewrite(rewrite=query)
|
|
return classification, keywords_terms, rewrite, query_keys
|
|
|
|
def fill_slots(self, query: str, classification: Classification) -> Dict[str, Any]:
|
|
"""
|
|
根据分类结果对问题进行槽位填充
|
|
|
|
Args:
|
|
query: 用户原始问题
|
|
classification: 意图分类结果
|
|
keywords: 匹配的关键词列表
|
|
|
|
Returns:
|
|
填充后的槽位数据模型
|
|
"""
|
|
# 根据分类结果选择对应的数据模型
|
|
slot_model = self._get_slot_model(classification)
|
|
if not slot_model:
|
|
return {"error": "未找到匹配的槽位模型"}
|
|
|
|
# 使用LLM进行槽位填充
|
|
filled_slots = self._fill_slots_with_llm(query, classification, slot_model)
|
|
|
|
# 检查必填槽位是否都已填充
|
|
is_complete, missing_slots = filled_slots.check_required_slots()
|
|
|
|
return {
|
|
"is_complete": is_complete,
|
|
"missing_slots": missing_slots,
|
|
"filled_data": filled_slots.model_dump()
|
|
}
|
|
|
|
def _get_slot_model(self, classification: Classification) -> Optional[type]:
|
|
"""
|
|
根据分类结果获取对应的槽位模型类
|
|
|
|
Args:
|
|
classification: 意图分类结果
|
|
|
|
Returns:
|
|
对应的槽位模型类
|
|
"""
|
|
# 软件问题
|
|
if classification.vertical_classification == "软件问题":
|
|
if classification.sub_classification == "软件功能":
|
|
return SoftwareFunction
|
|
elif classification.sub_classification == "故障排查":
|
|
return TroubleShooting
|
|
|
|
# 业务问题
|
|
elif classification.vertical_classification == "业务问题":
|
|
if classification.sub_classification == "专业咨询":
|
|
return ProfessionalConsulting
|
|
elif classification.sub_classification == "数据问题":
|
|
return DataProblem
|
|
|
|
# 安装下载注册
|
|
elif classification.vertical_classification == "安装下载注册":
|
|
if classification.sub_classification == "后缀名咨询":
|
|
return FileExtensionConsulting
|
|
elif classification.sub_classification == "软件锁类":
|
|
return SoftwareLock
|
|
elif classification.sub_classification == "安装下载类":
|
|
return InstallationDownload
|
|
elif classification.sub_classification == "问题排查类":
|
|
return ProblemDiagnosis
|
|
|
|
return None
|
|
|
|
def _fill_slots_with_llm(self, query: str, classification: Classification, slot_model_class: type) -> Any:
|
|
"""
|
|
使用LLM进行槽位填充
|
|
|
|
Args:
|
|
query: 用户原始问题
|
|
classification: 意图分类结果
|
|
slot_model_class: 槽位模型类
|
|
|
|
Returns:
|
|
填充后的槽位数据模型实例
|
|
"""
|
|
# 准备提示词
|
|
slot_parser = PydanticOutputParser(pydantic_object=slot_model_class)
|
|
model_schema = json.dumps(slot_model_class.model_json_schema(), ensure_ascii=False)
|
|
|
|
formatted_prompt = slot_filling_prompt.format(
|
|
query=query,
|
|
vertical_classification=classification.vertical_classification,
|
|
sub_classification=classification.sub_classification,
|
|
output_format=slot_parser.get_format_instructions()
|
|
)
|
|
|
|
# 调用LLM
|
|
response = self.llm.invoke(formatted_prompt, False)
|
|
|
|
try:
|
|
# 尝试解析LLM响应
|
|
parsed_output = slot_parser.parse(response.content)
|
|
return parsed_output
|
|
except Exception as e:
|
|
# 如果解析失败,创建一个空的模型实例
|
|
empty_instance = slot_model_class()
|
|
return empty_instance
|
|
|
|
def process_query_with_slots(self, query: str) -> Dict[str, Any]:
|
|
"""
|
|
处理用户问题的完整流程,包括槽位填充
|
|
|
|
Args:
|
|
query: 用户原始问题
|
|
|
|
Returns:
|
|
包含分类、关键词、改写和槽位填充结果的字典
|
|
"""
|
|
# 执行基本处理流程
|
|
classification, keywords, rewrite, query_keys = self.process_query(query)
|
|
|
|
# 如果是有效分类,进行槽位填充
|
|
slot_filling_result = {}
|
|
if classification.vertical_classification not in ["其他", "闲聊"] and classification.sub_classification not in ["其他", "闲聊"]:
|
|
slot_filling_result = self.fill_slots(rewrite.rewrite, classification)
|
|
|
|
return {
|
|
"classification": classification.model_dump(),
|
|
"keywords": keywords.model_dump(),
|
|
"rewrite": rewrite.model_dump(),
|
|
"query_keys": query_keys,
|
|
"slot_filling": slot_filling_result
|
|
} |