Files
QueryRewrite/rag2_0/demo/Test.py
T

250 lines
8.8 KiB
Python

"""
提问内容补全工具
此模块用于解析Excel文件中的提问和回答,调用LLM补全提问内容,
并将原提问和补全后的提问保存到新的Excel文件中。
用法示例:
completer = QuestionCompleter(input_path="历史提问数据(dislike).xlsx", output_path="补全后的提问数据.xlsx")
completer.process()
"""
import pandas as pd
from tqdm import tqdm
import os
from dotenv import load_dotenv
from rag2_0.tool.ModelTool import OpenAiLLM
from pydantic import BaseModel, Field
from langchain.output_parsers import PydanticOutputParser
import concurrent.futures
from threading import Lock
class RewriteQuery(BaseModel):
rewrite_query:str = Field(description="补全后的提问")
software_name:str = Field(description="软件名称")
# 加载环境变量
load_dotenv()
class QuestionCompleter:
"""
提问内容补全工具类
用于解析Excel文件中的提问和回答,调用LLM补全提问内容,
并将原提问和补全后的提问保存到新的Excel文件中。
"""
def __init__(self, input_path="/data/Rag2_0/data/excel/历史提问数据(dislike).xlsx",
output_path="/data/Rag2_0/data/excel/历史提问数据(dislike)_补全后的提问数据.xlsx",
question_column="提问", answer_column="回答", max_workers=10):
"""
初始化提问内容补全工具
参数:
input_path (str): 输入Excel文件路径
output_path (str): 输出Excel文件路径
question_column (str): 提问列的名称
answer_column (str): 回答列的名称
max_workers (int): 最大线程数
"""
self.input_path = input_path
self.output_path = output_path
self.question_column = question_column
self.answer_column = answer_column
self.max_workers = max_workers
self.rewrite_query_parser = PydanticOutputParser(pydantic_object=RewriteQuery)
self.lock = Lock() # 添加线程锁
# 初始化LLM
self.api_key = os.getenv("OPENAI_API_KEY")
self.base_url = os.getenv("OPENAI_API_BASE")
self.model = os.getenv("LLM_MODEL_NAME")
if not all([self.api_key, self.base_url, self.model]):
raise ValueError("请设置 OPENAI_API_KEY, OPENAI_API_BASE, 和 LLM_MODEL_NAME 环境变量")
self.llm = OpenAiLLM(api_key=self.api_key, base_url=self.base_url, model=self.model)
# 读取Excel文件
try:
self.df = pd.read_excel(self.input_path)
print(f"成功读取Excel文件: {self.input_path}")
print(f"共有 {len(self.df)} 条记录")
except Exception as e:
raise RuntimeError(f"读取Excel文件失败: {str(e)}")
# 检查列是否存在
if self.question_column not in self.df.columns:
raise ValueError(f"Excel文件中不存在列: {self.question_column}")
if self.answer_column not in self.df.columns:
raise ValueError(f"Excel文件中不存在列: {self.answer_column}")
def create_completion_prompt(self, question, answer):
"""
创建用于补全提问的prompt
参数:
question (str): 原始提问
answer (str): 对应的回答
返回:
str: 格式化的prompt
"""
prompt = f"""
1、判断提问中是否缺少软件名称,如果不缺少,则直接返回原始提问
2、如果缺少软件名称,则根据回答中的软件名称,补全提问
3、补全后的提问需要保持问题原有意图不变
4、软件名称包括:
配网D3软件(配网工程计价通D3)
西藏Z1软件(西藏电力工程计价通Z1)
主网计价通软件(电力建设计价通)
技改检修工程计价通T1软件(技改检修工程计价通T1)
技改检修清单计价通T1软件(技改检修清单计价通T1)
储能C1软件(新型储能电站建设计价通C1)
如果没有包含上述软件名称,则直接返回原始提问,software_name为空字符串
{{
"rewrite_query": "xxx",
"software_name": ""
}}
原始提问:{question}
系统回答:{answer}
输出格式:
{self.rewrite_query_parser.get_format_instructions()}
示例:
例如,如果输入是:
提问:这个软件怎么用?
回答:Photoshop的使用方法是...
那么输出会是:
{{
"rewrite_query": "Photoshop这个软件怎么用?",
"software_name": "Photoshop"
}}
或者如果提问已经包含软件名称:
提问:Photoshop怎么用?
回答:Photoshop的使用方法是...
那么输出会是:
{{
"rewrite_query": "Photoshop怎么用?",
"software_name": "Photoshop"
}}
"""
return prompt
def complete_question(self, question, answer):
"""
调用LLM补全提问内容
参数:
question (str): 原始提问
answer (str): 对应的回答
返回:
str: 补全后的提问,如果补全失败则返回原始提问
"""
# 如果提问或回答为空,直接返回原始提问
if pd.isna(question) or question.strip() == "" or pd.isna(answer) or answer.strip() == "":
return question, ""
try:
prompt = self.create_completion_prompt(question, answer)
response = self.llm.invoke(prompt)
completed_question = self.rewrite_query_parser.parse(response.content)
return completed_question.rewrite_query, completed_question.software_name
except Exception as e:
print(f"补全提问失败: {str(e)}")
return question, ""
def process_row(self, row):
"""
处理单行数据
参数:
row: DataFrame中的一行
返回:
dict: 处理结果
"""
original_question = row[self.question_column]
answer = row[self.answer_column]
# 调用LLM补全提问
completed_question, software_name = self.complete_question(original_question, answer)
# 创建结果字典
result = {
"原始提问": original_question,
"补全后的提问": completed_question,
"软件名称": software_name
}
return result
def process(self):
"""
使用多线程处理所有提问并补全内容
读取Excel文件中的提问和回答,调用LLM补全提问内容,
并将原提问和补全后的提问保存到新的Excel文件中
"""
results = []
total = len(self.df)
# 使用进度条显示总体进度
with tqdm(total=total, desc="补全提问") as pbar:
# 创建线程池
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# 提交所有任务
future_to_idx = {executor.submit(self.process_row, self.df.iloc[idx]): idx for idx in range(total)}
# 处理完成的任务
for future in concurrent.futures.as_completed(future_to_idx):
result = future.result()
with self.lock:
results.append(result)
pbar.update(1)
# 将结果转换为DataFrame并保存
results_df = pd.DataFrame(results)
results_df.to_excel(self.output_path, index=False)
print(f"处理完成,共处理 {len(results)} 条记录,结果已保存至 {self.output_path}")
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description='补全Excel文件中的提问内容')
parser.add_argument('-i', '--input', type=str, default="/data/Rag2_0/data/excel/历史提问数据(dislike).xlsx",
help='输入Excel文件路径')
parser.add_argument('-o', '--output', type=str, default="/data/Rag2_0/data/excel/补全后的提问数据.xlsx",
help='输出Excel文件路径')
parser.add_argument('-q', '--question', type=str, default="提问",
help='提问列的名称')
parser.add_argument('-a', '--answer', type=str, default="回答",
help='回答列的名称')
parser.add_argument('-w', '--workers', type=int, default=50,
help='最大线程数')
args = parser.parse_args()
# 创建提问补全工具实例
completer = QuestionCompleter(
input_path=args.input,
output_path=args.output,
question_column=args.question,
answer_column=args.answer,
max_workers=args.workers
)
# 执行处理
completer.process()
if __name__ == "__main__":
main()