From de34c3938c21246b8caccac00f5b7af70d36a1fa Mon Sep 17 00:00:00 2001 From: chentianrui Date: Thu, 29 Aug 2024 12:02:53 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=86=E5=8F=82=E6=95=B0?= =?UTF-8?q?=E8=AF=84=E4=BC=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/test1/ParamTuner.py | 107 ++++++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 backend/test1/ParamTuner.py diff --git a/backend/test1/ParamTuner.py b/backend/test1/ParamTuner.py new file mode 100644 index 0000000..a8bbb12 --- /dev/null +++ b/backend/test1/ParamTuner.py @@ -0,0 +1,107 @@ +import nest_asyncio +nest_asyncio.apply() +from llama_index.core import SimpleDirectoryReader +from llama_index.core.node_parser import SentenceSplitter +from llama_index.core import VectorStoreIndex +from llama_index.core.evaluation import ( + FaithfulnessEvaluator, + DatasetGenerator, + CorrectnessEvaluator, + SemanticSimilarityEvaluator, +) +from llama_index.experimental.param_tuner import ParamTuner +from llama_index.experimental.param_tuner.base import RunResult +from llama_index.llms.openai import OpenAI + +import asyncio + +# 初始化环境 +from app.observability import init_observability +from app.settings import init_settings +from dotenv import load_dotenv + +load_dotenv() +init_settings() +init_observability() + +# 读取文档 +documents = SimpleDirectoryReader("D:/LLM_model/text2sql/zjdataai-app-test/backend/data-test").load_data() + +# 辅助函数 +def _build_index(chunk_size, documents): + # 构建索引 + splitter = SentenceSplitter(chunk_size=chunk_size) + vector_index = VectorStoreIndex.from_documents( + documents, transformations=[splitter], + ) + return vector_index + +# 评估函数 +def evaluate_query_engine(query_engine, questions): + loop = asyncio.get_event_loop() + correct, total = loop.run_until_complete(_evaluate_query_engine_async(query_engine, questions)) + return correct, total + +async def _evaluate_query_engine_async(query_engine, questions): + c = [query_engine.aquery(q) for q in questions] + gathering_future = asyncio.gather(*c) + results = await gathering_future + + total_correct = 0 + for r in results: + eval_result = ( + 1 if FaithfulnessEvaluator().evaluate_response(response=r).passing else 0 + ) + total_correct += eval_result + + return total_correct, len(results) + +# 目标函数 +def objective_function(params_dict, documents, question_count): + chunk_size = params_dict["chunk_size"] + top_k = params_dict["top_k"] + temperature = params_dict["temperature"] + + # 构建索引 + vector_index = _build_index(chunk_size, documents) + + # 查询引擎 + query_engine = vector_index.as_query_engine( + similarity_top_k=top_k, temperature=temperature + ) + + # 生成问题 + question_generator = DatasetGenerator.from_documents(documents) + eval_questions = question_generator.generate_questions_from_nodes(question_count) + + # 评估查询引擎 + correct, total = evaluate_query_engine(query_engine, eval_questions) + + # 计算分数 + score = correct / total if total > 0 else 0 + return RunResult(score=score, params=params_dict) + +# 参数字典 +param_dict = { + "chunk_size": [512, 1024], + "top_k": [1, 5], + "temperature": [0.1, 1.0] +} + +# 创建 ParamTuner 实例 +param_tuner = ParamTuner( + param_fn=lambda params_dict: objective_function(params_dict, documents, 1), + param_dict=param_dict, + show_progress=True, +) + +# 调用 tune 方法 +results = param_tuner.tune() +best_result = results.best_run_result +best_top_k = best_result.params["top_k"] +best_chunk_size = best_result.params["chunk_size"] +best_temperature = best_result.params["temperature"] +print(f"Score: {best_result.score}") +print(f"Top-k: {best_top_k}") +print(f"Chunk size: {best_chunk_size}") +print(f"Temperature: {best_temperature}") \ No newline at end of file