From 4a8c79e83dc12043b333654632447e3555880fa1 Mon Sep 17 00:00:00 2001 From: chentianrui Date: Thu, 29 Aug 2024 15:09:55 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8F=82=E6=95=B0=E4=BC=98=E5=8C=96=E9=92=88?= =?UTF-8?q?=E5=AF=B9=E9=97=AE=E9=A2=98=E5=81=9A=E5=87=BA=E4=BA=86=E8=B0=83?= =?UTF-8?q?=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/test1/ParamTuner.py | 67 +++++++++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 18 deletions(-) diff --git a/backend/test1/ParamTuner.py b/backend/test1/ParamTuner.py index a8bbb12..5e31b57 100644 --- a/backend/test1/ParamTuner.py +++ b/backend/test1/ParamTuner.py @@ -27,6 +27,13 @@ init_observability() # 读取文档 documents = SimpleDirectoryReader("D:/LLM_model/text2sql/zjdataai-app-test/backend/data-test").load_data() +# 参数字典 +param_dict = { + "chunk_size": [512, 1024], + "top_k": [1, 5], + "temperature": [0.1, 1.0] +} + # 辅助函数 def _build_index(chunk_size, documents): # 构建索引 @@ -56,8 +63,18 @@ async def _evaluate_query_engine_async(query_engine, questions): return total_correct, len(results) + + +# 生成问题 +question_generator = DatasetGenerator.from_documents(documents) +eval_questions = question_generator.generate_questions_from_nodes(1) # 假设生成10个问题 + +# 打印生成的问题 +for i, q in enumerate(eval_questions, start=1): + print(f"问题 {i}: {q}") + # 目标函数 -def objective_function(params_dict, documents, question_count): +def objective_function(params_dict, documents, questions): chunk_size = params_dict["chunk_size"] top_k = params_dict["top_k"] temperature = params_dict["temperature"] @@ -70,27 +87,25 @@ def objective_function(params_dict, documents, question_count): similarity_top_k=top_k, temperature=temperature ) - # 生成问题 - question_generator = DatasetGenerator.from_documents(documents) - eval_questions = question_generator.generate_questions_from_nodes(question_count) - # 评估查询引擎 - correct, total = evaluate_query_engine(query_engine, eval_questions) + correct, total = 0, len(questions) + question_answers = [] # 添加列表来收集问题和答案 + for question in questions: + response = query_engine.query(question) + if response is not None: + question_answers.append((question, response.response)) + eval_result = FaithfulnessEvaluator().evaluate_response(response=response, query_str=question) + if eval_result.passing: + correct += 1 + # 计算分数 score = correct / total if total > 0 else 0 - return RunResult(score=score, params=params_dict) - -# 参数字典 -param_dict = { - "chunk_size": [512, 1024], - "top_k": [1, 5], - "temperature": [0.1, 1.0] -} + return RunResult(score=score, params=params_dict, question_answers=question_answers) # 创建 ParamTuner 实例 param_tuner = ParamTuner( - param_fn=lambda params_dict: objective_function(params_dict, documents, 1), + param_fn=lambda params_dict: objective_function(params_dict, documents, eval_questions), param_dict=param_dict, show_progress=True, ) @@ -101,7 +116,23 @@ best_result = results.best_run_result best_top_k = best_result.params["top_k"] best_chunk_size = best_result.params["chunk_size"] best_temperature = best_result.params["temperature"] -print(f"Score: {best_result.score}") +print(f"得分: {best_result.score}") print(f"Top-k: {best_top_k}") -print(f"Chunk size: {best_chunk_size}") -print(f"Temperature: {best_temperature}") \ No newline at end of file +print(f"文本块大小: {best_chunk_size}") +print(f"温度: {best_temperature}") + +# 使用最佳参数再次运行查询引擎,并打印问题与答案 +best_vector_index = _build_index(best_chunk_size, documents) +best_query_engine = best_vector_index.as_query_engine( + similarity_top_k=best_top_k, temperature=best_temperature +) + +best_question_answers = [] +for question in eval_questions: + response = best_query_engine.query(question) + if response is not None: + best_question_answers.append((question, response.response)) + +# 打印最佳参数下的问题与答案 +for i, (question, answer) in enumerate(best_question_answers, start=1): + print(f"最佳参数 - 问题 {i}: {question}\n答案: {answer}\n") \ No newline at end of file