435 lines
19 KiB
Python
Executable File
435 lines
19 KiB
Python
Executable File
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
File: intent_recognition_example.py
|
|
Date: 2025-05-14
|
|
Description: 意图识别和问题改写示例
|
|
"""
|
|
|
|
import os
|
|
from dotenv import load_dotenv
|
|
import pandas as pd
|
|
import logging
|
|
import json
|
|
import concurrent.futures
|
|
from tqdm import tqdm
|
|
import time
|
|
import sys
|
|
import argparse
|
|
from typing import List, Dict, Any, Optional
|
|
from langchain.output_parsers import PydanticOutputParser
|
|
from pydantic import BaseModel, Field
|
|
sys.path.append(os.getcwd())
|
|
from rag2_0.intent_recognition import IntentRecognizer
|
|
from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval
|
|
from rag2_0.intent_recognition.DataModels import Classification
|
|
from rag2_0.tool.ModelTool import OpenAiLLM
|
|
# 加载环境变量
|
|
load_dotenv()
|
|
|
|
# 示例查询
|
|
examples_query = """ PE2211PK0801是什么软件"""
|
|
conversation_context=""
|
|
chat_history=[
|
|
{
|
|
"user": "怎么新建工程啊"
|
|
},
|
|
{
|
|
"assistant": "您好!为了更准确地帮助您解决问题,请提供以下信息:\n\n- **软件名称**:请确认您使用的软件名称,可选值如下:\n - 配网工程计价通D3软件\n - 新型储能电站建设计价通C1软件\n - 西藏电力工程计价通Z1软件\n - 技改检修工程计价通T1软件\n - 技改检修清单计价通T1软件\n - 主网电力建设计价通软件\n\n请补充以上信息,以便我们能够更好地指导您如何新建工程。"
|
|
},
|
|
{
|
|
"user": "配网工程计价通D3软件"
|
|
},
|
|
{
|
|
"assistant": """好的,针对您使用**配网工程计价通D3软件**新建工程的问题,以下是具体的操作步骤:\n\n## 新建工程\n\n### 功能入口\n【主页】界面——"新建"按钮\n\n### 操作步骤\n1. 在"新建窗口"选择对应工程模板。\n2. 设置工程名称。\n3. 确定后根据新建向导完善"电压等级"、"地区类型"等参数。\n4. 点击"确定"即可完成新建工程。\n\n\n\n希望这些步骤能帮助您顺利完成新建工程。如果还有其他问题,欢迎随时提问!\n"""
|
|
}
|
|
]
|
|
previous_slots={
|
|
"software_name": "配网工程计价通D3软件",
|
|
"function_name": "新建工程",
|
|
"operation": "如何新建工程",
|
|
"project_type": None,
|
|
"software_version": None,
|
|
"operation_steps": None
|
|
}
|
|
|
|
class QueryRewriteProcessor:
|
|
"""
|
|
查询改写处理器,用于意图识别、问题改写和文档检索
|
|
"""
|
|
def __init__(self,
|
|
api_key: str = None,
|
|
base_url: str = None,
|
|
model_name: str = None,
|
|
dify_api_key: str = "dataset-skLjmPVonjHo119OWNf3kAmY",
|
|
dify_base_url: str = "http://172.20.0.145/v1"):
|
|
"""
|
|
初始化查询改写处理器
|
|
|
|
Args:
|
|
api_key: API密钥,默认使用环境变量
|
|
base_url: API基础URL,默认使用环境变量
|
|
model_name: 模型名称,默认使用环境变量或默认模型
|
|
dify_api_key: Dify API密钥
|
|
dify_base_url: Dify API基础URL
|
|
"""
|
|
# 初始化意图识别器
|
|
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
|
self.base_url = base_url or os.getenv("OPENAI_API_BASE")
|
|
self.model_name = model_name or os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
|
|
|
self.recognizer = IntentRecognizer(api_key=self.api_key, base_url=self.base_url, model_name=self.model_name)
|
|
self.dify_query_retrieval = DifyQueryRetrieval(api_key=dify_api_key, base_url=dify_base_url)
|
|
|
|
def is_retrieved_doc_relevant(self, query: str, retrieved_doc: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
"""
|
|
使用LLM判断检索出的文档是否与用户提问相关
|
|
|
|
Args:
|
|
query: 用户提问
|
|
retrieved_doc: 检索出的文档列表
|
|
|
|
Returns:
|
|
包含相关性判断结果的字典,包括is_relevant(布尔值)和explanation(解释)
|
|
"""
|
|
# 如果没有检索到文档,直接返回不相关
|
|
if not retrieved_doc or len(retrieved_doc) == 0:
|
|
return {
|
|
"is_relevant": False,
|
|
"explanation": "没有检索到任何文档",
|
|
"relevance_score": 0.0
|
|
}
|
|
|
|
doc_text_list = json.dumps(retrieved_doc, ensure_ascii=False, indent=2)
|
|
class TempModel(BaseModel):
|
|
can_solve_problem: bool = Field(description="是否能解决用户问题")
|
|
relevance_score: int = Field(description="相关性评分,0-100分")
|
|
explanation: str = Field(description="解释文档是否能解决(回答)提问")
|
|
|
|
class most_relevant_document(BaseModel):
|
|
most_relevant_document: TempModel = Field(description="最相关的文档的判断结果")
|
|
|
|
parser = PydanticOutputParser(pydantic_object=most_relevant_document)
|
|
# 构建提示词
|
|
prompt = f"""请判断以下检索文档列表中是否与用户提问相关,能够解决用户的问题,并给出相关性评分(0-100分)。输出最相关的文档的判断结果。
|
|
|
|
用户提问: {query}
|
|
|
|
检索文档列表:
|
|
{doc_text_list}
|
|
|
|
请按照以下JSON格式返回结果:
|
|
{parser.get_format_instructions()}
|
|
"""
|
|
|
|
try:
|
|
# 初始化LLM并调用
|
|
llm = OpenAiLLM(api_key=self.api_key, base_url=self.base_url, model="deepseek-ai/DeepSeek-R1", response_format={"type": "json_object"})
|
|
response = llm.invoke(prompt)
|
|
|
|
result = parser.parse(response.content).most_relevant_document
|
|
|
|
return {
|
|
"is_relevant": result.can_solve_problem,
|
|
"relevance_score": result.relevance_score,
|
|
"explanation": result.explanation
|
|
}
|
|
except Exception as e:
|
|
logging.error(f"判断文档相关性时出错: {str(e)}", exc_info=True)
|
|
return {
|
|
"is_relevant": False,
|
|
"explanation": f"判断过程出错: {str(e)}",
|
|
"relevance_score": 0.0
|
|
}
|
|
|
|
def load_questions_from_excel(self, file_path=None):
|
|
"""
|
|
从Excel文件中读取提问数据
|
|
|
|
Args:
|
|
file_path: Excel文件路径,如果为None则使用默认路径
|
|
|
|
Returns:
|
|
提问列表
|
|
"""
|
|
try:
|
|
# 读取Excel文件的第一列数据
|
|
df = pd.read_excel(file_path)
|
|
questions = df.iloc[:, 0].tolist() # 获取第一列数据
|
|
logging.info(f"成功从{file_path}读取了{len(questions)}条提问")
|
|
return questions
|
|
except Exception as e:
|
|
logging.error(f"读取Excel文件时出错: {e}", exc_info=True)
|
|
return []
|
|
|
|
def process_query(self, query: str, conversation_context: str = "", chat_history: List[Dict[str, str]] = None, previous_slots: Dict[str, str] = None, enable_retrieval: bool = False):
|
|
"""
|
|
处理单个查询,支持重试机制,并包含槽位填充
|
|
|
|
Args:
|
|
query: 查询字符串
|
|
conversation_context: 对话上下文
|
|
chat_history: 聊天历史记录
|
|
previous_slots: 之前识别的槽位信息
|
|
enable_retrieval: 是否启用文档检索功能,默认为False
|
|
|
|
Returns:
|
|
处理结果字典
|
|
"""
|
|
max_retries = 3
|
|
retry_count = 0
|
|
|
|
while retry_count <= max_retries:
|
|
try:
|
|
# 使用process_query方法处理查询
|
|
result = self.recognizer.process_query(query,
|
|
conversation_context=conversation_context,
|
|
chat_history=chat_history,
|
|
previous_slots=previous_slots,
|
|
enable_query_expansion=True)
|
|
# 提取分类信息
|
|
classification = result["classification"]
|
|
original_query = result["rewrite"]["rewrite"]
|
|
query_list = result["query_expand"]["all"]
|
|
soft_name = result.get("slot_filling", {}).get("filled_data", {}).get("software_name","")
|
|
# 将字典转换为Classification对象
|
|
classification_obj = Classification(**classification)
|
|
|
|
# 根据enable_retrieval参数决定是否进行文档检索
|
|
retrieved_doc = None
|
|
if enable_retrieval:
|
|
retrieved_doc = self.dify_query_retrieval.retrieve(original_query, query_list, classification_obj, soft_name)
|
|
|
|
# 判断检索文档是否相关
|
|
relevance_result = {}
|
|
if retrieved_doc:
|
|
# 判断文档相关性
|
|
relevance_result = self.is_retrieved_doc_relevant(query, retrieved_doc)
|
|
else:
|
|
relevance_result = {
|
|
"is_relevant": False,
|
|
"explanation": "没有检索到文档" if enable_retrieval else "文档检索功能未启用",
|
|
"relevance_score": 0.0
|
|
}
|
|
|
|
retrieved_doc_titles=[]
|
|
if retrieved_doc:
|
|
retrieved_doc_titles=[doc["title"].split("/")[-1] for doc in retrieved_doc]
|
|
# 提取槽位填充信息
|
|
slot_filling = result.get("slot_filling", {})
|
|
slot_filling_str = ""
|
|
if slot_filling and "filled_data" in slot_filling:
|
|
# 格式化槽位填充结果
|
|
slot_filling_str = json.dumps({
|
|
"是否完整": slot_filling.get("is_complete", False),
|
|
"缺失槽位": slot_filling.get("missing_slots", {}),
|
|
"填充数据": slot_filling.get("filled_data", {})
|
|
}, ensure_ascii=False, indent=2)
|
|
|
|
# 处理成功,返回结果
|
|
return {
|
|
"问题": query,
|
|
"问题分类": f"{classification['vertical_classification']} - {classification['sub_classification']}",
|
|
"问题改写": result["rewrite"]["rewrite"],
|
|
"槽位信息": slot_filling_str,
|
|
"检索的文档": "\n".join(retrieved_doc_titles),
|
|
"检索的内容": json.dumps(retrieved_doc, ensure_ascii=False, indent=2) if retrieved_doc else "",
|
|
"文档是否相关": "相关" if relevance_result["is_relevant"] else "不相关",
|
|
"文档相关性解释": relevance_result["explanation"]
|
|
}
|
|
except Exception as e:
|
|
logging.error(f"处理问题 '{query}' 时出错: ",exc_info=True)
|
|
retry_count += 1
|
|
|
|
# 如果已经重试了最大次数,则记录错误并返回错误结果
|
|
if retry_count > max_retries:
|
|
return {
|
|
"问题": query,
|
|
"问题分类": "处理出错",
|
|
"问题改写": "处理出错",
|
|
"槽位信息": "处理出错",
|
|
"检索的文档": f"重试 {max_retries} 次后失败: {str(e)}",
|
|
"检索的内容":"",
|
|
"文档是否相关": "处理出错",
|
|
"文档相关性解释": "处理出错"
|
|
}
|
|
else:
|
|
# 可以在这里添加延迟,避免过快重试
|
|
time.sleep(10)
|
|
|
|
def save_results_to_excel(self, results, output_file, is_final=False):
|
|
"""
|
|
将结果保存到Excel文件
|
|
|
|
Args:
|
|
results: 结果列表
|
|
output_file: 输出文件路径
|
|
is_final: 是否为最终保存,如果是则使用完整文件名,否则添加临时标记
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
# 过滤掉None值
|
|
valid_results = [r for r in results if r is not None]
|
|
|
|
if not valid_results:
|
|
logging.warning("没有有效结果可保存")
|
|
return
|
|
|
|
# 创建DataFrame
|
|
results_df = pd.DataFrame(valid_results)
|
|
|
|
# 根据是否为最终保存确定文件名
|
|
if not is_final:
|
|
file_name, file_ext = os.path.splitext(output_file)
|
|
temp_output_file = f"{file_name}_temp{file_ext}"
|
|
else:
|
|
temp_output_file = output_file
|
|
|
|
# 使用ExcelWriter设置格式
|
|
with pd.ExcelWriter(temp_output_file, engine='xlsxwriter') as writer:
|
|
results_df.to_excel(writer, index=False, sheet_name='Sheet1')
|
|
|
|
# 获取工作簿和工作表对象
|
|
workbook = writer.book
|
|
worksheet = writer.sheets['Sheet1']
|
|
|
|
# 设置列宽(单位:像素)
|
|
# 定义列宽(厘米转为Excel单位,1cm约等于4.7个Excel单位)
|
|
worksheet.set_column('A:A', 60) # 提问列 60个Excel单位
|
|
worksheet.set_column('B:B', 20) # 问题拆解 20个Excel单位
|
|
worksheet.set_column('C:C', 20) # 一级分类 20个Excel单位
|
|
worksheet.set_column('D:D', 20) # 二级分类 20个Excel单位
|
|
worksheet.set_column('E:E', 60) # 问题改写 60个Excel单位
|
|
worksheet.set_column('F:F', 60) # 检索到的关键词 60个Excel单位
|
|
worksheet.set_column('G:G', 80) # 槽位填充 80个Excel单位
|
|
worksheet.set_column('H:H', 60) # 文档相关性 60个Excel单位
|
|
|
|
# 设置所有行高为20磅
|
|
for i in range(len(results_df) + 1): # +1 是为了包括表头
|
|
worksheet.set_row(i, 20)
|
|
|
|
logging.info(f"已保存{len(valid_results)}条结果至: {temp_output_file}")
|
|
|
|
def process_batch(self, questions: List[str], max_workers: int = 2, enable_retrieval: bool = False, output_file: str = None):
|
|
"""
|
|
批量处理多个问题
|
|
|
|
Args:
|
|
questions: 问题列表
|
|
max_workers: 并发处理的最大线程数,默认为2
|
|
enable_retrieval: 是否启用文档检索功能,默认为False
|
|
output_file: 输出文件路径,如果为None则不保存结果
|
|
|
|
Returns:
|
|
处理结果列表
|
|
"""
|
|
logging.info(f"共有 {len(questions)} 个问题需要处理,使用 {max_workers} 个并发线程")
|
|
logging.info(f"文档检索功能状态: {'已启用' if enable_retrieval else '未启用'}")
|
|
|
|
# 创建一个与输入顺序相同的结果列表
|
|
results = [None] * len(questions)
|
|
|
|
# 使用线程池进行并发处理
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
# 提交所有任务并记录它们的索引
|
|
future_to_index = {}
|
|
for idx, query in enumerate(questions):
|
|
if not query or query.strip() == "":
|
|
continue
|
|
future = executor.submit(self.process_query, query, enable_retrieval=enable_retrieval)
|
|
future_to_index[future] = idx
|
|
|
|
# 使用tqdm显示进度条
|
|
for future in tqdm(concurrent.futures.as_completed(future_to_index), total=len(future_to_index), desc="处理进度"):
|
|
idx = future_to_index[future]
|
|
result = future.result()
|
|
# 将结果放在与输入相同的位置
|
|
results[idx] = result
|
|
|
|
# 如果提供了输出文件路径,则保存结果
|
|
if output_file:
|
|
self.save_results_to_excel(results, output_file, is_final=True)
|
|
|
|
return results
|
|
|
|
def parse_arguments():
|
|
"""解析命令行参数"""
|
|
parser = argparse.ArgumentParser(description='意图识别和问题改写工具')
|
|
input_file="data/excel/1500条点踩软件问题测试.xlsx"
|
|
ouput_file="data/excel/1500条点踩软件问题测试_意图分类.xlsx"
|
|
# 添加数据文件路径参数
|
|
parser.add_argument('--input', '-i', type=str, default=input_file,
|
|
help='输入Excel文件路径,包含待处理的提问数据(第一列)')
|
|
parser.add_argument('--output', '-o', type=str,default=ouput_file,
|
|
help='输出Excel文件路径,用于保存处理结果')
|
|
|
|
# 添加处理相关参数
|
|
parser.add_argument('--max_workers', '-w', type=int, default=2,
|
|
help='并发处理的最大线程数,默认为20')
|
|
|
|
parser.add_argument('--enable_retrieval', '-r', action='store_true',
|
|
help='是否启用文档检索功能,默认不启用')
|
|
|
|
return parser.parse_args()
|
|
|
|
def main():
|
|
"""
|
|
意图识别和问题改写示例
|
|
"""
|
|
|
|
# 解析命令行参数
|
|
args = parse_arguments()
|
|
|
|
# 从环境变量中获取配置,命令行参数优先
|
|
api_key = os.getenv("OPENAI_API_KEY")
|
|
base_url = os.getenv("OPENAI_API_BASE")
|
|
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
|
enable_retrieval = args.enable_retrieval
|
|
|
|
# 初始化查询改写处理器
|
|
processor = QueryRewriteProcessor(api_key=api_key, base_url=base_url, model_name=model_name)
|
|
|
|
# 读取提问数据
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
data_file = args.input if args.input else os.path.join(current_dir, "..", "..", "data", "excel", "1500条点踩软件问题测试.xlsx")
|
|
output_file = args.output if args.output else os.path.join(current_dir, "..", "..", "data", "excel", "1500条点踩软件问题_槽位(分类)填充结果.xlsx")
|
|
|
|
# 检测是否为调试模式
|
|
is_debug =hasattr(sys, 'gettrace') and sys.gettrace() is not None
|
|
if is_debug:
|
|
examples = examples_query.strip().split("\n")
|
|
else:
|
|
examples = processor.load_questions_from_excel(data_file)
|
|
|
|
if not is_debug:
|
|
# 批量处理问题
|
|
results = processor.process_batch(questions=examples, max_workers=args.max_workers, enable_retrieval=enable_retrieval, output_file=output_file)
|
|
logging.info(f"所有处理完成,最终结果已保存至: {output_file}")
|
|
else:
|
|
logging.info(f"文档检索功能状态: {'已启用' if enable_retrieval else '未启用'}")
|
|
for idx, query in enumerate(examples):
|
|
if query.strip() == "":
|
|
continue
|
|
# 在调试模式下使用完整的参数
|
|
print(json.dumps(processor.process_query(
|
|
query,
|
|
enable_retrieval=enable_retrieval
|
|
), ensure_ascii=False, indent=2))
|
|
|
|
def setup_logging():
|
|
# 配置日志输出到控制台
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - [%(thread)d] - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.StreamHandler() # 添加控制台处理器
|
|
]
|
|
)
|
|
logging.getLogger('httpx').setLevel(logging.WARNING)
|
|
logging.getLogger('openai').setLevel(logging.WARNING)
|
|
|
|
if __name__ == "__main__":
|
|
setup_logging()
|
|
logging.info("意图识别示例程序开始运行...")
|
|
main() |