优化意图识别模块,新增文档相关性判断功能,更新DifyQueryRetrieval类以支持多线程检索,增强数据模型,改进日志记录,调整Excel数据验证逻辑,更新多个提示词模板以提升用户体验。
This commit is contained in:
@@ -16,13 +16,92 @@ from tqdm import tqdm
|
||||
import time
|
||||
import sys
|
||||
import argparse
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Any, Optional
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
from pydantic import BaseModel, Field
|
||||
sys.path.append(os.getcwd())
|
||||
from rag2_0.intent_recognition import IntentRecognizer
|
||||
from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval
|
||||
from rag2_0.intent_recognition.DataModels import Classification
|
||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://172.20.0.145/v1")
|
||||
|
||||
def is_retrieved_doc_relevant(query: str, retrieved_doc: List[Dict[str, Any]], api_key: str = None, base_url: str = None, model_name: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
使用LLM判断检索出的文档是否与用户提问相关
|
||||
|
||||
Args:
|
||||
query: 用户提问
|
||||
retrieved_doc: 检索出的文档列表
|
||||
api_key: API密钥,默认使用环境变量
|
||||
base_url: API基础URL,默认使用环境变量
|
||||
model_name: 模型名称,默认使用环境变量或默认模型
|
||||
|
||||
Returns:
|
||||
包含相关性判断结果的字典,包括is_relevant(布尔值)和explanation(解释)
|
||||
"""
|
||||
# 使用环境变量或参数值
|
||||
api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
base_url = base_url or os.getenv("OPENAI_API_BASE")
|
||||
model_name = model_name or os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
|
||||
# 如果没有检索到文档,直接返回不相关
|
||||
if not retrieved_doc or len(retrieved_doc) == 0:
|
||||
return {
|
||||
"is_relevant": False,
|
||||
"explanation": "没有检索到任何文档",
|
||||
"relevance_score": 0.0
|
||||
}
|
||||
|
||||
# 构建文档内容
|
||||
doc_contents = []
|
||||
for i, doc in enumerate(retrieved_doc[:3]): # 只取前3个文档进行判断
|
||||
content = doc.get("content", "")
|
||||
title = doc.get("title", "")
|
||||
doc_contents.append(f"文档{i+1}标题: {title}\n文档{i+1}内容: {content}")
|
||||
|
||||
doc_text = "\n\n".join(doc_contents)
|
||||
class TempModel(BaseModel):
|
||||
is_relevant: bool = Field(description="是否与用户提问相关")
|
||||
relevance_score: int = Field(description="相关性评分,0-100分")
|
||||
explanation: str = Field(description="解释各个文档与提问的相关性或不相关性")
|
||||
|
||||
parser = PydanticOutputParser(pydantic_object=TempModel)
|
||||
# 构建提示词
|
||||
prompt = f"""请判断以下检索文档是否与用户提问相关,并给出相关性评分(0-100分)。
|
||||
|
||||
用户提问: {query}
|
||||
|
||||
检索文档:
|
||||
{doc_text}
|
||||
|
||||
请按照以下JSON格式返回结果:
|
||||
{parser.get_format_instructions()}
|
||||
"""
|
||||
|
||||
try:
|
||||
# 初始化LLM并调用
|
||||
llm = OpenAiLLM(api_key=api_key, base_url=base_url, model="deepseek-ai/DeepSeek-R1", response_format={"type": "json_object"})
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
result = parser.parse(response.content)
|
||||
|
||||
return {
|
||||
"is_relevant": result.is_relevant,
|
||||
"relevance_score": result.relevance_score,
|
||||
"explanation": result.explanation
|
||||
}
|
||||
except Exception as e:
|
||||
logging.error(f"判断文档相关性时出错: {str(e)}")
|
||||
return {
|
||||
"is_relevant": False,
|
||||
"explanation": f"判断过程出错: {str(e)}",
|
||||
"relevance_score": 0.0
|
||||
}
|
||||
|
||||
# 读取Excel文件中的提问数据
|
||||
def load_questions_from_excel(file_path=None):
|
||||
"""
|
||||
@@ -70,23 +149,33 @@ def process_query(recognizer: IntentRecognizer, query: str, conversation_context
|
||||
enable_query_expansion=True)
|
||||
# 提取分类信息
|
||||
classification = result["classification"]
|
||||
original_query = result["rewrite"]["rewrite"]
|
||||
query_list = result["query_expand"]["all"]
|
||||
soft_name = result.get("slot_filling", {}).get("filled_data", {}).get("software_name","")
|
||||
# 将字典转换为Classification对象
|
||||
classification_obj = Classification(**classification)
|
||||
retrieved_doc=dify_query_retrieval.retrieve(original_query, query_list, classification_obj, soft_name)
|
||||
|
||||
# 提取关键词信息
|
||||
keywords = result["keywords"]
|
||||
keywords_str = ""
|
||||
if keywords and keywords.get("terms"):
|
||||
term_details = []
|
||||
for term in keywords["terms"]:
|
||||
term_info = {
|
||||
"名称": term["name"],
|
||||
"同义词": ";".join(term["synonymous"]) if term["synonymous"] else "",
|
||||
"描述": term["description"]
|
||||
}
|
||||
term_details.append(term_info)
|
||||
|
||||
# 将term_details转换为JSON字符串,确保中文正确显示
|
||||
keywords_str = json.dumps(term_details, ensure_ascii=False, indent=2)
|
||||
|
||||
# 判断检索文档是否相关
|
||||
relevance_result = {}
|
||||
if retrieved_doc:
|
||||
# 获取API密钥和基础URL
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
# 判断文档相关性
|
||||
relevance_result = is_retrieved_doc_relevant(query, retrieved_doc, api_key, base_url, model_name)
|
||||
else:
|
||||
retrieved_doc_str = []
|
||||
relevance_result = {
|
||||
"is_relevant": False,
|
||||
"explanation": "没有检索到文档",
|
||||
"relevance_score": 0.0
|
||||
}
|
||||
|
||||
retrieved_doc_titles=[]
|
||||
if retrieved_doc:
|
||||
retrieved_doc_titles=[doc["title"].split("/")[-1] for doc in retrieved_doc]
|
||||
# 提取槽位填充信息
|
||||
slot_filling = result.get("slot_filling", {})
|
||||
slot_filling_str = ""
|
||||
@@ -97,30 +186,31 @@ def process_query(recognizer: IntentRecognizer, query: str, conversation_context
|
||||
"缺失槽位": slot_filling.get("missing_slots", {}),
|
||||
"填充数据": slot_filling.get("filled_data", {})
|
||||
}, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
# 处理成功,返回结果
|
||||
return {
|
||||
"提问": query,
|
||||
"问题拆解": result["query_keys"],
|
||||
"一级分类": classification["vertical_classification"],
|
||||
"二级分类": classification["sub_classification"],
|
||||
"问题分类": f"{classification['vertical_classification']} - {classification['sub_classification']}",
|
||||
"问题改写": result["rewrite"]["rewrite"],
|
||||
"检索的关键词": keywords_str,
|
||||
"槽位填充": slot_filling_str
|
||||
"槽位填充": slot_filling_str,
|
||||
"检索的文档": "\n".join(retrieved_doc_titles),
|
||||
"文档是否相关": "相关" if relevance_result["is_relevant"] else "不相关",
|
||||
"文档相关性解释": relevance_result["explanation"]
|
||||
}
|
||||
except Exception as e:
|
||||
logging.error(f"处理问题 '{query}' 时出错: ",exc_info=True)
|
||||
retry_count += 1
|
||||
|
||||
# 如果已经重试了最大次数,则记录错误并返回错误结果
|
||||
if retry_count > max_retries:
|
||||
logging.error(f"处理问题 '{query}' 时出错: {e.__class__}{e}")
|
||||
return {
|
||||
"提问": query,
|
||||
"一级分类": "处理出错",
|
||||
"二级分类": "处理出错",
|
||||
"问题分类": "处理出错",
|
||||
"问题改写": "处理出错",
|
||||
"检索的关键词": f"重试 {max_retries} 次后失败: {str(e)}",
|
||||
"槽位填充": "处理出错"
|
||||
"槽位填充": "处理出错",
|
||||
"检索的文档": f"重试 {max_retries} 次后失败: {str(e)}",
|
||||
"文档是否相关": "处理出错",
|
||||
"文档相关性解释": "处理出错"
|
||||
}
|
||||
else:
|
||||
# 可以在这里添加延迟,避免过快重试
|
||||
@@ -172,6 +262,7 @@ def save_results_to_excel(results, output_file, is_final=False):
|
||||
worksheet.set_column('E:E', 60) # 问题改写 60个Excel单位
|
||||
worksheet.set_column('F:F', 60) # 检索到的关键词 60个Excel单位
|
||||
worksheet.set_column('G:G', 80) # 槽位填充 80个Excel单位
|
||||
worksheet.set_column('H:H', 60) # 文档相关性 60个Excel单位
|
||||
|
||||
# 设置所有行高为20磅
|
||||
for i in range(len(results_df) + 1): # +1 是为了包括表头
|
||||
@@ -222,7 +313,7 @@ def parse_arguments():
|
||||
help='API基础URL,默认使用环境变量中的配置')
|
||||
|
||||
# 添加处理相关参数
|
||||
parser.add_argument('--max_workers', '-w', type=int, default=20,
|
||||
parser.add_argument('--max_workers', '-w', type=int, default=2,
|
||||
help='并发处理的最大线程数,默认为20')
|
||||
parser.add_argument('--debug', '-d', action='store_true',
|
||||
help='启用调试模式,使用示例查询而非从文件读取')
|
||||
@@ -249,12 +340,12 @@ def main():
|
||||
|
||||
# 读取提问数据
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
data_file = args.input if args.input else os.path.join(current_dir, "..", "..", "data", "excel", "历史提问数据(dislike)_提问明确.xlsx")
|
||||
output_file = args.output if args.output else os.path.join(current_dir, "..", "..", "data", "excel", "历史提问数据(dislike)_槽位(分类)填充结果.xlsx")
|
||||
data_file = args.input if args.input else os.path.join(current_dir, "..", "..", "data", "excel", "1500条点踩软件问题测试.xlsx")
|
||||
output_file = args.output if args.output else os.path.join(current_dir, "..", "..", "data", "excel", "1500条点踩软件问题_槽位(分类)填充结果.xlsx")
|
||||
|
||||
# 检测是否为调试模式
|
||||
is_debug = args.debug or (hasattr(sys, 'gettrace') and sys.gettrace() is not None)
|
||||
|
||||
is_debug = False
|
||||
if is_debug:
|
||||
# 如果提供了查询参数,使用它;否则使用默认示例
|
||||
if args.query:
|
||||
@@ -287,12 +378,7 @@ def main():
|
||||
result = future.result()
|
||||
# 将结果放在与输入相同的位置
|
||||
results[idx] = result
|
||||
|
||||
completed += 1
|
||||
# 每处理batch_size条数据保存一次
|
||||
# if completed % batch_size == 0:
|
||||
# logging.info(f"已完成 {completed}/{len(examples)} 条,保存中间结果...")
|
||||
# save_results_to_excel(results, output_file, is_final=False)
|
||||
|
||||
# 处理完所有数据后,保存最终结果
|
||||
save_results_to_excel(results, output_file, is_final=True)
|
||||
@@ -308,7 +394,7 @@ def setup_logging():
|
||||
# 配置日志输出到控制台
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
format='%(asctime)s - %(name)s - [%(thread)d] - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler() # 添加控制台处理器
|
||||
]
|
||||
|
||||
@@ -390,7 +390,7 @@ class ExcelDataValidator:
|
||||
return index, False, "槽位填充", error_reason, slot_confidence
|
||||
|
||||
# 4. 验证检索内容
|
||||
if retrieve_content:
|
||||
if retrieve_content and retrieve_content != "" and pd.notna(retrieve_content):
|
||||
if self.debug:
|
||||
logging.info(f" 验证检索内容...")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user