优化意图识别模块,新增文档相关性判断功能,更新DifyQueryRetrieval类以支持多线程检索,增强数据模型,改进日志记录,调整Excel数据验证逻辑,更新多个提示词模板以提升用户体验。

This commit is contained in:
2025-06-24 15:03:01 +08:00
parent d957b4374e
commit 4386cfac41
8 changed files with 737 additions and 191 deletions
+124 -38
View File
@@ -16,13 +16,92 @@ from tqdm import tqdm
import time
import sys
import argparse
from typing import List, Dict
from typing import List, Dict, Any, Optional
from langchain.output_parsers import PydanticOutputParser
from pydantic import BaseModel, Field
sys.path.append(os.getcwd())
from rag2_0.intent_recognition import IntentRecognizer
from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval
from rag2_0.intent_recognition.DataModels import Classification
from rag2_0.tool.ModelTool import OpenAiLLM
# 加载环境变量
load_dotenv()
dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://172.20.0.145/v1")
def is_retrieved_doc_relevant(query: str, retrieved_doc: List[Dict[str, Any]], api_key: str = None, base_url: str = None, model_name: str = None) -> Dict[str, Any]:
"""
使用LLM判断检索出的文档是否与用户提问相关
Args:
query: 用户提问
retrieved_doc: 检索出的文档列表
api_key: API密钥,默认使用环境变量
base_url: API基础URL,默认使用环境变量
model_name: 模型名称,默认使用环境变量或默认模型
Returns:
包含相关性判断结果的字典,包括is_relevant(布尔值)和explanation(解释)
"""
# 使用环境变量或参数值
api_key = api_key or os.getenv("OPENAI_API_KEY")
base_url = base_url or os.getenv("OPENAI_API_BASE")
model_name = model_name or os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
# 如果没有检索到文档,直接返回不相关
if not retrieved_doc or len(retrieved_doc) == 0:
return {
"is_relevant": False,
"explanation": "没有检索到任何文档",
"relevance_score": 0.0
}
# 构建文档内容
doc_contents = []
for i, doc in enumerate(retrieved_doc[:3]): # 只取前3个文档进行判断
content = doc.get("content", "")
title = doc.get("title", "")
doc_contents.append(f"文档{i+1}标题: {title}\n文档{i+1}内容: {content}")
doc_text = "\n\n".join(doc_contents)
class TempModel(BaseModel):
is_relevant: bool = Field(description="是否与用户提问相关")
relevance_score: int = Field(description="相关性评分,0-100分")
explanation: str = Field(description="解释各个文档与提问的相关性或不相关性")
parser = PydanticOutputParser(pydantic_object=TempModel)
# 构建提示词
prompt = f"""请判断以下检索文档是否与用户提问相关,并给出相关性评分(0-100分)。
用户提问: {query}
检索文档:
{doc_text}
请按照以下JSON格式返回结果:
{parser.get_format_instructions()}
"""
try:
# 初始化LLM并调用
llm = OpenAiLLM(api_key=api_key, base_url=base_url, model="deepseek-ai/DeepSeek-R1", response_format={"type": "json_object"})
response = llm.invoke(prompt)
result = parser.parse(response.content)
return {
"is_relevant": result.is_relevant,
"relevance_score": result.relevance_score,
"explanation": result.explanation
}
except Exception as e:
logging.error(f"判断文档相关性时出错: {str(e)}")
return {
"is_relevant": False,
"explanation": f"判断过程出错: {str(e)}",
"relevance_score": 0.0
}
# 读取Excel文件中的提问数据
def load_questions_from_excel(file_path=None):
"""
@@ -70,23 +149,33 @@ def process_query(recognizer: IntentRecognizer, query: str, conversation_context
enable_query_expansion=True)
# 提取分类信息
classification = result["classification"]
original_query = result["rewrite"]["rewrite"]
query_list = result["query_expand"]["all"]
soft_name = result.get("slot_filling", {}).get("filled_data", {}).get("software_name","")
# 将字典转换为Classification对象
classification_obj = Classification(**classification)
retrieved_doc=dify_query_retrieval.retrieve(original_query, query_list, classification_obj, soft_name)
# 提取关键词信息
keywords = result["keywords"]
keywords_str = ""
if keywords and keywords.get("terms"):
term_details = []
for term in keywords["terms"]:
term_info = {
"名称": term["name"],
"同义词": ";".join(term["synonymous"]) if term["synonymous"] else "",
"描述": term["description"]
}
term_details.append(term_info)
# 将term_details转换为JSON字符串,确保中文正确显示
keywords_str = json.dumps(term_details, ensure_ascii=False, indent=2)
# 判断检索文档是否相关
relevance_result = {}
if retrieved_doc:
# 获取API密钥和基础URL
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_API_BASE")
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
# 判断文档相关性
relevance_result = is_retrieved_doc_relevant(query, retrieved_doc, api_key, base_url, model_name)
else:
retrieved_doc_str = []
relevance_result = {
"is_relevant": False,
"explanation": "没有检索到文档",
"relevance_score": 0.0
}
retrieved_doc_titles=[]
if retrieved_doc:
retrieved_doc_titles=[doc["title"].split("/")[-1] for doc in retrieved_doc]
# 提取槽位填充信息
slot_filling = result.get("slot_filling", {})
slot_filling_str = ""
@@ -97,30 +186,31 @@ def process_query(recognizer: IntentRecognizer, query: str, conversation_context
"缺失槽位": slot_filling.get("missing_slots", {}),
"填充数据": slot_filling.get("filled_data", {})
}, ensure_ascii=False, indent=2)
# 处理成功,返回结果
return {
"提问": query,
"问题拆解": result["query_keys"],
"一级分类": classification["vertical_classification"],
"二级分类": classification["sub_classification"],
"问题分类": f"{classification['vertical_classification']} - {classification['sub_classification']}",
"问题改写": result["rewrite"]["rewrite"],
"检索的关键词": keywords_str,
"槽位填充": slot_filling_str
"槽位填充": slot_filling_str,
"检索的文档": "\n".join(retrieved_doc_titles),
"文档是否相关": "相关" if relevance_result["is_relevant"] else "不相关",
"文档相关性解释": relevance_result["explanation"]
}
except Exception as e:
logging.error(f"处理问题 '{query}' 时出错: ",exc_info=True)
retry_count += 1
# 如果已经重试了最大次数,则记录错误并返回错误结果
if retry_count > max_retries:
logging.error(f"处理问题 '{query}' 时出错: {e.__class__}{e}")
return {
"提问": query,
"一级分类": "处理出错",
"二级分类": "处理出错",
"问题分类": "处理出错",
"问题改写": "处理出错",
"检索的关键词": f"重试 {max_retries} 次后失败: {str(e)}",
"槽位填充": "处理出错"
"槽位填充": "处理出错",
"检索的文档": f"重试 {max_retries} 次后失败: {str(e)}",
"文档是否相关": "处理出错",
"文档相关性解释": "处理出错"
}
else:
# 可以在这里添加延迟,避免过快重试
@@ -172,6 +262,7 @@ def save_results_to_excel(results, output_file, is_final=False):
worksheet.set_column('E:E', 60) # 问题改写 60个Excel单位
worksheet.set_column('F:F', 60) # 检索到的关键词 60个Excel单位
worksheet.set_column('G:G', 80) # 槽位填充 80个Excel单位
worksheet.set_column('H:H', 60) # 文档相关性 60个Excel单位
# 设置所有行高为20磅
for i in range(len(results_df) + 1): # +1 是为了包括表头
@@ -222,7 +313,7 @@ def parse_arguments():
help='API基础URL,默认使用环境变量中的配置')
# 添加处理相关参数
parser.add_argument('--max_workers', '-w', type=int, default=20,
parser.add_argument('--max_workers', '-w', type=int, default=2,
help='并发处理的最大线程数,默认为20')
parser.add_argument('--debug', '-d', action='store_true',
help='启用调试模式,使用示例查询而非从文件读取')
@@ -249,12 +340,12 @@ def main():
# 读取提问数据
current_dir = os.path.dirname(os.path.abspath(__file__))
data_file = args.input if args.input else os.path.join(current_dir, "..", "..", "data", "excel", "历史提问数据(dislike)_提问明确.xlsx")
output_file = args.output if args.output else os.path.join(current_dir, "..", "..", "data", "excel", "历史提问数据(dislike)_槽位(分类)填充结果.xlsx")
data_file = args.input if args.input else os.path.join(current_dir, "..", "..", "data", "excel", "1500条点踩软件问题测试.xlsx")
output_file = args.output if args.output else os.path.join(current_dir, "..", "..", "data", "excel", "1500条点踩软件问题_槽位(分类)填充结果.xlsx")
# 检测是否为调试模式
is_debug = args.debug or (hasattr(sys, 'gettrace') and sys.gettrace() is not None)
is_debug = False
if is_debug:
# 如果提供了查询参数,使用它;否则使用默认示例
if args.query:
@@ -287,12 +378,7 @@ def main():
result = future.result()
# 将结果放在与输入相同的位置
results[idx] = result
completed += 1
# 每处理batch_size条数据保存一次
# if completed % batch_size == 0:
# logging.info(f"已完成 {completed}/{len(examples)} 条,保存中间结果...")
# save_results_to_excel(results, output_file, is_final=False)
# 处理完所有数据后,保存最终结果
save_results_to_excel(results, output_file, is_final=True)
@@ -308,7 +394,7 @@ def setup_logging():
# 配置日志输出到控制台
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
format='%(asctime)s - %(name)s - [%(thread)d] - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler() # 添加控制台处理器
]