191 lines
6.5 KiB
Python
191 lines
6.5 KiB
Python
|
|
from dotenv import load_dotenv
|
|
load_dotenv()
|
|
|
|
from typing import List,Dict
|
|
import json,sys,os
|
|
from app.observability import init_observability
|
|
from app.settings import init_settings
|
|
|
|
import nest_asyncio
|
|
nest_asyncio.apply()
|
|
|
|
from llama_index.core import SimpleDirectoryReader
|
|
from llama_index.core.evaluation import DatasetGenerator
|
|
|
|
import prompts
|
|
from app.engine.loaders.markdownReader import ChunkMarkdownReader
|
|
from llama_index.core.readers.json import JSONReader
|
|
from llama_index.core.schema import Document
|
|
from llama_index.core.prompts.base import PromptTemplate
|
|
from llama_index.core.prompts.prompt_type import PromptType
|
|
from app.engine.loaders import get_document_Types,getFileCacahePath
|
|
|
|
init_settings()
|
|
init_observability()
|
|
|
|
|
|
text_question_template = """\
|
|
给定的上下文信息如下:
|
|
---------------------
|
|
{context_str}
|
|
---------------------
|
|
根据给定的上下文信息而非先验知识。
|
|
仅围绕下面的问题描述生成多个问题。
|
|
{query_str}
|
|
"""
|
|
|
|
text_qa_template = """\
|
|
给定的上下文信息如下:
|
|
---------------------
|
|
{context_str}
|
|
---------------------
|
|
根据给定的上下文信息而非先验知识,回答问题。
|
|
问题: {query_str}
|
|
回答:
|
|
"""
|
|
|
|
class FileLoader:
|
|
def __init__(self) -> None:
|
|
self._prjTabels:Dict[str,any] = {}
|
|
|
|
def load(self):
|
|
rootPath = getFileCacahePath()
|
|
prjFlags = get_document_Types()
|
|
prjDocs = {}
|
|
for prjFlag in prjFlags:
|
|
filePath = os.path.join(rootPath,prjFlag.replace('_','\\'))
|
|
extrator = self._get_FileExtrator()
|
|
documents = SimpleDirectoryReader(input_dir = filePath,file_extractor = extrator).load_data()
|
|
prjDocs[prjFlag] = documents
|
|
self._add_tables(prjFlag,documents)
|
|
return prjDocs
|
|
|
|
def get_TableNames(self):
|
|
return self._prjTabels
|
|
|
|
def _add_tables(self,prjFlag:str,documents:List[Document]):
|
|
fileNames = []
|
|
for doc in documents:
|
|
meta = doc.metadata
|
|
fileBaseName = meta['file_name']
|
|
fileName = os.path.splitext(os.path.basename(fileBaseName))[0]
|
|
if fileName not in fileNames:
|
|
fileNames.append(fileName)
|
|
if len(fileNames) > 0:
|
|
self._prjTabels[prjFlag] = fileNames
|
|
|
|
def _get_FileExtrator(self):
|
|
parser = {
|
|
".md" : ChunkMarkdownReader(),
|
|
}
|
|
return parser
|
|
|
|
# 定义中文提示词和Python代码中提示词名称的映射
|
|
prompt_mapping = {
|
|
"普通属性": "Attribute_Prompt",
|
|
"金额查询": "Amount_Prompt",
|
|
"单位换算": "Units_Prompt",
|
|
"重名项目划分": "Name_Prompt",
|
|
"总体金额查询": "All_Amount_Prompt"
|
|
}
|
|
|
|
# 定义表格与其对应的查询类别
|
|
table_prompt_mapping = {
|
|
"工程信息": ["普通属性", "单位换算"],
|
|
"其他费用": ["金额查询", "单位换算"],
|
|
"取费表": ["金额查询"],
|
|
"总算表": ["金额查询"],
|
|
"工程量": ["普通属性", "重名项目划分"]
|
|
}
|
|
|
|
# 根据表格名称选择特定的表格
|
|
def select_documents(documents:List[Document], table_name:str):
|
|
docNodes = []
|
|
for doc in documents:
|
|
meta = doc.metadata
|
|
fileBaseName = meta['file_name']
|
|
fileName = os.path.splitext(os.path.basename(fileBaseName))[0]
|
|
if table_name == fileName:
|
|
docNodes.append(doc)
|
|
if len(docNodes) == 0:
|
|
raise ValueError(f"未找到名为 '{table_name}' 的节点")
|
|
return docNodes
|
|
|
|
# 选择提示词
|
|
def select_prompt(prompt_category):
|
|
prompt_name = prompt_mapping.get(prompt_category)
|
|
if not prompt_name:
|
|
raise ValueError(f"未找到名为 '{prompt_category}' 的提示词")
|
|
try:
|
|
return getattr(prompts, prompt_name)
|
|
except AttributeError:
|
|
raise ValueError(f"未找到提示词 '{prompt_name}' 对应的函数")
|
|
|
|
def get_Prompts(tableName:str):
|
|
for name,prompts in table_prompt_mapping.items():
|
|
if name in tableName:
|
|
return prompts
|
|
return []
|
|
|
|
# 生成问题和答案
|
|
def generate_questions_from_document(document, quest_prompt, num_questions):
|
|
question_generator = DatasetGenerator.from_documents(
|
|
documents=document,
|
|
text_question_template = PromptTemplate(text_question_template),
|
|
question_gen_query=quest_prompt,
|
|
text_qa_template = PromptTemplate(text_qa_template, prompt_type=PromptType.QUESTION_ANSWER),
|
|
num_questions_per_chunk=num_questions
|
|
)
|
|
|
|
eval_questions = question_generator.generate_dataset_from_nodes(num_questions)
|
|
print(eval_questions)
|
|
|
|
qa_pairs = []
|
|
for qa in eval_questions.qr_pairs:
|
|
qa_pairs.append({
|
|
"question": qa[0].strip(),
|
|
"answer": qa[1].strip()
|
|
})
|
|
return qa_pairs
|
|
|
|
# 主函数,控制生成多个表格的问题和使用多个提示词,并将结果合并到一个文件中
|
|
def main(documents:List[Document], table_names:List[str],num_questions_per_prompt,filePath:str):
|
|
all_results = {}
|
|
for table_name in table_names:
|
|
table_name = table_name.strip()
|
|
tbDocs = select_documents(documents, table_name)
|
|
selected_prompts = get_Prompts(table_name)
|
|
for prompt_category in selected_prompts:
|
|
quest_prompt = select_prompt(prompt_category).format(num_questions_per_chunk=num_questions_per_prompt)
|
|
qa_pairs = generate_questions_from_document(tbDocs, quest_prompt, num_questions_per_prompt)
|
|
label = f"test:{table_name}_{prompt_category}"
|
|
all_results[label] = qa_pairs
|
|
|
|
# 自动生成输出文件名
|
|
with open(filePath, "w", encoding="utf-8") as f:
|
|
json.dump(all_results, f, ensure_ascii=False, indent=4)
|
|
|
|
# 获取命令行参数
|
|
if __name__ == "__main__":
|
|
if len(sys.argv) != 2:
|
|
raise ValueError("Usage: python script.py <table_names_input> <prompt_categories_input> <num_questions_per_prompt>")
|
|
table_names_input = sys.argv[0]
|
|
num_questions_per_prompt = int(sys.argv[1])
|
|
|
|
# table_names_input = '[总算表]'
|
|
# num_questions_per_prompt = 2
|
|
|
|
que_Dir = os.path.join(os.getcwd(),f'unit_test\\Quetions')
|
|
loader = FileLoader()
|
|
prjDocs = loader.load()
|
|
prjTableNames = loader.get_TableNames()
|
|
for prjFlg,documents in prjDocs.items():
|
|
if table_names_input == "all":
|
|
table_names = prjTableNames[prjFlg]
|
|
else:
|
|
table_names = table_names_input.strip('[]').split(',')
|
|
os.makedirs(que_Dir,exist_ok = True)
|
|
filePath =os.path.join(que_Dir,f'{prjFlg}.json')
|
|
main(documents, table_names, num_questions_per_prompt,filePath)
|