新增单元测试
This commit is contained in:
@@ -0,0 +1,190 @@
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
from typing import List,Dict
|
||||
import json,sys,os
|
||||
from app.observability import init_observability
|
||||
from app.settings import init_settings
|
||||
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.evaluation import DatasetGenerator
|
||||
|
||||
import prompts
|
||||
from app.engine.loaders.markdownReader import ChunkMarkdownReader
|
||||
from llama_index.core.readers.json import JSONReader
|
||||
from llama_index.core.schema import Document
|
||||
from llama_index.core.prompts.base import PromptTemplate
|
||||
from llama_index.core.prompts.prompt_type import PromptType
|
||||
from app.engine.loaders import get_document_Types,getFileCacahePath
|
||||
|
||||
init_settings()
|
||||
init_observability()
|
||||
|
||||
|
||||
text_question_template = """\
|
||||
给定的上下文信息如下:
|
||||
---------------------
|
||||
{context_str}
|
||||
---------------------
|
||||
根据给定的上下文信息而非先验知识。
|
||||
仅围绕下面的问题描述生成多个问题。
|
||||
{query_str}
|
||||
"""
|
||||
|
||||
text_qa_template = """\
|
||||
给定的上下文信息如下:
|
||||
---------------------
|
||||
{context_str}
|
||||
---------------------
|
||||
根据给定的上下文信息而非先验知识,回答问题。
|
||||
问题: {query_str}
|
||||
回答:
|
||||
"""
|
||||
|
||||
class FileLoader:
|
||||
def __init__(self) -> None:
|
||||
self._prjTabels:Dict[str,any] = {}
|
||||
|
||||
def load(self):
|
||||
rootPath = getFileCacahePath()
|
||||
prjFlags = get_document_Types()
|
||||
prjDocs = {}
|
||||
for prjFlag in prjFlags:
|
||||
filePath = os.path.join(rootPath,prjFlag.replace('_','\\'))
|
||||
extrator = self._get_FileExtrator()
|
||||
documents = SimpleDirectoryReader(input_dir = filePath,file_extractor = extrator).load_data()
|
||||
prjDocs[prjFlag] = documents
|
||||
self._add_tables(prjFlag,documents)
|
||||
return prjDocs
|
||||
|
||||
def get_TableNames(self):
|
||||
return self._prjTabels
|
||||
|
||||
def _add_tables(self,prjFlag:str,documents:List[Document]):
|
||||
fileNames = []
|
||||
for doc in documents:
|
||||
meta = doc.metadata
|
||||
fileBaseName = meta['file_name']
|
||||
fileName = os.path.splitext(os.path.basename(fileBaseName))[0]
|
||||
if fileName not in fileNames:
|
||||
fileNames.append(fileName)
|
||||
if len(fileNames) > 0:
|
||||
self._prjTabels[prjFlag] = fileNames
|
||||
|
||||
def _get_FileExtrator(self):
|
||||
parser = {
|
||||
".md" : ChunkMarkdownReader(),
|
||||
}
|
||||
return parser
|
||||
|
||||
# 定义中文提示词和Python代码中提示词名称的映射
|
||||
prompt_mapping = {
|
||||
"普通属性": "Attribute_Prompt",
|
||||
"金额查询": "Amount_Prompt",
|
||||
"单位换算": "Units_Prompt",
|
||||
"重名项目划分": "Name_Prompt",
|
||||
"总体金额查询": "All_Amount_Prompt"
|
||||
}
|
||||
|
||||
# 定义表格与其对应的查询类别
|
||||
table_prompt_mapping = {
|
||||
"工程信息": ["普通属性", "单位换算"],
|
||||
"其他费用": ["金额查询", "单位换算"],
|
||||
"取费表": ["金额查询"],
|
||||
"总算表": ["金额查询"],
|
||||
"工程量": ["普通属性", "重名项目划分"]
|
||||
}
|
||||
|
||||
# 根据表格名称选择特定的表格
|
||||
def select_documents(documents:List[Document], table_name:str):
|
||||
docNodes = []
|
||||
for doc in documents:
|
||||
meta = doc.metadata
|
||||
fileBaseName = meta['file_name']
|
||||
fileName = os.path.splitext(os.path.basename(fileBaseName))[0]
|
||||
if table_name == fileName:
|
||||
docNodes.append(doc)
|
||||
if len(docNodes) == 0:
|
||||
raise ValueError(f"未找到名为 '{table_name}' 的节点")
|
||||
return docNodes
|
||||
|
||||
# 选择提示词
|
||||
def select_prompt(prompt_category):
|
||||
prompt_name = prompt_mapping.get(prompt_category)
|
||||
if not prompt_name:
|
||||
raise ValueError(f"未找到名为 '{prompt_category}' 的提示词")
|
||||
try:
|
||||
return getattr(prompts, prompt_name)
|
||||
except AttributeError:
|
||||
raise ValueError(f"未找到提示词 '{prompt_name}' 对应的函数")
|
||||
|
||||
def get_Prompts(tableName:str):
|
||||
for name,prompts in table_prompt_mapping.items():
|
||||
if name in tableName:
|
||||
return prompts
|
||||
return []
|
||||
|
||||
# 生成问题和答案
|
||||
def generate_questions_from_document(document, quest_prompt, num_questions):
|
||||
question_generator = DatasetGenerator.from_documents(
|
||||
documents=document,
|
||||
text_question_template = PromptTemplate(text_question_template),
|
||||
question_gen_query=quest_prompt,
|
||||
text_qa_template = PromptTemplate(text_qa_template, prompt_type=PromptType.QUESTION_ANSWER),
|
||||
num_questions_per_chunk=num_questions
|
||||
)
|
||||
|
||||
eval_questions = question_generator.generate_dataset_from_nodes(num_questions)
|
||||
print(eval_questions)
|
||||
|
||||
qa_pairs = []
|
||||
for qa in eval_questions.qr_pairs:
|
||||
qa_pairs.append({
|
||||
"question": qa[0].strip(),
|
||||
"answer": qa[1].strip()
|
||||
})
|
||||
return qa_pairs
|
||||
|
||||
# 主函数,控制生成多个表格的问题和使用多个提示词,并将结果合并到一个文件中
|
||||
def main(documents:List[Document], table_names:List[str],num_questions_per_prompt,filePath:str):
|
||||
all_results = {}
|
||||
for table_name in table_names:
|
||||
table_name = table_name.strip()
|
||||
tbDocs = select_documents(documents, table_name)
|
||||
selected_prompts = get_Prompts(table_name)
|
||||
for prompt_category in selected_prompts:
|
||||
quest_prompt = select_prompt(prompt_category).format(num_questions_per_chunk=num_questions_per_prompt)
|
||||
qa_pairs = generate_questions_from_document(tbDocs, quest_prompt, num_questions_per_prompt)
|
||||
label = f"test:{table_name}_{prompt_category}"
|
||||
all_results[label] = qa_pairs
|
||||
|
||||
# 自动生成输出文件名
|
||||
with open(filePath, "w", encoding="utf-8") as f:
|
||||
json.dump(all_results, f, ensure_ascii=False, indent=4)
|
||||
|
||||
# 获取命令行参数
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
raise ValueError("Usage: python script.py <table_names_input> <prompt_categories_input> <num_questions_per_prompt>")
|
||||
table_names_input = sys.argv[0]
|
||||
num_questions_per_prompt = int(sys.argv[1])
|
||||
|
||||
# table_names_input = '[总算表]'
|
||||
# num_questions_per_prompt = 2
|
||||
|
||||
que_Dir = os.path.join(os.getcwd(),f'unit_test\\Quetions')
|
||||
loader = FileLoader()
|
||||
prjDocs = loader.load()
|
||||
prjTableNames = loader.get_TableNames()
|
||||
for prjFlg,documents in prjDocs.items():
|
||||
if table_names_input == "all":
|
||||
table_names = prjTableNames[prjFlg]
|
||||
else:
|
||||
table_names = table_names_input.strip('[]').split(',')
|
||||
os.makedirs(que_Dir,exist_ok = True)
|
||||
filePath =os.path.join(que_Dir,f'{prjFlg}.json')
|
||||
main(documents, table_names, num_questions_per_prompt,filePath)
|
||||
Reference in New Issue
Block a user