上传问题改写、意图识别模块代码
This commit is contained in:
@@ -0,0 +1,250 @@
|
||||
"""
|
||||
提问内容补全工具
|
||||
|
||||
此模块用于解析Excel文件中的提问和回答,调用LLM补全提问内容,
|
||||
并将原提问和补全后的提问保存到新的Excel文件中。
|
||||
|
||||
用法示例:
|
||||
completer = QuestionCompleter(input_path="历史提问数据(dislike).xlsx", output_path="补全后的提问数据.xlsx")
|
||||
completer.process()
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
import concurrent.futures
|
||||
from threading import Lock
|
||||
|
||||
class RewriteQuery(BaseModel):
|
||||
rewrite_query:str = Field(description="补全后的提问")
|
||||
software_name:str = Field(description="软件名称")
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
class QuestionCompleter:
|
||||
"""
|
||||
提问内容补全工具类
|
||||
|
||||
用于解析Excel文件中的提问和回答,调用LLM补全提问内容,
|
||||
并将原提问和补全后的提问保存到新的Excel文件中。
|
||||
"""
|
||||
|
||||
def __init__(self, input_path="/data/Rag2_0/data/excel/历史提问数据(dislike).xlsx",
|
||||
output_path="/data/Rag2_0/data/excel/历史提问数据(dislike)_补全后的提问数据.xlsx",
|
||||
question_column="提问", answer_column="回答", max_workers=10):
|
||||
"""
|
||||
初始化提问内容补全工具
|
||||
|
||||
参数:
|
||||
input_path (str): 输入Excel文件路径
|
||||
output_path (str): 输出Excel文件路径
|
||||
question_column (str): 提问列的名称
|
||||
answer_column (str): 回答列的名称
|
||||
max_workers (int): 最大线程数
|
||||
"""
|
||||
self.input_path = input_path
|
||||
self.output_path = output_path
|
||||
self.question_column = question_column
|
||||
self.answer_column = answer_column
|
||||
self.max_workers = max_workers
|
||||
self.rewrite_query_parser = PydanticOutputParser(pydantic_object=RewriteQuery)
|
||||
self.lock = Lock() # 添加线程锁
|
||||
|
||||
# 初始化LLM
|
||||
self.api_key = os.getenv("OPENAI_API_KEY")
|
||||
self.base_url = os.getenv("OPENAI_API_BASE")
|
||||
self.model = os.getenv("LLM_MODEL_NAME")
|
||||
|
||||
if not all([self.api_key, self.base_url, self.model]):
|
||||
raise ValueError("请设置 OPENAI_API_KEY, OPENAI_API_BASE, 和 LLM_MODEL_NAME 环境变量")
|
||||
|
||||
self.llm = OpenAiLLM(api_key=self.api_key, base_url=self.base_url, model=self.model)
|
||||
|
||||
# 读取Excel文件
|
||||
try:
|
||||
self.df = pd.read_excel(self.input_path)
|
||||
print(f"成功读取Excel文件: {self.input_path}")
|
||||
print(f"共有 {len(self.df)} 条记录")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"读取Excel文件失败: {str(e)}")
|
||||
|
||||
# 检查列是否存在
|
||||
if self.question_column not in self.df.columns:
|
||||
raise ValueError(f"Excel文件中不存在列: {self.question_column}")
|
||||
if self.answer_column not in self.df.columns:
|
||||
raise ValueError(f"Excel文件中不存在列: {self.answer_column}")
|
||||
|
||||
def create_completion_prompt(self, question, answer):
|
||||
"""
|
||||
创建用于补全提问的prompt
|
||||
|
||||
参数:
|
||||
question (str): 原始提问
|
||||
answer (str): 对应的回答
|
||||
|
||||
返回:
|
||||
str: 格式化的prompt
|
||||
"""
|
||||
prompt = f"""
|
||||
1、判断提问中是否缺少软件名称,如果不缺少,则直接返回原始提问
|
||||
2、如果缺少软件名称,则根据回答中的软件名称,补全提问
|
||||
3、补全后的提问需要保持问题原有意图不变
|
||||
|
||||
4、软件名称包括:
|
||||
配网D3软件(配网工程计价通D3)
|
||||
西藏Z1软件(西藏电力工程计价通Z1)
|
||||
主网计价通软件(电力建设计价通)
|
||||
技改检修工程计价通T1软件(技改检修工程计价通T1)
|
||||
技改检修清单计价通T1软件(技改检修清单计价通T1)
|
||||
储能C1软件(新型储能电站建设计价通C1)
|
||||
如果没有包含上述软件名称,则直接返回原始提问,software_name为空字符串
|
||||
{{
|
||||
"rewrite_query": "xxx",
|
||||
"software_name": ""
|
||||
}}
|
||||
|
||||
原始提问:{question}
|
||||
系统回答:{answer}
|
||||
|
||||
输出格式:
|
||||
{self.rewrite_query_parser.get_format_instructions()}
|
||||
|
||||
示例:
|
||||
例如,如果输入是:
|
||||
提问:这个软件怎么用?
|
||||
回答:Photoshop的使用方法是...
|
||||
|
||||
那么输出会是:
|
||||
{{
|
||||
"rewrite_query": "Photoshop这个软件怎么用?",
|
||||
"software_name": "Photoshop"
|
||||
}}
|
||||
|
||||
或者如果提问已经包含软件名称:
|
||||
提问:Photoshop怎么用?
|
||||
回答:Photoshop的使用方法是...
|
||||
|
||||
那么输出会是:
|
||||
{{
|
||||
"rewrite_query": "Photoshop怎么用?",
|
||||
"software_name": "Photoshop"
|
||||
}}
|
||||
|
||||
"""
|
||||
return prompt
|
||||
|
||||
def complete_question(self, question, answer):
|
||||
"""
|
||||
调用LLM补全提问内容
|
||||
|
||||
参数:
|
||||
question (str): 原始提问
|
||||
answer (str): 对应的回答
|
||||
|
||||
返回:
|
||||
str: 补全后的提问,如果补全失败则返回原始提问
|
||||
"""
|
||||
# 如果提问或回答为空,直接返回原始提问
|
||||
if pd.isna(question) or question.strip() == "" or pd.isna(answer) or answer.strip() == "":
|
||||
return question, ""
|
||||
|
||||
try:
|
||||
prompt = self.create_completion_prompt(question, answer)
|
||||
response = self.llm.invoke(prompt)
|
||||
completed_question = self.rewrite_query_parser.parse(response.content)
|
||||
return completed_question.rewrite_query, completed_question.software_name
|
||||
except Exception as e:
|
||||
print(f"补全提问失败: {str(e)}")
|
||||
return question, ""
|
||||
|
||||
def process_row(self, row):
|
||||
"""
|
||||
处理单行数据
|
||||
|
||||
参数:
|
||||
row: DataFrame中的一行
|
||||
|
||||
返回:
|
||||
dict: 处理结果
|
||||
"""
|
||||
original_question = row[self.question_column]
|
||||
answer = row[self.answer_column]
|
||||
|
||||
# 调用LLM补全提问
|
||||
completed_question, software_name = self.complete_question(original_question, answer)
|
||||
|
||||
# 创建结果字典
|
||||
result = {
|
||||
"原始提问": original_question,
|
||||
"补全后的提问": completed_question,
|
||||
"软件名称": software_name
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def process(self):
|
||||
"""
|
||||
使用多线程处理所有提问并补全内容
|
||||
|
||||
读取Excel文件中的提问和回答,调用LLM补全提问内容,
|
||||
并将原提问和补全后的提问保存到新的Excel文件中
|
||||
"""
|
||||
results = []
|
||||
total = len(self.df)
|
||||
|
||||
# 使用进度条显示总体进度
|
||||
with tqdm(total=total, desc="补全提问") as pbar:
|
||||
# 创建线程池
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
# 提交所有任务
|
||||
future_to_idx = {executor.submit(self.process_row, self.df.iloc[idx]): idx for idx in range(total)}
|
||||
|
||||
# 处理完成的任务
|
||||
for future in concurrent.futures.as_completed(future_to_idx):
|
||||
result = future.result()
|
||||
with self.lock:
|
||||
results.append(result)
|
||||
pbar.update(1)
|
||||
|
||||
# 将结果转换为DataFrame并保存
|
||||
results_df = pd.DataFrame(results)
|
||||
results_df.to_excel(self.output_path, index=False)
|
||||
print(f"处理完成,共处理 {len(results)} 条记录,结果已保存至 {self.output_path}")
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='补全Excel文件中的提问内容')
|
||||
parser.add_argument('-i', '--input', type=str, default="/data/Rag2_0/data/excel/历史提问数据(dislike).xlsx",
|
||||
help='输入Excel文件路径')
|
||||
parser.add_argument('-o', '--output', type=str, default="/data/Rag2_0/data/excel/补全后的提问数据.xlsx",
|
||||
help='输出Excel文件路径')
|
||||
parser.add_argument('-q', '--question', type=str, default="提问",
|
||||
help='提问列的名称')
|
||||
parser.add_argument('-a', '--answer', type=str, default="回答",
|
||||
help='回答列的名称')
|
||||
parser.add_argument('-w', '--workers', type=int, default=50,
|
||||
help='最大线程数')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 创建提问补全工具实例
|
||||
completer = QuestionCompleter(
|
||||
input_path=args.input,
|
||||
output_path=args.output,
|
||||
question_column=args.question,
|
||||
answer_column=args.answer,
|
||||
max_workers=args.workers
|
||||
)
|
||||
|
||||
# 执行处理
|
||||
completer.process()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user