diff --git a/backend/test1/test_parameters.py b/backend/test1/test_parameters.py index 479c2f1..f95efda 100644 --- a/backend/test1/test_parameters.py +++ b/backend/test1/test_parameters.py @@ -1,7 +1,6 @@ import os import json import sys - from llama_index.core import VectorStoreIndex, SQLDatabase from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex @@ -31,11 +30,10 @@ def save_results_to_file(question, result, file_path): json.dump(result_data, file, ensure_ascii=False) file.write('\n') -def main(questions_file): +def main(questions_file, query_type): # 更新环境变量 os.environ['TOP_K'] = str(5) # 向量的TOP_K值 - os.environ['LLM_TEMPERATURE'] = str(0.1) # 温度值 - os.environ['similarity_top_k'] = str(5) # SQL的TOP_K值 + os.environ['similarity_top_k'] = str(1) # SQL的TOP_K值固定为1 init_settings() init_observability() @@ -43,8 +41,7 @@ def main(questions_file): index = get_index() top_k = int(os.getenv("TOP_K")) # 向量的TOP_K值 - temperature = float(os.getenv("LLM_TEMPERATURE")) # 温度值 - similarity_top_k = float(os.getenv("similarity_top_k")) # SQL的TOP_K值 + similarity_top_k = int(os.getenv("similarity_top_k")) # SQL的TOP_K值 filters = generate_filters([]) engine = create_engine(os.getenv("SQL_DATABASE_URL", "")) @@ -65,31 +62,49 @@ def main(questions_file): results_file_path = os.path.join(script_dir, "parameters_results.json") questions = read_questions(questions_file_path) - - # 如果文件为空,则写入参数值 - if not os.path.isfile(results_file_path): - with open(results_file_path, 'w', encoding='utf-8') as file: - json.dump({ - "TOP_K": top_k, - "LLM_TEMPERATURE": temperature, - "similarity_top_k": similarity_top_k - }, file, ensure_ascii=False) - file.write('\n') + + # # 如果文件为空,则写入参数值 + # if not os.path.isfile(results_file_path): + # with open(results_file_path, 'w', encoding='utf-8') as file: + # json.dump({ + # "TOP_K": top_k, + # "similarity_top_k": similarity_top_k + # }, file, ensure_ascii=False) + # file.write('\n') # 循环执行查询 for i, question in enumerate(questions): print(f"Executing query {i+1}: {question}") - sql_query_result = sql_query_engine.query(question) - print(f"SQL查询结果: {sql_query_result}\n") - save_results_to_file(question, f"SQL查询结果: {sql_query_result}", results_file_path) + + # 对于每个问题,测试不同的温度值 + for temperature in range(1, 11): # 从1到10 + temperature_value = temperature / 10.0 # 从0.1到1.0 + + os.environ['LLM_TEMPERATURE'] = str(temperature_value) + + if query_type == "vector": + query_engine = index.as_query_engine( + similarity_top_k=top_k, filters=filters + ) + query_result = query_engine.query(question) + print(f"Vector Query Result: {query_result}\n") + save_results_to_file(question, f"Current parameters: TOP_K={top_k}, similarity_top_k={similarity_top_k}, Temperature: {temperature_value:.1f}, Vector Query Result: {query_result}", results_file_path) + elif query_type == "sql": + sql_query_result = sql_query_engine.query(question) + print(f"SQL Query Result: {sql_query_result}\n") + save_results_to_file(question, f"Current parameters: TOP_K={top_k}, similarity_top_k={similarity_top_k}, Temperature: {temperature_value:.1f}, SQL Query Result: {sql_query_result}", results_file_path) + else: + print("无效的查询类型,请选择 'vector' 或 'sql'") + sys.exit(1) if __name__ == "__main__": - if len(sys.argv) < 2: - print("请提供questions.json文件的路径") + if len(sys.argv) < 3: + print("请提供questions.json文件的路径和查询类型(vector 或 sql)") sys.exit(1) questions_file = sys.argv[1] + query_type = sys.argv[2].lower() from phoenix.trace import using_project with using_project(questions_file) as obj: - main(questions_file) + main(questions_file, query_type) \ No newline at end of file