diff --git a/backend/unit_test/corr_test.py b/backend/unit_test/corr_test.py index d97e902..9a09116 100644 --- a/backend/unit_test/corr_test.py +++ b/backend/unit_test/corr_test.py @@ -1,5 +1,20 @@ -import json from dotenv import load_dotenv +load_dotenv() + +from llama_index.core.evaluation import CorrectnessEvaluator +from app.engine import get_chat_engine +from app.engine.index import get_index +from app.observability import init_observability +from app.settings import init_settings + +init_settings() +init_observability() + +index = get_index() + + +import os +import json import asyncio import nest_asyncio nest_asyncio.apply() @@ -55,31 +70,14 @@ DEFAULT_EVAL_TEMPLATE = ChatPromptTemplate( ] ) -from app.api.routers.models import ChatData, Message -from llama_index.core.chat_engine.types import BaseChatEngine, NodeWithScore -from llama_index.core.vector_stores.types import MetadataFilter, MetadataFilters -from llama_index.core.evaluation import CorrectnessEvaluator -from app.engine import get_chat_engine -from app.api.routers.chat import generate_filters -from app.engine.index import get_index -from app.observability import init_observability -from app.settings import init_settings - - -load_dotenv() - - -init_settings() -init_observability() - -index = get_index() # 初始化聊天引擎和评估器 chat_engine = get_chat_engine() corr_evaluator_qwen = CorrectnessEvaluator() # 加载本地问题回答文件 -file_path = 'D:/LLM_model/text2sql/zjdataai-app-test/backend/unit_test/test.json' +script_dir = os.path.dirname(os.path.abspath(__file__)) +file_path = os.path.join(script_dir, 'questions_and_answers.json') output_file_path = file_path.replace('.json', '_test.json') with open(file_path, 'r', encoding='utf-8') as f: @@ -88,8 +86,13 @@ with open(file_path, 'r', encoding='utf-8') as f: # 异步函数用于评估查询 async def evaluate_query(question, answer, index, output_file): response = await chat_engine.astream_chat(question) - content_str = str(response.sources[0]) - + + # 检查sources是否为空 + if response.sources: + content_str = str(response.sources[0]) + else: + content_str = "<无回答>" + result = corr_evaluator_qwen.evaluate( query=question, response=content_str, @@ -101,13 +104,13 @@ async def evaluate_query(question, answer, index, output_file): "问题": question, "答案": answer, "回答": result.response, - "得分(0~5)": result.score, + "得分(1~5)": result.score, "评价": result.feedback } with open(output_file, 'a', encoding='utf-8') as f: f.write(json.dumps(result_dict, ensure_ascii=False, indent=4)) - f.write(',') + f.write(',\n') # 主异步函数 async def main(): diff --git a/backend/unit_test/question.py b/backend/unit_test/question.py new file mode 100644 index 0000000..95c2c27 --- /dev/null +++ b/backend/unit_test/question.py @@ -0,0 +1,58 @@ +from dotenv import load_dotenv +load_dotenv() + +from app.observability import init_observability +from app.settings import init_settings + +import nest_asyncio +nest_asyncio.apply() + +from llama_index.core.node_parser import SentenceSplitter +from llama_index.core import SimpleDirectoryReader + +from llama_index.core.evaluation import DatasetGenerator + +import json + +init_settings() +init_observability() + +documents = SimpleDirectoryReader("backend\data-test").load_data() + +splitter = SentenceSplitter(chunk_size=512) + +# question_generator = DatasetGenerator.from_documents(documents) +quest_prompt = ( + "你是一个电力造价工程相关的项目经理,现在给你一些上下文信息," + "你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答," + "问题的实例应该是这种类型的:'人工费的费率是多少?,费率是100','前期工作管理费用的金额是多少?,金额是0'," + "这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题," + "不要生成一些过于简单的问题,不要生成一些过于复杂的问题。" +) + +question_generator = DatasetGenerator.from_documents( + documents=documents, + question_gen_query=quest_prompt, + num_questions_per_chunk=5 #生成的问题数 +) + +eval_questions = question_generator.generate_questions_from_nodes(5) + +# print(eval_questions) + +# 处理生成的问题和答案,转换为JSON格式 +qa_pairs = [] +for qa in eval_questions: + # 处理可能没有 ',' 的情况 + if '?' in qa: + question, answer = qa.split("?", 1) + qa_pairs.append({ + "question": question.strip(), + "answer": answer.strip() + }) + else: + print(f"无法处理的问题和答案: {qa}") + +# 保存为JSON文件 +with open("backend/unit_test/questions_and_answers.json", "w", encoding="utf-8") as f: + json.dump(qa_pairs, f, ensure_ascii=False, indent=4)