""" 提问内容补全工具 此模块用于解析Excel文件中的提问和回答,调用LLM补全提问内容, 并将原提问和补全后的提问保存到新的Excel文件中。 用法示例: completer = QuestionCompleter(input_path="历史提问数据(dislike).xlsx", output_path="补全后的提问数据.xlsx") completer.process() """ import pandas as pd from tqdm import tqdm import os from dotenv import load_dotenv from rag2_0.tool.ModelTool import OpenAiLLM from pydantic import BaseModel, Field from langchain.output_parsers import PydanticOutputParser import concurrent.futures from threading import Lock class RewriteQuery(BaseModel): rewrite_query:str = Field(description="补全后的提问") software_name:str = Field(description="软件名称") # 加载环境变量 load_dotenv() class QuestionCompleter: """ 提问内容补全工具类 用于解析Excel文件中的提问和回答,调用LLM补全提问内容, 并将原提问和补全后的提问保存到新的Excel文件中。 """ def __init__(self, input_path="/data/Rag2_0/data/excel/历史提问数据(dislike).xlsx", output_path="/data/Rag2_0/data/excel/历史提问数据(dislike)_补全后的提问数据.xlsx", question_column="提问", answer_column="回答", max_workers=10): """ 初始化提问内容补全工具 参数: input_path (str): 输入Excel文件路径 output_path (str): 输出Excel文件路径 question_column (str): 提问列的名称 answer_column (str): 回答列的名称 max_workers (int): 最大线程数 """ self.input_path = input_path self.output_path = output_path self.question_column = question_column self.answer_column = answer_column self.max_workers = max_workers self.rewrite_query_parser = PydanticOutputParser(pydantic_object=RewriteQuery) self.lock = Lock() # 添加线程锁 # 初始化LLM self.api_key = os.getenv("OPENAI_API_KEY") self.base_url = os.getenv("OPENAI_API_BASE") self.model = os.getenv("LLM_MODEL_NAME") if not all([self.api_key, self.base_url, self.model]): raise ValueError("请设置 OPENAI_API_KEY, OPENAI_API_BASE, 和 LLM_MODEL_NAME 环境变量") self.llm = OpenAiLLM(api_key=self.api_key, base_url=self.base_url, model=self.model) # 读取Excel文件 try: self.df = pd.read_excel(self.input_path) print(f"成功读取Excel文件: {self.input_path}") print(f"共有 {len(self.df)} 条记录") except Exception as e: raise RuntimeError(f"读取Excel文件失败: {str(e)}") # 检查列是否存在 if self.question_column not in self.df.columns: raise ValueError(f"Excel文件中不存在列: {self.question_column}") if self.answer_column not in self.df.columns: raise ValueError(f"Excel文件中不存在列: {self.answer_column}") def create_completion_prompt(self, question, answer): """ 创建用于补全提问的prompt 参数: question (str): 原始提问 answer (str): 对应的回答 返回: str: 格式化的prompt """ prompt = f""" 1、判断提问中是否缺少软件名称,如果不缺少,则直接返回原始提问 2、如果缺少软件名称,则根据回答中的软件名称,补全提问 3、补全后的提问需要保持问题原有意图不变 4、软件名称包括: 配网D3软件(配网工程计价通D3) 西藏Z1软件(西藏电力工程计价通Z1) 主网计价通软件(电力建设计价通) 技改检修工程计价通T1软件(技改检修工程计价通T1) 技改检修清单计价通T1软件(技改检修清单计价通T1) 储能C1软件(新型储能电站建设计价通C1) 如果没有包含上述软件名称,则直接返回原始提问,software_name为空字符串 {{ "rewrite_query": "xxx", "software_name": "" }} 原始提问:{question} 系统回答:{answer} 输出格式: {self.rewrite_query_parser.get_format_instructions()} 示例: 例如,如果输入是: 提问:这个软件怎么用? 回答:Photoshop的使用方法是... 那么输出会是: {{ "rewrite_query": "Photoshop这个软件怎么用?", "software_name": "Photoshop" }} 或者如果提问已经包含软件名称: 提问:Photoshop怎么用? 回答:Photoshop的使用方法是... 那么输出会是: {{ "rewrite_query": "Photoshop怎么用?", "software_name": "Photoshop" }} """ return prompt def complete_question(self, question, answer): """ 调用LLM补全提问内容 参数: question (str): 原始提问 answer (str): 对应的回答 返回: str: 补全后的提问,如果补全失败则返回原始提问 """ # 如果提问或回答为空,直接返回原始提问 if pd.isna(question) or question.strip() == "" or pd.isna(answer) or answer.strip() == "": return question, "" try: prompt = self.create_completion_prompt(question, answer) response = self.llm.invoke(prompt) completed_question = self.rewrite_query_parser.parse(response.content) return completed_question.rewrite_query, completed_question.software_name except Exception as e: print(f"补全提问失败: {str(e)}") return question, "" def process_row(self, row): """ 处理单行数据 参数: row: DataFrame中的一行 返回: dict: 处理结果 """ original_question = row[self.question_column] answer = row[self.answer_column] # 调用LLM补全提问 completed_question, software_name = self.complete_question(original_question, answer) # 创建结果字典 result = { "原始提问": original_question, "补全后的提问": completed_question, "软件名称": software_name } return result def process(self): """ 使用多线程处理所有提问并补全内容 读取Excel文件中的提问和回答,调用LLM补全提问内容, 并将原提问和补全后的提问保存到新的Excel文件中 """ results = [] total = len(self.df) # 使用进度条显示总体进度 with tqdm(total=total, desc="补全提问") as pbar: # 创建线程池 with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor: # 提交所有任务 future_to_idx = {executor.submit(self.process_row, self.df.iloc[idx]): idx for idx in range(total)} # 处理完成的任务 for future in concurrent.futures.as_completed(future_to_idx): result = future.result() with self.lock: results.append(result) pbar.update(1) # 将结果转换为DataFrame并保存 results_df = pd.DataFrame(results) results_df.to_excel(self.output_path, index=False) print(f"处理完成,共处理 {len(results)} 条记录,结果已保存至 {self.output_path}") def main(): """主函数""" import argparse parser = argparse.ArgumentParser(description='补全Excel文件中的提问内容') parser.add_argument('-i', '--input', type=str, default="/data/Rag2_0/data/excel/历史提问数据(dislike).xlsx", help='输入Excel文件路径') parser.add_argument('-o', '--output', type=str, default="/data/Rag2_0/data/excel/补全后的提问数据.xlsx", help='输出Excel文件路径') parser.add_argument('-q', '--question', type=str, default="提问", help='提问列的名称') parser.add_argument('-a', '--answer', type=str, default="回答", help='回答列的名称') parser.add_argument('-w', '--workers', type=int, default=50, help='最大线程数') args = parser.parse_args() # 创建提问补全工具实例 completer = QuestionCompleter( input_path=args.input, output_path=args.output, question_column=args.question, answer_column=args.answer, max_workers=args.workers ) # 执行处理 completer.process() if __name__ == "__main__": main()