参数优化针对问题做出了调整
This commit is contained in:
+49
-18
@@ -27,6 +27,13 @@ init_observability()
|
|||||||
# 读取文档
|
# 读取文档
|
||||||
documents = SimpleDirectoryReader("D:/LLM_model/text2sql/zjdataai-app-test/backend/data-test").load_data()
|
documents = SimpleDirectoryReader("D:/LLM_model/text2sql/zjdataai-app-test/backend/data-test").load_data()
|
||||||
|
|
||||||
|
# 参数字典
|
||||||
|
param_dict = {
|
||||||
|
"chunk_size": [512, 1024],
|
||||||
|
"top_k": [1, 5],
|
||||||
|
"temperature": [0.1, 1.0]
|
||||||
|
}
|
||||||
|
|
||||||
# 辅助函数
|
# 辅助函数
|
||||||
def _build_index(chunk_size, documents):
|
def _build_index(chunk_size, documents):
|
||||||
# 构建索引
|
# 构建索引
|
||||||
@@ -56,8 +63,18 @@ async def _evaluate_query_engine_async(query_engine, questions):
|
|||||||
|
|
||||||
return total_correct, len(results)
|
return total_correct, len(results)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 生成问题
|
||||||
|
question_generator = DatasetGenerator.from_documents(documents)
|
||||||
|
eval_questions = question_generator.generate_questions_from_nodes(1) # 假设生成10个问题
|
||||||
|
|
||||||
|
# 打印生成的问题
|
||||||
|
for i, q in enumerate(eval_questions, start=1):
|
||||||
|
print(f"问题 {i}: {q}")
|
||||||
|
|
||||||
# 目标函数
|
# 目标函数
|
||||||
def objective_function(params_dict, documents, question_count):
|
def objective_function(params_dict, documents, questions):
|
||||||
chunk_size = params_dict["chunk_size"]
|
chunk_size = params_dict["chunk_size"]
|
||||||
top_k = params_dict["top_k"]
|
top_k = params_dict["top_k"]
|
||||||
temperature = params_dict["temperature"]
|
temperature = params_dict["temperature"]
|
||||||
@@ -70,27 +87,25 @@ def objective_function(params_dict, documents, question_count):
|
|||||||
similarity_top_k=top_k, temperature=temperature
|
similarity_top_k=top_k, temperature=temperature
|
||||||
)
|
)
|
||||||
|
|
||||||
# 生成问题
|
|
||||||
question_generator = DatasetGenerator.from_documents(documents)
|
|
||||||
eval_questions = question_generator.generate_questions_from_nodes(question_count)
|
|
||||||
|
|
||||||
# 评估查询引擎
|
# 评估查询引擎
|
||||||
correct, total = evaluate_query_engine(query_engine, eval_questions)
|
correct, total = 0, len(questions)
|
||||||
|
question_answers = [] # 添加列表来收集问题和答案
|
||||||
|
|
||||||
|
for question in questions:
|
||||||
|
response = query_engine.query(question)
|
||||||
|
if response is not None:
|
||||||
|
question_answers.append((question, response.response))
|
||||||
|
eval_result = FaithfulnessEvaluator().evaluate_response(response=response, query_str=question)
|
||||||
|
if eval_result.passing:
|
||||||
|
correct += 1
|
||||||
|
|
||||||
# 计算分数
|
# 计算分数
|
||||||
score = correct / total if total > 0 else 0
|
score = correct / total if total > 0 else 0
|
||||||
return RunResult(score=score, params=params_dict)
|
return RunResult(score=score, params=params_dict, question_answers=question_answers)
|
||||||
|
|
||||||
# 参数字典
|
|
||||||
param_dict = {
|
|
||||||
"chunk_size": [512, 1024],
|
|
||||||
"top_k": [1, 5],
|
|
||||||
"temperature": [0.1, 1.0]
|
|
||||||
}
|
|
||||||
|
|
||||||
# 创建 ParamTuner 实例
|
# 创建 ParamTuner 实例
|
||||||
param_tuner = ParamTuner(
|
param_tuner = ParamTuner(
|
||||||
param_fn=lambda params_dict: objective_function(params_dict, documents, 1),
|
param_fn=lambda params_dict: objective_function(params_dict, documents, eval_questions),
|
||||||
param_dict=param_dict,
|
param_dict=param_dict,
|
||||||
show_progress=True,
|
show_progress=True,
|
||||||
)
|
)
|
||||||
@@ -101,7 +116,23 @@ best_result = results.best_run_result
|
|||||||
best_top_k = best_result.params["top_k"]
|
best_top_k = best_result.params["top_k"]
|
||||||
best_chunk_size = best_result.params["chunk_size"]
|
best_chunk_size = best_result.params["chunk_size"]
|
||||||
best_temperature = best_result.params["temperature"]
|
best_temperature = best_result.params["temperature"]
|
||||||
print(f"Score: {best_result.score}")
|
print(f"得分: {best_result.score}")
|
||||||
print(f"Top-k: {best_top_k}")
|
print(f"Top-k: {best_top_k}")
|
||||||
print(f"Chunk size: {best_chunk_size}")
|
print(f"文本块大小: {best_chunk_size}")
|
||||||
print(f"Temperature: {best_temperature}")
|
print(f"温度: {best_temperature}")
|
||||||
|
|
||||||
|
# 使用最佳参数再次运行查询引擎,并打印问题与答案
|
||||||
|
best_vector_index = _build_index(best_chunk_size, documents)
|
||||||
|
best_query_engine = best_vector_index.as_query_engine(
|
||||||
|
similarity_top_k=best_top_k, temperature=best_temperature
|
||||||
|
)
|
||||||
|
|
||||||
|
best_question_answers = []
|
||||||
|
for question in eval_questions:
|
||||||
|
response = best_query_engine.query(question)
|
||||||
|
if response is not None:
|
||||||
|
best_question_answers.append((question, response.response))
|
||||||
|
|
||||||
|
# 打印最佳参数下的问题与答案
|
||||||
|
for i, (question, answer) in enumerate(best_question_answers, start=1):
|
||||||
|
print(f"最佳参数 - 问题 {i}: {question}\n答案: {answer}\n")
|
||||||
Reference in New Issue
Block a user