新增单元测试

This commit is contained in:
wanyaokun
2024-09-12 13:58:42 +08:00
parent 47437044cb
commit c262aec6bd
13 changed files with 374 additions and 1493 deletions
+190
View File
@@ -0,0 +1,190 @@
from dotenv import load_dotenv
load_dotenv()
from typing import List,Dict
import json,sys,os
from app.observability import init_observability
from app.settings import init_settings
import nest_asyncio
nest_asyncio.apply()
from llama_index.core import SimpleDirectoryReader
from llama_index.core.evaluation import DatasetGenerator
import prompts
from app.engine.loaders.markdownReader import ChunkMarkdownReader
from llama_index.core.readers.json import JSONReader
from llama_index.core.schema import Document
from llama_index.core.prompts.base import PromptTemplate
from llama_index.core.prompts.prompt_type import PromptType
from app.engine.loaders import get_document_Types,getFileCacahePath
init_settings()
init_observability()
text_question_template = """\
给定的上下文信息如下:
---------------------
{context_str}
---------------------
根据给定的上下文信息而非先验知识。
仅围绕下面的问题描述生成多个问题。
{query_str}
"""
text_qa_template = """\
给定的上下文信息如下:
---------------------
{context_str}
---------------------
根据给定的上下文信息而非先验知识,回答问题。
问题: {query_str}
回答:
"""
class FileLoader:
def __init__(self) -> None:
self._prjTabels:Dict[str,any] = {}
def load(self):
rootPath = getFileCacahePath()
prjFlags = get_document_Types()
prjDocs = {}
for prjFlag in prjFlags:
filePath = os.path.join(rootPath,prjFlag.replace('_','\\'))
extrator = self._get_FileExtrator()
documents = SimpleDirectoryReader(input_dir = filePath,file_extractor = extrator).load_data()
prjDocs[prjFlag] = documents
self._add_tables(prjFlag,documents)
return prjDocs
def get_TableNames(self):
return self._prjTabels
def _add_tables(self,prjFlag:str,documents:List[Document]):
fileNames = []
for doc in documents:
meta = doc.metadata
fileBaseName = meta['file_name']
fileName = os.path.splitext(os.path.basename(fileBaseName))[0]
if fileName not in fileNames:
fileNames.append(fileName)
if len(fileNames) > 0:
self._prjTabels[prjFlag] = fileNames
def _get_FileExtrator(self):
parser = {
".md" : ChunkMarkdownReader(),
}
return parser
# 定义中文提示词和Python代码中提示词名称的映射
prompt_mapping = {
"普通属性": "Attribute_Prompt",
"金额查询": "Amount_Prompt",
"单位换算": "Units_Prompt",
"重名项目划分": "Name_Prompt",
"总体金额查询": "All_Amount_Prompt"
}
# 定义表格与其对应的查询类别
table_prompt_mapping = {
"工程信息": ["普通属性", "单位换算"],
"其他费用": ["金额查询", "单位换算"],
"取费表": ["金额查询"],
"总算表": ["金额查询"],
"工程量": ["普通属性", "重名项目划分"]
}
# 根据表格名称选择特定的表格
def select_documents(documents:List[Document], table_name:str):
docNodes = []
for doc in documents:
meta = doc.metadata
fileBaseName = meta['file_name']
fileName = os.path.splitext(os.path.basename(fileBaseName))[0]
if table_name == fileName:
docNodes.append(doc)
if len(docNodes) == 0:
raise ValueError(f"未找到名为 '{table_name}' 的节点")
return docNodes
# 选择提示词
def select_prompt(prompt_category):
prompt_name = prompt_mapping.get(prompt_category)
if not prompt_name:
raise ValueError(f"未找到名为 '{prompt_category}' 的提示词")
try:
return getattr(prompts, prompt_name)
except AttributeError:
raise ValueError(f"未找到提示词 '{prompt_name}' 对应的函数")
def get_Prompts(tableName:str):
for name,prompts in table_prompt_mapping.items():
if name in tableName:
return prompts
return []
# 生成问题和答案
def generate_questions_from_document(document, quest_prompt, num_questions):
question_generator = DatasetGenerator.from_documents(
documents=document,
text_question_template = PromptTemplate(text_question_template),
question_gen_query=quest_prompt,
text_qa_template = PromptTemplate(text_qa_template, prompt_type=PromptType.QUESTION_ANSWER),
num_questions_per_chunk=num_questions
)
eval_questions = question_generator.generate_dataset_from_nodes(num_questions)
print(eval_questions)
qa_pairs = []
for qa in eval_questions.qr_pairs:
qa_pairs.append({
"question": qa[0].strip(),
"answer": qa[1].strip()
})
return qa_pairs
# 主函数,控制生成多个表格的问题和使用多个提示词,并将结果合并到一个文件中
def main(documents:List[Document], table_names:List[str],num_questions_per_prompt,filePath:str):
all_results = {}
for table_name in table_names:
table_name = table_name.strip()
tbDocs = select_documents(documents, table_name)
selected_prompts = get_Prompts(table_name)
for prompt_category in selected_prompts:
quest_prompt = select_prompt(prompt_category).format(num_questions_per_chunk=num_questions_per_prompt)
qa_pairs = generate_questions_from_document(tbDocs, quest_prompt, num_questions_per_prompt)
label = f"test:{table_name}_{prompt_category}"
all_results[label] = qa_pairs
# 自动生成输出文件名
with open(filePath, "w", encoding="utf-8") as f:
json.dump(all_results, f, ensure_ascii=False, indent=4)
# 获取命令行参数
if __name__ == "__main__":
if len(sys.argv) != 2:
raise ValueError("Usage: python script.py <table_names_input> <prompt_categories_input> <num_questions_per_prompt>")
table_names_input = sys.argv[0]
num_questions_per_prompt = int(sys.argv[1])
# table_names_input = '[总算表]'
# num_questions_per_prompt = 2
que_Dir = os.path.join(os.getcwd(),f'unit_test\\Quetions')
loader = FileLoader()
prjDocs = loader.load()
prjTableNames = loader.get_TableNames()
for prjFlg,documents in prjDocs.items():
if table_names_input == "all":
table_names = prjTableNames[prjFlg]
else:
table_names = table_names_input.strip('[]').split(',')
os.makedirs(que_Dir,exist_ok = True)
filePath =os.path.join(que_Dir,f'{prjFlg}.json')
main(documents, table_names, num_questions_per_prompt,filePath)