Add new files and update existing files

This commit is contained in:
chentianrui
2024-08-16 19:02:06 +08:00
parent aa2cecc997
commit a9b5dc94fe
+37 -22
View File
@@ -1,7 +1,6 @@
import os import os
import json import json
import sys import sys
from llama_index.core import VectorStoreIndex, SQLDatabase from llama_index.core import VectorStoreIndex, SQLDatabase
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex
@@ -31,11 +30,10 @@ def save_results_to_file(question, result, file_path):
json.dump(result_data, file, ensure_ascii=False) json.dump(result_data, file, ensure_ascii=False)
file.write('\n') file.write('\n')
def main(questions_file): def main(questions_file, query_type):
# 更新环境变量 # 更新环境变量
os.environ['TOP_K'] = str(5) # 向量的TOP_K值 os.environ['TOP_K'] = str(5) # 向量的TOP_K值
os.environ['LLM_TEMPERATURE'] = str(0.1) # 温度值 os.environ['similarity_top_k'] = str(1) # SQL的TOP_K值固定为1
os.environ['similarity_top_k'] = str(5) # SQL的TOP_K值
init_settings() init_settings()
init_observability() init_observability()
@@ -43,8 +41,7 @@ def main(questions_file):
index = get_index() index = get_index()
top_k = int(os.getenv("TOP_K")) # 向量的TOP_K值 top_k = int(os.getenv("TOP_K")) # 向量的TOP_K值
temperature = float(os.getenv("LLM_TEMPERATURE")) # 温度 similarity_top_k = int(os.getenv("similarity_top_k")) # SQL的TOP_K
similarity_top_k = float(os.getenv("similarity_top_k")) # SQL的TOP_K值
filters = generate_filters([]) filters = generate_filters([])
engine = create_engine(os.getenv("SQL_DATABASE_URL", "")) engine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
@@ -65,31 +62,49 @@ def main(questions_file):
results_file_path = os.path.join(script_dir, "parameters_results.json") results_file_path = os.path.join(script_dir, "parameters_results.json")
questions = read_questions(questions_file_path) questions = read_questions(questions_file_path)
# 如果文件为空,则写入参数值 # # 如果文件为空,则写入参数值
if not os.path.isfile(results_file_path): # if not os.path.isfile(results_file_path):
with open(results_file_path, 'w', encoding='utf-8') as file: # with open(results_file_path, 'w', encoding='utf-8') as file:
json.dump({ # json.dump({
"TOP_K": top_k, # "TOP_K": top_k,
"LLM_TEMPERATURE": temperature, # "similarity_top_k": similarity_top_k
"similarity_top_k": similarity_top_k # }, file, ensure_ascii=False)
}, file, ensure_ascii=False) # file.write('\n')
file.write('\n')
# 循环执行查询 # 循环执行查询
for i, question in enumerate(questions): for i, question in enumerate(questions):
print(f"Executing query {i+1}: {question}") print(f"Executing query {i+1}: {question}")
sql_query_result = sql_query_engine.query(question)
print(f"SQL查询结果: {sql_query_result}\n") # 对于每个问题,测试不同的温度值
save_results_to_file(question, f"SQL查询结果: {sql_query_result}", results_file_path) for temperature in range(1, 11): # 从1到10
temperature_value = temperature / 10.0 # 从0.1到1.0
os.environ['LLM_TEMPERATURE'] = str(temperature_value)
if query_type == "vector":
query_engine = index.as_query_engine(
similarity_top_k=top_k, filters=filters
)
query_result = query_engine.query(question)
print(f"Vector Query Result: {query_result}\n")
save_results_to_file(question, f"Current parameters: TOP_K={top_k}, similarity_top_k={similarity_top_k}, Temperature: {temperature_value:.1f}, Vector Query Result: {query_result}", results_file_path)
elif query_type == "sql":
sql_query_result = sql_query_engine.query(question)
print(f"SQL Query Result: {sql_query_result}\n")
save_results_to_file(question, f"Current parameters: TOP_K={top_k}, similarity_top_k={similarity_top_k}, Temperature: {temperature_value:.1f}, SQL Query Result: {sql_query_result}", results_file_path)
else:
print("无效的查询类型,请选择 'vector''sql'")
sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":
if len(sys.argv) < 2: if len(sys.argv) < 3:
print("请提供questions.json文件的路径") print("请提供questions.json文件的路径和查询类型(vector 或 sql")
sys.exit(1) sys.exit(1)
questions_file = sys.argv[1] questions_file = sys.argv[1]
query_type = sys.argv[2].lower()
from phoenix.trace import using_project from phoenix.trace import using_project
with using_project(questions_file) as obj: with using_project(questions_file) as obj:
main(questions_file) main(questions_file, query_type)