#!/usr/bin/env python # -*- coding: utf-8 -*- """ File: intent_recognition_example.py Date: 2025-05-14 Description: 意图识别和问题改写示例 """ import os from dotenv import load_dotenv import pandas as pd import logging import json import concurrent.futures from tqdm import tqdm import time import sys import argparse from typing import List, Dict, Any, Optional from langchain.output_parsers import PydanticOutputParser from pydantic import BaseModel, Field sys.path.append(os.getcwd()) from rag2_0.intent_recognition import IntentRecognizer from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval from rag2_0.intent_recognition.DataModels import Classification from rag2_0.tool.ModelTool import OpenAiLLM # 加载环境变量 load_dotenv() # 示例查询 examples_query = """ PE2211PK0801是什么软件""" conversation_context="" chat_history=[ { "user": "怎么新建工程啊" }, { "assistant": "您好!为了更准确地帮助您解决问题,请提供以下信息:\n\n- **软件名称**:请确认您使用的软件名称,可选值如下:\n - 配网工程计价通D3软件\n - 新型储能电站建设计价通C1软件\n - 西藏电力工程计价通Z1软件\n - 技改检修工程计价通T1软件\n - 技改检修清单计价通T1软件\n - 主网电力建设计价通软件\n\n请补充以上信息,以便我们能够更好地指导您如何新建工程。" }, { "user": "配网工程计价通D3软件" }, { "assistant": """好的,针对您使用**配网工程计价通D3软件**新建工程的问题,以下是具体的操作步骤:\n\n## 新建工程\n\n### 功能入口\n【主页】界面——"新建"按钮\n\n### 操作步骤\n1. 在"新建窗口"选择对应工程模板。\n2. 设置工程名称。\n3. 确定后根据新建向导完善"电压等级"、"地区类型"等参数。\n4. 点击"确定"即可完成新建工程。\n\n![进入新建工程窗口](https://172.20.0.145/files/362491a3-69c3-45b0-a037-b5f96b91bf73/image-preview)\n\n希望这些步骤能帮助您顺利完成新建工程。如果还有其他问题,欢迎随时提问!\n""" } ] previous_slots={ "software_name": "配网工程计价通D3软件", "function_name": "新建工程", "operation": "如何新建工程", "project_type": None, "software_version": None, "operation_steps": None } class QueryRewriteProcessor: """ 查询改写处理器,用于意图识别、问题改写和文档检索 """ def __init__(self, api_key: str = None, base_url: str = None, model_name: str = None, dify_api_key: str = "dataset-skLjmPVonjHo119OWNf3kAmY", dify_base_url: str = "http://172.20.0.145/v1"): """ 初始化查询改写处理器 Args: api_key: API密钥,默认使用环境变量 base_url: API基础URL,默认使用环境变量 model_name: 模型名称,默认使用环境变量或默认模型 dify_api_key: Dify API密钥 dify_base_url: Dify API基础URL """ # 初始化意图识别器 self.api_key = api_key or os.getenv("OPENAI_API_KEY") self.base_url = base_url or os.getenv("OPENAI_API_BASE") self.model_name = model_name or os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo") self.recognizer = IntentRecognizer(api_key=self.api_key, base_url=self.base_url, model_name=self.model_name) self.dify_query_retrieval = DifyQueryRetrieval(api_key=dify_api_key, base_url=dify_base_url) def is_retrieved_doc_relevant(self, query: str, retrieved_doc: List[Dict[str, Any]]) -> Dict[str, Any]: """ 使用LLM判断检索出的文档是否与用户提问相关 Args: query: 用户提问 retrieved_doc: 检索出的文档列表 Returns: 包含相关性判断结果的字典,包括is_relevant(布尔值)和explanation(解释) """ # 如果没有检索到文档,直接返回不相关 if not retrieved_doc or len(retrieved_doc) == 0: return { "is_relevant": False, "explanation": "没有检索到任何文档", "relevance_score": 0.0 } doc_text_list = json.dumps(retrieved_doc, ensure_ascii=False, indent=2) class TempModel(BaseModel): can_solve_problem: bool = Field(description="是否能解决用户问题") relevance_score: int = Field(description="相关性评分,0-100分") explanation: str = Field(description="解释文档是否能解决(回答)提问") class all_relevant_document(BaseModel): most_relevant_document: list[TempModel] = Field(description="最相关的文档的判断结果") parser = PydanticOutputParser(pydantic_object=all_relevant_document) # 构建提示词 prompt = f"""请判断以下检索文档列表中是否与用户提问相关,能够解决用户的问题,并给出相关性评分(0-100分)。输出最相关的文档的判断结果。 用户提问: {query} 检索文档列表: {doc_text_list} 请按照以下JSON格式返回结果: json``` {{ "most_relevant_document":[{{ "can_solve_problem": true, "relevance_score": 60, "explanation":"xxxx" }}] }} ``` """ try: # 初始化LLM并调用 llm = OpenAiLLM(api_key=self.api_key, base_url=self.base_url, model="deepseek-ai/DeepSeek-R1", response_format={"type": "json_object"}) response = llm.invoke(prompt) result_list = parser.parse(response.content).most_relevant_document # 如果列表为空,返回默认的不相关结果 if not result_list: return { "is_relevant": False, "explanation": "无法解析文档相关性结果", "relevance_score": 0.0 } # 找出分数最高的文档 max_score_doc = max(result_list, key=lambda x: x.relevance_score) return { "is_relevant": max_score_doc.can_solve_problem, "relevance_score": max_score_doc.relevance_score, "explanation": max_score_doc.explanation } except Exception as e: logging.error(f"判断文档相关性时出错: {str(e)}", exc_info=True) return { "is_relevant": False, "explanation": f"判断过程出错: {str(e)}", "relevance_score": 0.0 } def load_questions_from_excel(self, file_path=None): """ 从Excel文件中读取提问数据 Args: file_path: Excel文件路径,如果为None则使用默认路径 Returns: 提问列表 """ try: # 读取Excel文件的第一列数据 df = pd.read_excel(file_path) questions = df.iloc[:, 0].tolist() # 获取第一列数据 logging.info(f"成功从{file_path}读取了{len(questions)}条提问") return questions except Exception as e: logging.error(f"读取Excel文件时出错: {e}", exc_info=True) return [] def process_query(self, query: str, conversation_context: str = "", chat_history: List[Dict[str, str]] = None, previous_slots: Dict[str, str] = None, enable_retrieval: bool = False): """ 处理单个查询,支持重试机制,并包含槽位填充 Args: query: 查询字符串 conversation_context: 对话上下文 chat_history: 聊天历史记录 previous_slots: 之前识别的槽位信息 enable_retrieval: 是否启用文档检索功能,默认为False Returns: 处理结果字典 """ max_retries = 3 retry_count = 0 while retry_count <= max_retries: try: # 使用process_query方法处理查询 result = self.recognizer.process_query(query, conversation_context=conversation_context, chat_history=chat_history, previous_slots=previous_slots, enable_query_expansion=True) # 提取分类信息 classification = result["classification"] original_query = result["rewrite"]["rewrite"] query_list = result["query_expand"]["all"] soft_name = result.get("slot_filling", {}).get("filled_data", {}).get("software_name","") # 将字典转换为Classification对象 classification_obj = Classification(**classification) # 根据enable_retrieval参数决定是否进行文档检索 retrieved_doc = None if enable_retrieval: retrieved_doc = self.dify_query_retrieval.retrieve(original_query, query_list, classification_obj, soft_name) # 判断检索文档是否相关 relevance_result = {} if retrieved_doc: # 判断文档相关性 relevance_result = self.is_retrieved_doc_relevant(query, retrieved_doc) else: relevance_result = { "is_relevant": False, "explanation": "没有检索到文档" if enable_retrieval else "文档检索功能未启用", "relevance_score": 0.0 } retrieved_doc_titles=[] if retrieved_doc: retrieved_doc_titles=[doc["title"].split("/")[-1] for doc in retrieved_doc] # 提取槽位填充信息 slot_filling = result.get("slot_filling", {}) slot_filling_str = "" if slot_filling and "filled_data" in slot_filling: # 格式化槽位填充结果 slot_filling_str = json.dumps({ "是否完整": slot_filling.get("is_complete", False), "缺失槽位": slot_filling.get("missing_slots", {}), "填充数据": slot_filling.get("filled_data", {}) }, ensure_ascii=False, indent=2) # 处理成功,返回结果 return { "问题": query, "问题分类": f"{classification['vertical_classification']} - {classification['sub_classification']}", "问题改写": result["rewrite"]["rewrite"], "槽位信息": slot_filling_str, "检索的文档": "\n".join(retrieved_doc_titles), "检索的内容": json.dumps(retrieved_doc, ensure_ascii=False, indent=2) if retrieved_doc else "", "文档是否相关": "相关" if relevance_result["is_relevant"] else "不相关", "文档相关性解释": relevance_result["explanation"] } except Exception as e: logging.error(f"处理问题 '{query}' 时出错: ",exc_info=True) retry_count += 1 # 如果已经重试了最大次数,则记录错误并返回错误结果 if retry_count > max_retries: return { "问题": query, "问题分类": "处理出错", "问题改写": "处理出错", "槽位信息": "处理出错", "检索的文档": f"重试 {max_retries} 次后失败: {str(e)}", "检索的内容":"", "文档是否相关": "处理出错", "文档相关性解释": "处理出错" } else: # 可以在这里添加延迟,避免过快重试 time.sleep(10) def save_results_to_excel(self, results, output_file, is_final=False): """ 将结果保存到Excel文件 Args: results: 结果列表 output_file: 输出文件路径 is_final: 是否为最终保存,如果是则使用完整文件名,否则添加临时标记 Returns: None """ # 过滤掉None值 valid_results = [r for r in results if r is not None] if not valid_results: logging.warning("没有有效结果可保存") return # 创建DataFrame results_df = pd.DataFrame(valid_results) # 根据是否为最终保存确定文件名 if not is_final: file_name, file_ext = os.path.splitext(output_file) temp_output_file = f"{file_name}_temp{file_ext}" else: temp_output_file = output_file # 使用ExcelWriter设置格式 with pd.ExcelWriter(temp_output_file, engine='xlsxwriter') as writer: results_df.to_excel(writer, index=False, sheet_name='Sheet1') # 获取工作簿和工作表对象 workbook = writer.book worksheet = writer.sheets['Sheet1'] # 设置列宽(单位:像素) # 定义列宽(厘米转为Excel单位,1cm约等于4.7个Excel单位) worksheet.set_column('A:A', 60) # 提问列 60个Excel单位 worksheet.set_column('B:B', 20) # 问题拆解 20个Excel单位 worksheet.set_column('C:C', 20) # 一级分类 20个Excel单位 worksheet.set_column('D:D', 20) # 二级分类 20个Excel单位 worksheet.set_column('E:E', 60) # 问题改写 60个Excel单位 worksheet.set_column('F:F', 60) # 检索到的关键词 60个Excel单位 worksheet.set_column('G:G', 80) # 槽位填充 80个Excel单位 worksheet.set_column('H:H', 60) # 文档相关性 60个Excel单位 # 设置所有行高为20磅 for i in range(len(results_df) + 1): # +1 是为了包括表头 worksheet.set_row(i, 20) logging.info(f"已保存{len(valid_results)}条结果至: {temp_output_file}") def process_batch(self, questions: List[str], max_workers: int = 2, enable_retrieval: bool = False, output_file: str = None, save_interval: int = 100): """ 批量处理多个问题 Args: questions: 问题列表 max_workers: 并发处理的最大线程数,默认为2 enable_retrieval: 是否启用文档检索功能,默认为False output_file: 输出文件路径,如果为None则不保存结果 save_interval: 临时保存的间隔,每处理这么多问题就临时保存一次结果,默认为100 Returns: 处理结果列表 """ logging.info(f"共有 {len(questions)} 个问题需要处理,使用 {max_workers} 个并发线程") logging.info(f"文档检索功能状态: {'已启用' if enable_retrieval else '未启用'}") logging.info(f"每处理 {save_interval} 个问题将临时保存一次结果") # 创建一个与输入顺序相同的结果列表 results = [None] * len(questions) # 使用线程池进行并发处理 with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: # 提交所有任务并记录它们的索引 future_to_index = {} for idx, query in enumerate(questions): if not query or query.strip() == "": continue future = executor.submit(self.process_query, query, enable_retrieval=enable_retrieval) future_to_index[future] = idx # 用于跟踪已完成的问题数量 completed_count = 0 # 使用tqdm显示进度条 for future in tqdm(concurrent.futures.as_completed(future_to_index), total=len(future_to_index), desc="处理进度"): idx = future_to_index[future] result = future.result() # 将结果放在与输入相同的位置 results[idx] = result # 增加已完成的问题计数 completed_count += 1 # 检查是否需要临时保存 if output_file and completed_count % save_interval == 0: # 临时保存当前结果 self.save_results_to_excel(results, output_file, is_final=False) logging.info(f"已临时保存 {completed_count} 个问题的处理结果") # 如果提供了输出文件路径,则保存最终结果 if output_file: self.save_results_to_excel(results, output_file, is_final=True) return results def parse_arguments(): """解析命令行参数""" parser = argparse.ArgumentParser(description='意图识别和问题改写工具') input_file="data/excel/1500条点踩软件问题测试.xlsx" ouput_file="data/excel/1500条点踩软件问题测试_意图分类.xlsx" # 添加数据文件路径参数 parser.add_argument('--input', '-i', type=str, default=input_file, help='输入Excel文件路径,包含待处理的提问数据(第一列)') parser.add_argument('--output', '-o', type=str,default=ouput_file, help='输出Excel文件路径,用于保存处理结果') # 添加处理相关参数 parser.add_argument('--max_workers', '-w', type=int, default=2, help='并发处理的最大线程数,默认为20') parser.add_argument('--enable_retrieval', '-r', action='store_true', help='是否启用文档检索功能,默认不启用') return parser.parse_args() def main(): """ 意图识别和问题改写示例 """ # 解析命令行参数 args = parse_arguments() # 从环境变量中获取配置,命令行参数优先 api_key = os.getenv("OPENAI_API_KEY") base_url = os.getenv("OPENAI_API_BASE") model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo") enable_retrieval = args.enable_retrieval # 初始化查询改写处理器 processor = QueryRewriteProcessor(api_key=api_key, base_url=base_url, model_name=model_name) # 读取提问数据 current_dir = os.path.dirname(os.path.abspath(__file__)) data_file = args.input if args.input else os.path.join(current_dir, "..", "..", "data", "excel", "1500条点踩软件问题测试.xlsx") output_file = args.output if args.output else os.path.join(current_dir, "..", "..", "data", "excel", "1500条点踩软件问题_槽位(分类)填充结果.xlsx") # 检测是否为调试模式 is_debug =hasattr(sys, 'gettrace') and sys.gettrace() is not None if is_debug: examples = examples_query.strip().split("\n") else: examples = processor.load_questions_from_excel(data_file) if not is_debug: # 批量处理问题 results = processor.process_batch(questions=examples, max_workers=args.max_workers, enable_retrieval=enable_retrieval, output_file=output_file) logging.info(f"所有处理完成,最终结果已保存至: {output_file}") else: logging.info(f"文档检索功能状态: 已启用") for idx, query in enumerate(examples): if query.strip() == "": continue # 在调试模式下使用完整的参数 print(json.dumps(processor.process_query( query, enable_retrieval=True ), ensure_ascii=False, indent=2)) def setup_logging(): # 配置日志输出到控制台 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - [%(thread)d] - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler() # 添加控制台处理器 ] ) logging.getLogger('httpx').setLevel(logging.WARNING) logging.getLogger('openai').setLevel(logging.WARNING) if __name__ == "__main__": setup_logging() main()