250 lines
8.8 KiB
Python
250 lines
8.8 KiB
Python
"""
|
|
提问内容补全工具
|
|
|
|
此模块用于解析Excel文件中的提问和回答,调用LLM补全提问内容,
|
|
并将原提问和补全后的提问保存到新的Excel文件中。
|
|
|
|
用法示例:
|
|
completer = QuestionCompleter(input_path="历史提问数据(dislike).xlsx", output_path="补全后的提问数据.xlsx")
|
|
completer.process()
|
|
"""
|
|
|
|
import pandas as pd
|
|
from tqdm import tqdm
|
|
import os
|
|
from dotenv import load_dotenv
|
|
from rag2_0.tool.ModelTool import OpenAiLLM
|
|
from pydantic import BaseModel, Field
|
|
from langchain.output_parsers import PydanticOutputParser
|
|
import concurrent.futures
|
|
from threading import Lock
|
|
|
|
class RewriteQuery(BaseModel):
|
|
rewrite_query:str = Field(description="补全后的提问")
|
|
software_name:str = Field(description="软件名称")
|
|
|
|
# 加载环境变量
|
|
load_dotenv()
|
|
|
|
class QuestionCompleter:
|
|
"""
|
|
提问内容补全工具类
|
|
|
|
用于解析Excel文件中的提问和回答,调用LLM补全提问内容,
|
|
并将原提问和补全后的提问保存到新的Excel文件中。
|
|
"""
|
|
|
|
def __init__(self, input_path="/data/Rag2_0/data/excel/历史提问数据(dislike).xlsx",
|
|
output_path="/data/Rag2_0/data/excel/历史提问数据(dislike)_补全后的提问数据.xlsx",
|
|
question_column="提问", answer_column="回答", max_workers=10):
|
|
"""
|
|
初始化提问内容补全工具
|
|
|
|
参数:
|
|
input_path (str): 输入Excel文件路径
|
|
output_path (str): 输出Excel文件路径
|
|
question_column (str): 提问列的名称
|
|
answer_column (str): 回答列的名称
|
|
max_workers (int): 最大线程数
|
|
"""
|
|
self.input_path = input_path
|
|
self.output_path = output_path
|
|
self.question_column = question_column
|
|
self.answer_column = answer_column
|
|
self.max_workers = max_workers
|
|
self.rewrite_query_parser = PydanticOutputParser(pydantic_object=RewriteQuery)
|
|
self.lock = Lock() # 添加线程锁
|
|
|
|
# 初始化LLM
|
|
self.api_key = os.getenv("OPENAI_API_KEY")
|
|
self.base_url = os.getenv("OPENAI_API_BASE")
|
|
self.model = os.getenv("LLM_MODEL_NAME")
|
|
|
|
if not all([self.api_key, self.base_url, self.model]):
|
|
raise ValueError("请设置 OPENAI_API_KEY, OPENAI_API_BASE, 和 LLM_MODEL_NAME 环境变量")
|
|
|
|
self.llm = OpenAiLLM(api_key=self.api_key, base_url=self.base_url, model=self.model)
|
|
|
|
# 读取Excel文件
|
|
try:
|
|
self.df = pd.read_excel(self.input_path)
|
|
print(f"成功读取Excel文件: {self.input_path}")
|
|
print(f"共有 {len(self.df)} 条记录")
|
|
except Exception as e:
|
|
raise RuntimeError(f"读取Excel文件失败: {str(e)}")
|
|
|
|
# 检查列是否存在
|
|
if self.question_column not in self.df.columns:
|
|
raise ValueError(f"Excel文件中不存在列: {self.question_column}")
|
|
if self.answer_column not in self.df.columns:
|
|
raise ValueError(f"Excel文件中不存在列: {self.answer_column}")
|
|
|
|
def create_completion_prompt(self, question, answer):
|
|
"""
|
|
创建用于补全提问的prompt
|
|
|
|
参数:
|
|
question (str): 原始提问
|
|
answer (str): 对应的回答
|
|
|
|
返回:
|
|
str: 格式化的prompt
|
|
"""
|
|
prompt = f"""
|
|
1、判断提问中是否缺少软件名称,如果不缺少,则直接返回原始提问
|
|
2、如果缺少软件名称,则根据回答中的软件名称,补全提问
|
|
3、补全后的提问需要保持问题原有意图不变
|
|
|
|
4、软件名称包括:
|
|
配网D3软件(配网工程计价通D3)
|
|
西藏Z1软件(西藏电力工程计价通Z1)
|
|
主网计价通软件(电力建设计价通)
|
|
技改检修工程计价通T1软件(技改检修工程计价通T1)
|
|
技改检修清单计价通T1软件(技改检修清单计价通T1)
|
|
储能C1软件(新型储能电站建设计价通C1)
|
|
如果没有包含上述软件名称,则直接返回原始提问,software_name为空字符串
|
|
{{
|
|
"rewrite_query": "xxx",
|
|
"software_name": ""
|
|
}}
|
|
|
|
原始提问:{question}
|
|
系统回答:{answer}
|
|
|
|
输出格式:
|
|
{self.rewrite_query_parser.get_format_instructions()}
|
|
|
|
示例:
|
|
例如,如果输入是:
|
|
提问:这个软件怎么用?
|
|
回答:Photoshop的使用方法是...
|
|
|
|
那么输出会是:
|
|
{{
|
|
"rewrite_query": "Photoshop这个软件怎么用?",
|
|
"software_name": "Photoshop"
|
|
}}
|
|
|
|
或者如果提问已经包含软件名称:
|
|
提问:Photoshop怎么用?
|
|
回答:Photoshop的使用方法是...
|
|
|
|
那么输出会是:
|
|
{{
|
|
"rewrite_query": "Photoshop怎么用?",
|
|
"software_name": "Photoshop"
|
|
}}
|
|
|
|
"""
|
|
return prompt
|
|
|
|
def complete_question(self, question, answer):
|
|
"""
|
|
调用LLM补全提问内容
|
|
|
|
参数:
|
|
question (str): 原始提问
|
|
answer (str): 对应的回答
|
|
|
|
返回:
|
|
str: 补全后的提问,如果补全失败则返回原始提问
|
|
"""
|
|
# 如果提问或回答为空,直接返回原始提问
|
|
if pd.isna(question) or question.strip() == "" or pd.isna(answer) or answer.strip() == "":
|
|
return question, ""
|
|
|
|
try:
|
|
prompt = self.create_completion_prompt(question, answer)
|
|
response = self.llm.invoke(prompt)
|
|
completed_question = self.rewrite_query_parser.parse(response.content)
|
|
return completed_question.rewrite_query, completed_question.software_name
|
|
except Exception as e:
|
|
print(f"补全提问失败: {str(e)}")
|
|
return question, ""
|
|
|
|
def process_row(self, row):
|
|
"""
|
|
处理单行数据
|
|
|
|
参数:
|
|
row: DataFrame中的一行
|
|
|
|
返回:
|
|
dict: 处理结果
|
|
"""
|
|
original_question = row[self.question_column]
|
|
answer = row[self.answer_column]
|
|
|
|
# 调用LLM补全提问
|
|
completed_question, software_name = self.complete_question(original_question, answer)
|
|
|
|
# 创建结果字典
|
|
result = {
|
|
"原始提问": original_question,
|
|
"补全后的提问": completed_question,
|
|
"软件名称": software_name
|
|
}
|
|
|
|
return result
|
|
|
|
def process(self):
|
|
"""
|
|
使用多线程处理所有提问并补全内容
|
|
|
|
读取Excel文件中的提问和回答,调用LLM补全提问内容,
|
|
并将原提问和补全后的提问保存到新的Excel文件中
|
|
"""
|
|
results = []
|
|
total = len(self.df)
|
|
|
|
# 使用进度条显示总体进度
|
|
with tqdm(total=total, desc="补全提问") as pbar:
|
|
# 创建线程池
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|
# 提交所有任务
|
|
future_to_idx = {executor.submit(self.process_row, self.df.iloc[idx]): idx for idx in range(total)}
|
|
|
|
# 处理完成的任务
|
|
for future in concurrent.futures.as_completed(future_to_idx):
|
|
result = future.result()
|
|
with self.lock:
|
|
results.append(result)
|
|
pbar.update(1)
|
|
|
|
# 将结果转换为DataFrame并保存
|
|
results_df = pd.DataFrame(results)
|
|
results_df.to_excel(self.output_path, index=False)
|
|
print(f"处理完成,共处理 {len(results)} 条记录,结果已保存至 {self.output_path}")
|
|
|
|
def main():
|
|
"""主函数"""
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description='补全Excel文件中的提问内容')
|
|
parser.add_argument('-i', '--input', type=str, default="/data/Rag2_0/data/excel/历史提问数据(dislike).xlsx",
|
|
help='输入Excel文件路径')
|
|
parser.add_argument('-o', '--output', type=str, default="/data/Rag2_0/data/excel/补全后的提问数据.xlsx",
|
|
help='输出Excel文件路径')
|
|
parser.add_argument('-q', '--question', type=str, default="提问",
|
|
help='提问列的名称')
|
|
parser.add_argument('-a', '--answer', type=str, default="回答",
|
|
help='回答列的名称')
|
|
parser.add_argument('-w', '--workers', type=int, default=50,
|
|
help='最大线程数')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# 创建提问补全工具实例
|
|
completer = QuestionCompleter(
|
|
input_path=args.input,
|
|
output_path=args.output,
|
|
question_column=args.question,
|
|
answer_column=args.answer,
|
|
max_workers=args.workers
|
|
)
|
|
|
|
# 执行处理
|
|
completer.process()
|
|
|
|
if __name__ == "__main__":
|
|
main() |