from dotenv import load_dotenv load_dotenv() import json import sys from app.observability import init_observability from app.settings import init_settings import nest_asyncio nest_asyncio.apply() from llama_index.core.node_parser import SentenceSplitter from llama_index.core import SimpleDirectoryReader from llama_index.core.evaluation import DatasetGenerator import prompts init_settings() init_observability() # 读取所有文档(即所有表格) documents = SimpleDirectoryReader("D:/LLM_model/text2sql/zjdataai-app-test/backend/data-test").load_data() # 定义表格名称和索引的对应关系 table_names = { "工程信息表": 0, "其他费用表": 1, "取费表": 2, "项目划分表": 3, "项目划分_费用预览表": 4, "总算表": 5, "工程量表": 6 } # 定义中文提示词和Python代码中提示词名称的映射 prompt_mapping = { "普通属性": "Attribute_Prompt", "金额查询": "Amount_Prompt", "单位换算": "Units_Prompt", "重名项目划分": "Name_Prompt", "总体金额查询": "All_Amount_Prompt" } # 定义表格与其对应的查询类别 table_prompt_mapping = { "工程信息表": ["普通属性", "单位换算"], "其他费用表": ["金额查询", "单位换算"], "取费表": ["金额查询"], "总算表": ["金额查询", "总体金额查询"], "工程量表": ["普通属性", "重名项目划分"] } # 根据表格名称选择特定的表格 def select_document(documents, table_name): if table_name not in table_names: raise ValueError(f"未找到名为 '{table_name}' 的表格") index = table_names[table_name] return [documents[index]] # 返回一个包含所选表格的列表 # 选择提示词 def select_prompt(prompt_category): prompt_name = prompt_mapping.get(prompt_category) if not prompt_name: raise ValueError(f"未找到名为 '{prompt_category}' 的提示词") try: return getattr(prompts, prompt_name) except AttributeError: raise ValueError(f"未找到提示词 '{prompt_name}' 对应的函数") # 生成问题和答案 def generate_questions_from_document(document, quest_prompt, num_questions): question_generator = DatasetGenerator.from_documents( documents=document, question_gen_query=quest_prompt, num_questions_per_chunk=num_questions ) eval_questions = question_generator.generate_questions_from_nodes(num_questions) print(eval_questions) qa_pairs = [] for qa in eval_questions: if ',' in qa: question, answer = qa.split(",", 1) qa_pairs.append({ "question": question.strip(), "answer": answer.strip() }) else: print(f"无法处理的问题和答案: {qa}") return qa_pairs # 主函数,控制生成多个表格的问题和使用多个提示词,并将结果合并到一个文件中 def main(documents, table_names_input, prompt_categories_input, num_questions_per_prompt): if table_names_input == "all": selected_tables = list(table_prompt_mapping.keys()) else: selected_tables = table_names_input.strip('[]').split(',') all_results = {} for table_name in selected_tables: table_name = table_name.strip() # 去掉前后空格 document = select_document(documents, table_name) if prompt_categories_input == "all": selected_prompts = table_prompt_mapping[table_name] else: selected_prompts = prompt_categories_input.strip('[]').split(',') selected_prompts = [p.strip() for p in selected_prompts] # 去掉前后空格 for prompt_category in selected_prompts: if prompt_category not in table_prompt_mapping[table_name]: print(f"跳过表格 '{table_name}' 的提示词 '{prompt_category}',因为该表中不包含该类别的信息") continue quest_prompt = select_prompt(prompt_category).format(num_questions_per_chunk=num_questions_per_prompt) qa_pairs = generate_questions_from_document(document, quest_prompt, num_questions_per_prompt) label = f"test:{table_name}_{prompt_category}" all_results[label] = qa_pairs # 自动生成输出文件名 output_file = "combined_test.json" with open(output_file, "w", encoding="utf-8") as f: json.dump(all_results, f, ensure_ascii=False, indent=4) print(f"All questions and answers have been saved to '{output_file}'") # 获取命令行参数 if __name__ == "__main__": if len(sys.argv) != 4: print("Usage: python script.py ") else: table_names_input = sys.argv[1] prompt_categories_input = sys.argv[2] num_questions_per_prompt = int(sys.argv[3]) main(documents, table_names_input, prompt_categories_input, num_questions_per_prompt)