from dotenv import load_dotenv load_dotenv() from typing import List,Dict import json,sys,os from app.observability import init_observability from app.settings import init_settings import nest_asyncio nest_asyncio.apply() from llama_index.core import SimpleDirectoryReader from llama_index.core.evaluation import DatasetGenerator import prompts from app.engine.loaders.markdownReader import ChunkMarkdownReader from llama_index.core.readers.json import JSONReader from llama_index.core.schema import Document from llama_index.core.prompts.base import PromptTemplate from llama_index.core.prompts.prompt_type import PromptType from app.engine.loaders import get_document_Types,getFileCacahePath init_settings() init_observability() text_question_template = """\ 给定的上下文信息如下: --------------------- {context_str} --------------------- 根据给定的上下文信息而非先验知识。 仅围绕下面的问题描述生成多个问题。 {query_str} """ text_qa_template = """\ 给定的上下文信息如下: --------------------- {context_str} --------------------- 根据给定的上下文信息而非先验知识,回答问题。 问题: {query_str} 回答: """ class FileLoader: def __init__(self) -> None: self._prjTabels:Dict[str,any] = {} def load(self): rootPath = getFileCacahePath() prjFlags = get_document_Types() prjDocs = {} for prjFlag in prjFlags: filePath = os.path.join(rootPath,prjFlag.replace('_','\\')) extrator = self._get_FileExtrator() documents = SimpleDirectoryReader(input_dir = filePath,file_extractor = extrator).load_data() prjDocs[prjFlag] = documents self._add_tables(prjFlag,documents) return prjDocs def get_TableNames(self): return self._prjTabels def _add_tables(self,prjFlag:str,documents:List[Document]): fileNames = [] for doc in documents: meta = doc.metadata fileBaseName = meta['file_name'] fileName = os.path.splitext(os.path.basename(fileBaseName))[0] if fileName not in fileNames: fileNames.append(fileName) if len(fileNames) > 0: self._prjTabels[prjFlag] = fileNames def _get_FileExtrator(self): parser = { ".md" : ChunkMarkdownReader(), } return parser # 定义中文提示词和Python代码中提示词名称的映射 prompt_mapping = { "普通属性": "Attribute_Prompt", "金额查询": "Amount_Prompt", "单位换算": "Units_Prompt", "重名项目划分": "Name_Prompt", "总体金额查询": "All_Amount_Prompt" } # 定义表格与其对应的查询类别 table_prompt_mapping = { "工程信息": ["普通属性", "单位换算"], "其他费用": ["金额查询", "单位换算"], "取费表": ["金额查询"], "总算表": ["金额查询"], "工程量": ["普通属性", "重名项目划分"] } # 根据表格名称选择特定的表格 def select_documents(documents:List[Document], table_name:str): docNodes = [] for doc in documents: meta = doc.metadata fileBaseName = meta['file_name'] fileName = os.path.splitext(os.path.basename(fileBaseName))[0] if table_name == fileName: docNodes.append(doc) if len(docNodes) == 0: raise ValueError(f"未找到名为 '{table_name}' 的节点") return docNodes # 选择提示词 def select_prompt(prompt_category): prompt_name = prompt_mapping.get(prompt_category) if not prompt_name: raise ValueError(f"未找到名为 '{prompt_category}' 的提示词") try: return getattr(prompts, prompt_name) except AttributeError: raise ValueError(f"未找到提示词 '{prompt_name}' 对应的函数") def get_Prompts(tableName:str): for name,prompts in table_prompt_mapping.items(): if name in tableName: return prompts return [] # 生成问题和答案 def generate_questions_from_document(document, quest_prompt, num_questions): question_generator = DatasetGenerator.from_documents( documents=document, text_question_template = PromptTemplate(text_question_template), question_gen_query=quest_prompt, text_qa_template = PromptTemplate(text_qa_template, prompt_type=PromptType.QUESTION_ANSWER), num_questions_per_chunk=num_questions ) eval_questions = question_generator.generate_dataset_from_nodes(num_questions) print(eval_questions) qa_pairs = [] for qa in eval_questions.qr_pairs: qa_pairs.append({ "question": qa[0].strip(), "answer": qa[1].strip() }) return qa_pairs # 主函数,控制生成多个表格的问题和使用多个提示词,并将结果合并到一个文件中 def main(documents:List[Document], table_names:List[str],num_questions_per_prompt,filePath:str): all_results = {} for table_name in table_names: table_name = table_name.strip() tbDocs = select_documents(documents, table_name) selected_prompts = get_Prompts(table_name) for prompt_category in selected_prompts: quest_prompt = select_prompt(prompt_category).format(num_questions_per_chunk=num_questions_per_prompt) qa_pairs = generate_questions_from_document(tbDocs, quest_prompt, num_questions_per_prompt) label = f"test:{table_name}_{prompt_category}" all_results[label] = qa_pairs # 自动生成输出文件名 with open(filePath, "w", encoding="utf-8") as f: json.dump(all_results, f, ensure_ascii=False, indent=4) # 获取命令行参数 if __name__ == "__main__": if len(sys.argv) != 2: raise ValueError("Usage: python script.py ") table_names_input = sys.argv[0] num_questions_per_prompt = int(sys.argv[1]) # table_names_input = '[总算表]' # num_questions_per_prompt = 2 que_Dir = os.path.join(os.getcwd(),f'unit_test\\Quetions') loader = FileLoader() prjDocs = loader.load() prjTableNames = loader.get_TableNames() for prjFlg,documents in prjDocs.items(): if table_names_input == "all": table_names = prjTableNames[prjFlg] else: table_names = table_names_input.strip('[]').split(',') os.makedirs(que_Dir,exist_ok = True) filePath =os.path.join(que_Dir,f'{prjFlg}.json') main(documents, table_names, num_questions_per_prompt,filePath)