增加了参数评估
This commit is contained in:
@@ -0,0 +1,107 @@
|
|||||||
|
import nest_asyncio
|
||||||
|
nest_asyncio.apply()
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
from llama_index.core import VectorStoreIndex
|
||||||
|
from llama_index.core.evaluation import (
|
||||||
|
FaithfulnessEvaluator,
|
||||||
|
DatasetGenerator,
|
||||||
|
CorrectnessEvaluator,
|
||||||
|
SemanticSimilarityEvaluator,
|
||||||
|
)
|
||||||
|
from llama_index.experimental.param_tuner import ParamTuner
|
||||||
|
from llama_index.experimental.param_tuner.base import RunResult
|
||||||
|
from llama_index.llms.openai import OpenAI
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# 初始化环境
|
||||||
|
from app.observability import init_observability
|
||||||
|
from app.settings import init_settings
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
init_settings()
|
||||||
|
init_observability()
|
||||||
|
|
||||||
|
# 读取文档
|
||||||
|
documents = SimpleDirectoryReader("D:/LLM_model/text2sql/zjdataai-app-test/backend/data-test").load_data()
|
||||||
|
|
||||||
|
# 辅助函数
|
||||||
|
def _build_index(chunk_size, documents):
|
||||||
|
# 构建索引
|
||||||
|
splitter = SentenceSplitter(chunk_size=chunk_size)
|
||||||
|
vector_index = VectorStoreIndex.from_documents(
|
||||||
|
documents, transformations=[splitter],
|
||||||
|
)
|
||||||
|
return vector_index
|
||||||
|
|
||||||
|
# 评估函数
|
||||||
|
def evaluate_query_engine(query_engine, questions):
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
correct, total = loop.run_until_complete(_evaluate_query_engine_async(query_engine, questions))
|
||||||
|
return correct, total
|
||||||
|
|
||||||
|
async def _evaluate_query_engine_async(query_engine, questions):
|
||||||
|
c = [query_engine.aquery(q) for q in questions]
|
||||||
|
gathering_future = asyncio.gather(*c)
|
||||||
|
results = await gathering_future
|
||||||
|
|
||||||
|
total_correct = 0
|
||||||
|
for r in results:
|
||||||
|
eval_result = (
|
||||||
|
1 if FaithfulnessEvaluator().evaluate_response(response=r).passing else 0
|
||||||
|
)
|
||||||
|
total_correct += eval_result
|
||||||
|
|
||||||
|
return total_correct, len(results)
|
||||||
|
|
||||||
|
# 目标函数
|
||||||
|
def objective_function(params_dict, documents, question_count):
|
||||||
|
chunk_size = params_dict["chunk_size"]
|
||||||
|
top_k = params_dict["top_k"]
|
||||||
|
temperature = params_dict["temperature"]
|
||||||
|
|
||||||
|
# 构建索引
|
||||||
|
vector_index = _build_index(chunk_size, documents)
|
||||||
|
|
||||||
|
# 查询引擎
|
||||||
|
query_engine = vector_index.as_query_engine(
|
||||||
|
similarity_top_k=top_k, temperature=temperature
|
||||||
|
)
|
||||||
|
|
||||||
|
# 生成问题
|
||||||
|
question_generator = DatasetGenerator.from_documents(documents)
|
||||||
|
eval_questions = question_generator.generate_questions_from_nodes(question_count)
|
||||||
|
|
||||||
|
# 评估查询引擎
|
||||||
|
correct, total = evaluate_query_engine(query_engine, eval_questions)
|
||||||
|
|
||||||
|
# 计算分数
|
||||||
|
score = correct / total if total > 0 else 0
|
||||||
|
return RunResult(score=score, params=params_dict)
|
||||||
|
|
||||||
|
# 参数字典
|
||||||
|
param_dict = {
|
||||||
|
"chunk_size": [512, 1024],
|
||||||
|
"top_k": [1, 5],
|
||||||
|
"temperature": [0.1, 1.0]
|
||||||
|
}
|
||||||
|
|
||||||
|
# 创建 ParamTuner 实例
|
||||||
|
param_tuner = ParamTuner(
|
||||||
|
param_fn=lambda params_dict: objective_function(params_dict, documents, 1),
|
||||||
|
param_dict=param_dict,
|
||||||
|
show_progress=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 调用 tune 方法
|
||||||
|
results = param_tuner.tune()
|
||||||
|
best_result = results.best_run_result
|
||||||
|
best_top_k = best_result.params["top_k"]
|
||||||
|
best_chunk_size = best_result.params["chunk_size"]
|
||||||
|
best_temperature = best_result.params["temperature"]
|
||||||
|
print(f"Score: {best_result.score}")
|
||||||
|
print(f"Top-k: {best_top_k}")
|
||||||
|
print(f"Chunk size: {best_chunk_size}")
|
||||||
|
print(f"Temperature: {best_temperature}")
|
||||||
Reference in New Issue
Block a user