145 lines
4.9 KiB
Python
145 lines
4.9 KiB
Python
|
|
from dotenv import load_dotenv
|
|
load_dotenv()
|
|
|
|
import json
|
|
import sys
|
|
|
|
|
|
from app.observability import init_observability
|
|
from app.settings import init_settings
|
|
|
|
import nest_asyncio
|
|
nest_asyncio.apply()
|
|
|
|
from llama_index.core.node_parser import SentenceSplitter
|
|
from llama_index.core import SimpleDirectoryReader
|
|
from llama_index.core.evaluation import DatasetGenerator
|
|
|
|
import prompts
|
|
|
|
init_settings()
|
|
init_observability()
|
|
|
|
# 读取所有文档(即所有表格)
|
|
documents = SimpleDirectoryReader("D:/LLM_model/text2sql/zjdataai-app-test/backend/data-test").load_data()
|
|
|
|
# 定义表格名称和索引的对应关系
|
|
table_names = {
|
|
"工程信息表": 0,
|
|
"其他费用表": 1,
|
|
"取费表": 2,
|
|
"项目划分表": 3,
|
|
"项目划分_费用预览表": 4,
|
|
"总算表": 5,
|
|
"工程量表": 6
|
|
}
|
|
|
|
# 定义中文提示词和Python代码中提示词名称的映射
|
|
prompt_mapping = {
|
|
"普通属性": "Attribute_Prompt",
|
|
"金额查询": "Amount_Prompt",
|
|
"单位换算": "Units_Prompt",
|
|
"重名项目划分": "Name_Prompt",
|
|
"总体金额查询": "All_Amount_Prompt"
|
|
}
|
|
|
|
# 定义表格与其对应的查询类别
|
|
table_prompt_mapping = {
|
|
"工程信息表": ["普通属性", "单位换算"],
|
|
"其他费用表": ["金额查询", "单位换算"],
|
|
"取费表": ["金额查询"],
|
|
"总算表": ["金额查询", "总体金额查询"],
|
|
"工程量表": ["普通属性", "重名项目划分"]
|
|
}
|
|
|
|
# 根据表格名称选择特定的表格
|
|
def select_document(documents, table_name):
|
|
if table_name not in table_names:
|
|
raise ValueError(f"未找到名为 '{table_name}' 的表格")
|
|
index = table_names[table_name]
|
|
return [documents[index]] # 返回一个包含所选表格的列表
|
|
|
|
# 选择提示词
|
|
def select_prompt(prompt_category):
|
|
prompt_name = prompt_mapping.get(prompt_category)
|
|
if not prompt_name:
|
|
raise ValueError(f"未找到名为 '{prompt_category}' 的提示词")
|
|
try:
|
|
return getattr(prompts, prompt_name)
|
|
except AttributeError:
|
|
raise ValueError(f"未找到提示词 '{prompt_name}' 对应的函数")
|
|
|
|
# 生成问题和答案
|
|
def generate_questions_from_document(document, quest_prompt, num_questions):
|
|
question_generator = DatasetGenerator.from_documents(
|
|
documents=document,
|
|
question_gen_query=quest_prompt,
|
|
num_questions_per_chunk=num_questions
|
|
)
|
|
|
|
eval_questions = question_generator.generate_questions_from_nodes(num_questions)
|
|
print(eval_questions)
|
|
|
|
qa_pairs = []
|
|
for qa in eval_questions:
|
|
if ',' in qa:
|
|
question, answer = qa.split(",", 1)
|
|
qa_pairs.append({
|
|
"question": question.strip(),
|
|
"answer": answer.strip()
|
|
})
|
|
else:
|
|
print(f"无法处理的问题和答案: {qa}")
|
|
|
|
return qa_pairs
|
|
|
|
# 主函数,控制生成多个表格的问题和使用多个提示词,并将结果合并到一个文件中
|
|
def main(documents, table_names_input, prompt_categories_input, num_questions_per_prompt):
|
|
if table_names_input == "all":
|
|
selected_tables = list(table_prompt_mapping.keys())
|
|
else:
|
|
selected_tables = table_names_input.strip('[]').split(',')
|
|
|
|
all_results = {}
|
|
|
|
for table_name in selected_tables:
|
|
table_name = table_name.strip() # 去掉前后空格
|
|
document = select_document(documents, table_name)
|
|
|
|
if prompt_categories_input == "all":
|
|
selected_prompts = table_prompt_mapping[table_name]
|
|
else:
|
|
selected_prompts = prompt_categories_input.strip('[]').split(',')
|
|
selected_prompts = [p.strip() for p in selected_prompts] # 去掉前后空格
|
|
|
|
for prompt_category in selected_prompts:
|
|
if prompt_category not in table_prompt_mapping[table_name]:
|
|
print(f"跳过表格 '{table_name}' 的提示词 '{prompt_category}',因为该表中不包含该类别的信息")
|
|
continue
|
|
|
|
quest_prompt = select_prompt(prompt_category).format(num_questions_per_chunk=num_questions_per_prompt)
|
|
qa_pairs = generate_questions_from_document(document, quest_prompt, num_questions_per_prompt)
|
|
|
|
label = f"test:{table_name}_{prompt_category}"
|
|
all_results[label] = qa_pairs
|
|
|
|
# 自动生成输出文件名
|
|
output_file = "combined_test.json"
|
|
|
|
with open(output_file, "w", encoding="utf-8") as f:
|
|
json.dump(all_results, f, ensure_ascii=False, indent=4)
|
|
|
|
print(f"All questions and answers have been saved to '{output_file}'")
|
|
|
|
# 获取命令行参数
|
|
if __name__ == "__main__":
|
|
if len(sys.argv) != 4:
|
|
print("Usage: python script.py <table_names_input> <prompt_categories_input> <num_questions_per_prompt>")
|
|
else:
|
|
table_names_input = sys.argv[1]
|
|
prompt_categories_input = sys.argv[2]
|
|
num_questions_per_prompt = int(sys.argv[3])
|
|
|
|
main(documents, table_names_input, prompt_categories_input, num_questions_per_prompt)
|