上传问题改写、意图识别模块代码
This commit is contained in:
@@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
File: DataModels.py
|
||||
Author: oyyz
|
||||
Date: 2025-05-13
|
||||
Description: 提取和分类的数据模型
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
# 定义输出模型
|
||||
class Term(BaseModel):
|
||||
name: str = Field(description="专业名词")
|
||||
synonymous: List[str] = Field(description="同义词列表")
|
||||
description: str = Field(description="描述信息", default="")
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.name)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Term):
|
||||
return self.name == other.name
|
||||
return False
|
||||
|
||||
class TermList(BaseModel):
|
||||
terms: List[Term] = Field(description="专业名词列表")
|
||||
|
||||
class Classification(BaseModel):
|
||||
vertical_classification:str = Field(description="垂直领域一级分类")
|
||||
sub_classification:str = Field(description="一级分类下的二级分类")
|
||||
|
||||
class QueryRewrite(BaseModel):
|
||||
rewrite:str = Field(description="问题改写")
|
||||
@@ -0,0 +1,289 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
File: IntentRecognition.py
|
||||
Author: oyyz
|
||||
Date: 2025-05-13
|
||||
Description: 意图分类、改写核心逻辑
|
||||
"""
|
||||
|
||||
import os
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
import json
|
||||
from typing import List, Tuple
|
||||
import re
|
||||
from .PromptTemplates import classification_prompt, query_rewrite_prompt, extract_nouns_prompt, classification_info
|
||||
from .DataModels import Classification, QueryRewrite, Term, TermList
|
||||
from .ProfessionalNounVector import ProfessionalNounRetriever
|
||||
from rag2_0.tool.ModelTool import XinferenceReRankerModel, OpenAiLLM
|
||||
|
||||
|
||||
class IntentRecognizer:
|
||||
"""
|
||||
意图识别和问题改写类
|
||||
"""
|
||||
def __init__(self, api_key: str = None, base_url: str = None, model_name: str = "gpt-3.5-turbo", vector_index_dir: str = None):
|
||||
"""
|
||||
初始化意图识别器
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API密钥,如果为None则从环境变量获取
|
||||
base_url: OpenAI API基础URL,如果为None则使用默认URL
|
||||
model_name: 要使用的模型名称
|
||||
vector_index_dir: 向量索引目录,如果为None则使用默认目录
|
||||
"""
|
||||
# 初始化LLM
|
||||
llm_params = {
|
||||
"temperature": 0.2, # 降低随机性,使结果更确定
|
||||
"model": model_name
|
||||
}
|
||||
|
||||
# 如果提供了API密钥,则使用提供的密钥
|
||||
if api_key:
|
||||
llm_params["api_key"] = api_key
|
||||
|
||||
# 如果提供了自定义URL,则使用提供的URL
|
||||
if base_url:
|
||||
llm_params["base_url"] = base_url
|
||||
|
||||
self.llm = OpenAiLLM(**llm_params)
|
||||
|
||||
# 准备分类解析器
|
||||
self.classification_parser = PydanticOutputParser(pydantic_object=Classification)
|
||||
|
||||
# 准备问题改写解析器
|
||||
self.query_rewrite_parser = PydanticOutputParser(pydantic_object=QueryRewrite)
|
||||
|
||||
# 准备术语列表解析器
|
||||
self.terms_list_parser = PydanticOutputParser(pydantic_object=TermList)
|
||||
|
||||
# 加载suffix关键词
|
||||
self.suffix_keywords = self._load_suffix_keywords()
|
||||
|
||||
# 初始化向量检索器
|
||||
self.noun_retriever = ProfessionalNounRetriever(api_key=api_key, index_dir=vector_index_dir)
|
||||
|
||||
def _load_suffix_keywords(self, filepath: str = None) -> List[str]:
|
||||
"""
|
||||
加载后缀关键词列表
|
||||
|
||||
Args:
|
||||
filepath: 后缀关键词文件路径,默认为None使用默认路径
|
||||
|
||||
Returns:
|
||||
后缀关键词列表
|
||||
"""
|
||||
try:
|
||||
# 如果未指定路径,使用默认路径
|
||||
if filepath is None:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
filepath = os.path.join(current_dir, "..", "..", "data", "nouns", "suffix_keywords.json")
|
||||
|
||||
# 读取JSON文件
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
suffix_data = json.load(f)
|
||||
|
||||
# 添加额外的固定后缀
|
||||
return suffix_data
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"加载后缀关键词失败: {e}") from e
|
||||
|
||||
def classify_intent(self, query: str, keywords: TermList) -> Classification:
|
||||
"""
|
||||
对用户输入进行意图分类
|
||||
|
||||
Args:
|
||||
content: 用户输入内容
|
||||
keywords: 匹配到的关键词列表
|
||||
rewrite: 重写的问题
|
||||
Returns:
|
||||
分类结果
|
||||
"""
|
||||
formatted_prompt = classification_prompt.replace("{user_input}", query)
|
||||
formatted_prompt = formatted_prompt.replace("{classification_info}", classification_info)
|
||||
formatted_prompt = formatted_prompt.replace("{output_format}", self.classification_parser.get_format_instructions())
|
||||
# 将关键词列表转换为JSON字符串
|
||||
terms_dict = [term.model_dump() for term in keywords.terms]
|
||||
keywords_str = json.dumps(terms_dict, ensure_ascii=False)
|
||||
formatted_prompt = formatted_prompt.replace("{keywords}", keywords_str)
|
||||
# 调用LLM
|
||||
response = self.llm.invoke(formatted_prompt, False)
|
||||
|
||||
# 解析输出
|
||||
try:
|
||||
# 尝试直接解析JSON响应
|
||||
parsed_output = self.classification_parser.parse(response.content.strip())
|
||||
return parsed_output
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"解析分类结果时出错: {e}") from e
|
||||
|
||||
def extract_keywords_with_llm(self, query: str) -> List[Term]:
|
||||
"""
|
||||
使用LLM从用户查询中提取专业关键词
|
||||
|
||||
Args:
|
||||
query: 用户查询
|
||||
|
||||
Returns:
|
||||
提取的术语列表
|
||||
"""
|
||||
# 准备提示词
|
||||
formatted_prompt = extract_nouns_prompt.replace("{content}", query)
|
||||
formatted_prompt = formatted_prompt.replace("{output_format}", self.terms_list_parser.get_format_instructions())
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm.invoke(formatted_prompt, False)
|
||||
|
||||
try:
|
||||
# 尝试使用Pydantic解析器解析TermList
|
||||
parsed_output = self.terms_list_parser.parse(response.content)
|
||||
return parsed_output.terms
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"无法解析LLM关键词提取响应: {e}") from e
|
||||
|
||||
def match_keywords(self, query: str) -> Tuple[TermList, List[str]]:
|
||||
"""
|
||||
从用户问题中匹配关键词,结合LLM提取和向量检索
|
||||
|
||||
Args:
|
||||
query: 用户问题
|
||||
|
||||
Returns:
|
||||
匹配到的关键词列表
|
||||
"""
|
||||
matched_terms = set() # 存储匹配到的Term对象
|
||||
query_keys=[]
|
||||
# 步骤2: 使用LLM提取查询中的关键词
|
||||
try:
|
||||
extracted_terms = self.extract_keywords_with_llm(query)
|
||||
for term in extracted_terms:
|
||||
query_keys.append(term.name)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"LLM关键词提取失败: {e}") from e
|
||||
|
||||
# 步骤3: 使用向量检索找到相似的专业名词
|
||||
try:
|
||||
# 对matched_terms中的每个关键字进行向量检索
|
||||
for current_key in query_keys:
|
||||
vector_results = self.noun_retriever.query(current_key, top_k=3, use_intersection=True)
|
||||
|
||||
# 添加向量检索结果
|
||||
for result in vector_results:
|
||||
term = Term(
|
||||
name=result.get('name'),
|
||||
synonymous=result.get('synonymous', []),
|
||||
description=result.get('description', '')
|
||||
)
|
||||
matched_terms.add(term)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"向量检索关键词时出错: {e}") from e
|
||||
|
||||
if len(matched_terms) != 0:
|
||||
txts = ["名称:" + term.name + "|" + "同义词:" + ";".join(term.synonymous) + "|" + "描述:" + term.description for term in matched_terms]
|
||||
# txts = [term.name for term in matched_terms]
|
||||
xinference_reranker = XinferenceReRankerModel()
|
||||
rerank_results = xinference_reranker.rerank(query, txts, top_k=5)
|
||||
matched_terms_list = list(matched_terms)
|
||||
matched_terms = [matched_terms_list[result["index"]] for result in rerank_results]
|
||||
# 提取所有Term对象的名称并排序
|
||||
# 将set类型的matched_terms转换为TermList类型
|
||||
term_list = TermList(terms=list(matched_terms))
|
||||
return term_list, query_keys
|
||||
|
||||
def rewrite_query(self, query: str, keywords: TermList) -> QueryRewrite:
|
||||
"""
|
||||
对用户问题进行改写
|
||||
|
||||
Args:
|
||||
query: 用户原始问题
|
||||
keywords: 匹配到的关键词列表
|
||||
|
||||
Returns:
|
||||
改写结果
|
||||
"""
|
||||
# 准备问题改写提示
|
||||
terms_dict = [term.model_dump(exclude={"description"}) for term in keywords.terms]
|
||||
keywords_str = json.dumps(terms_dict, ensure_ascii=False)
|
||||
formatted_prompt = query_rewrite_prompt.format(query=query, output_format=self.query_rewrite_parser.get_format_instructions(),keywords=keywords_str)
|
||||
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm.invoke(formatted_prompt, False)
|
||||
|
||||
# 解析输出
|
||||
try:
|
||||
# 尝试直接解析JSON响应
|
||||
parsed_output = self.query_rewrite_parser.parse(response.content)
|
||||
return parsed_output
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"解析问题改写结果时出错: {e}") from e
|
||||
|
||||
def judge_define_suffix(self, input_str: str) -> Tuple[bool, List[str]]:
|
||||
"""
|
||||
判断输入字符串是否包含定义的后缀,并返回所有匹配到的后缀名列表
|
||||
|
||||
Args:
|
||||
input_str: 输入字符串
|
||||
|
||||
Returns:
|
||||
Tuple[bool, List[str]]: (是否包含定义的后缀, 匹配到的后缀名列表)
|
||||
"""
|
||||
|
||||
# 构建正则表达式模式,匹配大小写不敏感且前面可能带有.
|
||||
pattern = r'(?:\.?)(' + '|'.join(re.escape(field.get('name')) for field in self.suffix_keywords) + r')'
|
||||
|
||||
# 使用 re.IGNORECASE 标志来忽略大小写,findall找到所有匹配
|
||||
matches = re.finditer(pattern, input_str, re.IGNORECASE)
|
||||
matched_suffixes = [match.group(1) for match in matches]
|
||||
|
||||
return bool(matched_suffixes), matched_suffixes
|
||||
|
||||
def process_query(self, query: str) -> Tuple[Classification, TermList, QueryRewrite, List[str]]:
|
||||
"""
|
||||
处理用户问题的完整流程
|
||||
|
||||
Args:
|
||||
query: 用户原始问题
|
||||
|
||||
Returns:
|
||||
(意图分类结果, 匹配的关键词列表, 问题改写结果)的元组
|
||||
"""
|
||||
# 是否是扩展名
|
||||
# is_suffix, matched_suffixes = self.judge_define_suffix(query)
|
||||
# if is_suffix:
|
||||
# # 将所有匹配到的后缀名作为Term添加到结果中
|
||||
# suffix_terms = []
|
||||
# for suffix in matched_suffixes:
|
||||
# term_dict = next((item for item in self.suffix_keywords if item['name'].lower() == suffix.lower()), None)
|
||||
# if term_dict:
|
||||
# suffix_term = Term(
|
||||
# name=term_dict.get('name'),
|
||||
# synonymous=term_dict.get('synonymous', []),
|
||||
# description=json.dumps(term_dict.get('description', ''), ensure_ascii=False)
|
||||
# )
|
||||
# suffix_terms.append(suffix_term)
|
||||
|
||||
# return Classification(vertical_classification="安装下载", sub_classification="查询"), TermList(terms=suffix_terms), QueryRewrite(rewrite=query), matched_suffixes
|
||||
|
||||
# 步骤1: 匹配关键词
|
||||
keywords_terms, query_keys = self.match_keywords(query)
|
||||
|
||||
# 步骤2: 问题改写
|
||||
rewrite = self.rewrite_query(
|
||||
query=query,
|
||||
keywords=keywords_terms
|
||||
)
|
||||
|
||||
# 步骤3: 进行意图分类
|
||||
classification = self.classify_intent(query, keywords_terms)
|
||||
if classification.vertical_classification == "其他" or classification.sub_classification == "其他":
|
||||
return classification, TermList(terms=[]), QueryRewrite(rewrite=query), []
|
||||
|
||||
if classification.vertical_classification == "闲聊" or classification.sub_classification == "闲聊":
|
||||
return classification, TermList(terms=[]), QueryRewrite(rewrite=query),[]
|
||||
|
||||
# rewrite = QueryRewrite(rewrite=query)
|
||||
return classification, keywords_terms, rewrite, query_keys
|
||||
@@ -0,0 +1,321 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
File: ProfessionalNounVector.py
|
||||
Date: 2025-05-15
|
||||
Author: oyyz
|
||||
Description: 专业名词向量化和检索的核心逻辑
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import shutil
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from rag2_0.tool.ModelTool import SiliconFlowEmbeddings
|
||||
import logging
|
||||
|
||||
def get_embedding_model(api_key: str = None) -> Embeddings:
|
||||
"""
|
||||
获取嵌入模型
|
||||
|
||||
Args:
|
||||
api_key: API密钥,如果为None则从环境变量获取
|
||||
|
||||
Returns:
|
||||
嵌入模型实例
|
||||
"""
|
||||
if not api_key:
|
||||
api_key = os.getenv("SILICONFLOW_API_KEY", "sk-ftnofbucchwnscojohyxwmfzgaykdxihafnlphohsinftkbr")
|
||||
return SiliconFlowEmbeddings(api_key=api_key)
|
||||
|
||||
|
||||
class ProfessionalNounVectorizer:
|
||||
"""专业名词向量化和保存类"""
|
||||
|
||||
def __init__(self,
|
||||
embedding_model: Optional[Embeddings] = None,
|
||||
api_key: str = None,
|
||||
output_dir: str = None):
|
||||
"""
|
||||
初始化向量化器
|
||||
|
||||
Args:
|
||||
embedding_model: 嵌入模型,如果为None则使用默认模型
|
||||
api_key: SiliconFlow API密钥,仅在embedding_model为None时使用
|
||||
|
||||
output_dir: 索引输出目录,默认为None使用默认路径
|
||||
"""
|
||||
# 设置嵌入模型
|
||||
if embedding_model:
|
||||
self.embedding_model = embedding_model
|
||||
else:
|
||||
self.embedding_model = get_embedding_model(api_key)
|
||||
|
||||
|
||||
# 设置输出目录
|
||||
self.output_dir = output_dir
|
||||
if self.output_dir is None:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
self.output_dir = os.path.join(current_dir, "..", "..", "data", "nouns")
|
||||
|
||||
# 设置索引路径
|
||||
self.index_path = os.path.join(self.output_dir, "professional_nouns_index")
|
||||
|
||||
def _loadfile(self, file_paths: List[str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
加载多个专业术语JSON文件并合并
|
||||
|
||||
Args:
|
||||
file_paths: JSON文件路径列表
|
||||
|
||||
Returns:
|
||||
合并后的术语列表
|
||||
"""
|
||||
merged_terms = []
|
||||
|
||||
try:
|
||||
for file_path in file_paths:
|
||||
if not os.path.exists(file_path):
|
||||
logging.warning(f"文件不存在: {file_path}")
|
||||
continue
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
terms_data = json.load(f)
|
||||
|
||||
if isinstance(terms_data, list):
|
||||
merged_terms.extend(terms_data)
|
||||
logging.info(f"从 {file_path} 加载了 {len(terms_data)} 条专业名词")
|
||||
else:
|
||||
logging.warning(f"文件格式错误: {file_path},应为JSON数组")
|
||||
|
||||
logging.info(f"总共加载了 {len(merged_terms)} 条专业名词")
|
||||
return merged_terms
|
||||
except Exception as e:
|
||||
logging.error(f"加载多个文件失败: {e}")
|
||||
return []
|
||||
|
||||
def vectorize_files_and_save(self, file_paths: List[str]) -> bool:
|
||||
"""
|
||||
处理多个文件:加载多个术语文件、创建索引并保存
|
||||
|
||||
Args:
|
||||
file_paths: JSON文件路径列表
|
||||
|
||||
Returns:
|
||||
处理成功返回True,否则返回False
|
||||
"""
|
||||
try:
|
||||
# 加载多个文件的术语
|
||||
terms = self._loadfile(file_paths)
|
||||
|
||||
if not terms:
|
||||
logging.warning("未找到术语数据,退出处理")
|
||||
return False
|
||||
|
||||
# 根据名称去重
|
||||
unique_terms = {}
|
||||
for term in terms:
|
||||
name = term.get("name", "")
|
||||
if name and name not in unique_terms:
|
||||
unique_terms[name] = term
|
||||
|
||||
# 转换回列表
|
||||
deduplicated_terms = list(unique_terms.values())
|
||||
logging.info(f"去重后剩余 {len(deduplicated_terms)} 条专业名词")
|
||||
|
||||
# 准备数据
|
||||
texts, metadatas = self._prepare_terms_for_faiss(deduplicated_terms)
|
||||
|
||||
# 创建索引
|
||||
faiss_index = self._create_index(texts, metadatas)
|
||||
|
||||
# 保存索引
|
||||
self._save_index(faiss_index)
|
||||
|
||||
logging.info("完成多文件专业名词向量化和索引创建")
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"多文件向量化处理失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _prepare_terms_for_faiss(self, terms: List[Dict[str, Any]]) -> Tuple[List[str], List[Dict]]:
|
||||
"""
|
||||
将术语准备为FAISS可用的格式 (内部方法)
|
||||
|
||||
Args:
|
||||
terms: 术语列表
|
||||
|
||||
Returns:
|
||||
格式化的术语文本列表和元数据列表
|
||||
"""
|
||||
texts = []
|
||||
metadatas = []
|
||||
|
||||
for term in terms:
|
||||
name = term["name"]
|
||||
texts.append(name.strip())
|
||||
synonyms = term.get("synonymous", [])
|
||||
description = term.get("description", "")
|
||||
# 记录元数据
|
||||
metadatas.append({
|
||||
"name": name,
|
||||
"synonyms": synonyms,
|
||||
"description": description
|
||||
})
|
||||
|
||||
if len(synonyms) > 0:
|
||||
synonyms_str = ', '.join(synonyms)
|
||||
texts.append(synonyms_str.strip())
|
||||
metadatas.append({
|
||||
"name": name,
|
||||
"synonyms": synonyms,
|
||||
"description": description
|
||||
})
|
||||
|
||||
if len(description) > 0:
|
||||
texts.append(description.strip())
|
||||
metadatas.append({
|
||||
"name": name,
|
||||
"synonyms": synonyms,
|
||||
"description": description
|
||||
})
|
||||
|
||||
return texts, metadatas
|
||||
|
||||
def _create_index(self, texts: List[str], metadatas: List[Dict]) -> FAISS:
|
||||
"""
|
||||
创建FAISS索引 (内部方法)
|
||||
|
||||
Args:
|
||||
texts: 文本列表
|
||||
metadatas: 元数据列表
|
||||
|
||||
Returns:
|
||||
FAISS索引
|
||||
"""
|
||||
logging.info(f"正在创建FAISS索引,共 {len(texts)} 条数据...")
|
||||
return FAISS.from_texts(texts=texts, embedding=self.embedding_model, metadatas=metadatas)
|
||||
|
||||
def _save_index(self, faiss_index: FAISS) -> None:
|
||||
"""
|
||||
保存FAISS索引到本地 (内部方法)
|
||||
|
||||
Args:
|
||||
faiss_index: 要保存的FAISS索引
|
||||
"""
|
||||
try:
|
||||
# 确保输出目录存在
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
|
||||
# 如果索引目录已存在,先删除
|
||||
if os.path.exists(self.index_path):
|
||||
shutil.rmtree(self.index_path)
|
||||
|
||||
# 保存FAISS索引
|
||||
faiss_index.save_local(folder_path=self.index_path)
|
||||
logging.info(f"FAISS索引已保存至 {self.index_path}")
|
||||
except Exception as e:
|
||||
logging.error(f"保存FAISS索引失败: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
class ProfessionalNounRetriever:
|
||||
"""专业名词检索类"""
|
||||
|
||||
def __init__(self,
|
||||
embedding_model: Optional[Embeddings] = None,
|
||||
api_key: str = None,
|
||||
index_dir: str = None):
|
||||
"""
|
||||
初始化检索器并加载索引
|
||||
|
||||
Args:
|
||||
embedding_model: 嵌入模型,如果为None则使用默认模型
|
||||
api_key: SiliconFlow API密钥,仅在embedding_model为None时使用
|
||||
index_dir: 索引目录路径,默认为None使用默认路径
|
||||
"""
|
||||
# 设置嵌入模型
|
||||
if embedding_model:
|
||||
self.embedding_model = embedding_model
|
||||
else:
|
||||
self.embedding_model = get_embedding_model(api_key)
|
||||
|
||||
# 设置索引路径
|
||||
self.index_dir = index_dir
|
||||
if self.index_dir is None:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
self.index_dir = os.path.join(current_dir, "..", "..", "data", "nouns", "professional_nouns_index")
|
||||
|
||||
# 在构造函数中加载索引
|
||||
self.faiss_index = None
|
||||
self._load_index()
|
||||
|
||||
def _load_index(self) -> None:
|
||||
"""
|
||||
从本地加载FAISS索引 (内部方法)
|
||||
"""
|
||||
try:
|
||||
# 加载FAISS索引,启用不安全反序列化(仅用于可信数据源)
|
||||
self.faiss_index = FAISS.load_local(
|
||||
folder_path=self.index_dir,
|
||||
embeddings=self.embedding_model,
|
||||
allow_dangerous_deserialization=True
|
||||
)
|
||||
logging.info(f"成功从 {self.index_dir} 加载FAISS索引")
|
||||
except Exception as e:
|
||||
logging.warning(f"加载FAISS索引失败: {e}")
|
||||
self.faiss_index = None
|
||||
|
||||
def query(self, query_text: str, top_k: int = 5, use_intersection: bool = True) -> List[Dict]:
|
||||
"""
|
||||
查询FAISS索引,获取最相似的专业名词 (唯一对外接口)
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
top_k: 返回的结果数量,默认为5
|
||||
use_intersection: 是否使用三种检索方式的交集,默认为True
|
||||
|
||||
Returns:
|
||||
相似度最高的专业名词列表
|
||||
"""
|
||||
try:
|
||||
# 检查索引是否已加载
|
||||
if self.faiss_index is None:
|
||||
logging.warning("FAISS索引未加载,无法执行查询")
|
||||
return []
|
||||
|
||||
# 使用三种检索方式并取交集
|
||||
retriever1 = self.faiss_index.as_retriever(search_kwargs={"k": top_k})
|
||||
retriever2 = self.faiss_index.as_retriever(
|
||||
search_type="mmr",
|
||||
search_kwargs={"k": top_k, "fetch_k": 3, "lambda_mult": 0.5}
|
||||
)
|
||||
retriever3 = self.faiss_index.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={"score_threshold": 0.5}
|
||||
)
|
||||
|
||||
# 用json.dumps将dict转为字符串,便于取交集
|
||||
set1 = set(json.dumps(i.metadata, sort_keys=True, ensure_ascii=False)
|
||||
for i in retriever1.invoke(query_text))
|
||||
set2 = set(json.dumps(i.metadata, sort_keys=True, ensure_ascii=False)
|
||||
for i in retriever2.invoke(query_text))
|
||||
set3 = set(json.dumps(i.metadata, sort_keys=True, ensure_ascii=False)
|
||||
for i in retriever3.invoke(query_text))
|
||||
|
||||
intersection = set1 | set2 | set3
|
||||
|
||||
# 如果交集为空,使用第一种检索方式的结果
|
||||
if not intersection:
|
||||
logging.warning("三种检索方式无交集,使用普通检索结果")
|
||||
return [json.loads(item) for item in set1]
|
||||
|
||||
# 转回dict
|
||||
return [json.loads(item) for item in intersection]
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"查询FAISS索引失败: {e}")
|
||||
return []
|
||||
@@ -0,0 +1,130 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
File: PromptTemplates.py
|
||||
Author: oyyz
|
||||
Date: 2025-05-13
|
||||
Description: 提示词模板
|
||||
"""
|
||||
|
||||
extract_nouns_prompt="""
|
||||
【智能关键词提取助手】
|
||||
请根据用户问题自动识别核心关键词,并按照以下规则输出:
|
||||
1. 只输出最终关键词列表,不要解释说明
|
||||
2. 关键词提取范围包括但不限于以下内容:
|
||||
- 软件相关:功能模块/操作步骤/报错提示/扩展名后缀名
|
||||
- 造价专业:费用类型/计算标准/行业规范
|
||||
- 电力工程:项目类型/设备型号/工程阶段
|
||||
3. 自动展开缩写(如将'导excel'转为'Excel导入')
|
||||
4. 严格基于用户问题提取关键词,不要输出与用户问题无关的关键词
|
||||
|
||||
三、输出格式:
|
||||
{output_format}
|
||||
|
||||
四、用户问题:
|
||||
{content}
|
||||
|
||||
"""
|
||||
|
||||
classification_info="""【垂直领域分类】:
|
||||
1. 软件问题 -- 指涉及软件使用、功能询问、软件故障排查等方面的提问或请求。
|
||||
2. 业务问题 -- 指涉及电力造价领域专业知识、造价费用计算等电力造价业务知识
|
||||
3. 安装下载注册 -- 指涉及软件(或插件)安装下载、注册、激活等操作类问题。
|
||||
4. 其他 -- 指与软件或电力造价专业无关的日常对话、问候、感慨、情绪表达等。
|
||||
|
||||
【软件问题包括以下两类】:
|
||||
1. 软件功能:询问软件功能的使用、操作、位置等
|
||||
2. 故障排查:软件运行异常、软件报错、软件显示错误等
|
||||
|
||||
【业务问题包括以下两类】:
|
||||
1. 专业咨询:涉及电力造价规范、工程计价规则问题、行业标准解读等
|
||||
2. 数据问题:涉及电力造价费用、造价指标等
|
||||
|
||||
【安装下载注册包括以下三类】:
|
||||
1. 后缀名查询:询问有关软件后缀名、工程文件扩展名等问题,例如:BDY3是什么文件?、用什么软件打开.BDY3文件?
|
||||
2. 软件锁类:询问软件锁信息、锁注册号查询、许可证查询、锁激活问题等软件锁相关问题
|
||||
3. 安装下载类:安装下载咨询、组件(插件)选择、环境配置等
|
||||
4. 问题排查类:软件安装下载失败、报错,系统兼容性问题等
|
||||
|
||||
【其他】:
|
||||
1. 其他"""
|
||||
|
||||
classification_prompt="""
|
||||
用户正在使用电力造价软件或想询问电力造价领域相关知识,你需要根据用户的输入内容,将其归类为以下垂直领域之一:
|
||||
{classification_info}
|
||||
|
||||
【用户输入】:
|
||||
{user_input}
|
||||
|
||||
【输出格式要求】:
|
||||
{output_format}
|
||||
|
||||
【示例】
|
||||
用户输入1: 技改T1怎样新建工程
|
||||
输出1:
|
||||
{
|
||||
"vertical_classification":"软件咨询",
|
||||
"sub_classification":"软件功能"
|
||||
}
|
||||
"""
|
||||
|
||||
query_rewrite_prompt = """
|
||||
|
||||
你是一名电力造价专业问答优化工程师,负责通过多维度信息整合重构用户问题以提升知识库检索准确率。请严格遵循以下流程处理:
|
||||
|
||||
# 任务处理框架
|
||||
## 第一阶段:输入分析
|
||||
1. 解析基础信息
|
||||
- 原始问题(需保留核心语义):{query}
|
||||
- 关键词集合:{keywords}
|
||||
|
||||
## 第二阶段:语义匹配验证
|
||||
2. 执行关键词校验
|
||||
- 建立意图关联矩阵,验证关键词与原始问题的语义一致性
|
||||
- 若存在≥1个有效关联词 → 进入重构流程
|
||||
- 若无有效关联 → 直接输出原始问题
|
||||
|
||||
## 第三阶段:专业重构
|
||||
3. 术语规范化处理
|
||||
a. 实施术语映射:将口语表达替换为知识库标准术语
|
||||
b. 执行结构优化:
|
||||
- 采用【术语标记】规范标注关键概念
|
||||
- 构建主谓宾明确的问题句式
|
||||
- 保持原问题时态与语态特征
|
||||
|
||||
# 输出规范
|
||||
{output_format}
|
||||
|
||||
# 示范案例库
|
||||
▶ 案例1(有效匹配)
|
||||
输入:
|
||||
原始问题:怎么把旧版西藏定额工程转到Z1新版
|
||||
关键词:【'老版本定额升级', '批量设置定额', '西藏造价软件Z1'】
|
||||
输出:
|
||||
{{"rewrite":"【西藏造价软件Z1】如何执行【老版本定额升级】操作?"}}
|
||||
|
||||
▶ 案例2(无效匹配)
|
||||
输入:
|
||||
原始问题:程序界面文字显示过小如何处理?
|
||||
关键词:【'定额升级', '工程批量导入'】
|
||||
输出:
|
||||
{{"rewrite":"程序界面文字显示过小如何处理?"}}
|
||||
|
||||
# 质量约束条款
|
||||
1. 语义内容保真原则
|
||||
- 禁止修改原问题核心诉求(如转换主语/变更操作对象)
|
||||
- 保留原始问题的限定条件
|
||||
|
||||
2. 术语使用规范
|
||||
- 仅使用检索返回的关键词进行术语替换
|
||||
- 新增术语必须来自关键词集合
|
||||
|
||||
3. 结构优化标准
|
||||
- 问题长度控制在20字内
|
||||
- 必须包含≥1个【标注术语】
|
||||
- 禁止添加解释性语句
|
||||
|
||||
4. 异常处理机制
|
||||
- 当关键词与问题无明显关联时,触发直通输出规则
|
||||
- 出现术语冲突时优先保留原始表述
|
||||
"""
|
||||
@@ -0,0 +1,5 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from .ProfessionalNounVector import ProfessionalNounVectorizer, ProfessionalNounRetriever
|
||||
from .IntentRecognition import IntentRecognizer
|
||||
from .DataModels import Term, TermList, Classification, QueryRewrite
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user