import os import json import sys from llama_index.core import VectorStoreIndex, SQLDatabase from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex from sqlalchemy import create_engine from app.api.routers.chat import generate_filters from app.engine import get_index, makeDescriptionByEngine from app.engine.vectordb import get_vector_store from app.observability import init_observability from app.settings import init_settings from dotenv import load_dotenv load_dotenv() def read_questions(file_path): with open(file_path, 'r', encoding='utf-8') as file: data = json.load(file) questions = [item["question"] for item in data] return questions def save_results_to_file(question, result, file_path): result_data = { "question": question, "result": result } with open(file_path, 'a', encoding='utf-8') as file: json.dump(result_data, file, ensure_ascii=False) file.write('\n') def main(questions_file): # 更新环境变量 os.environ['TOP_K'] = str(5) # 向量的TOP_K值 os.environ['LLM_TEMPERATURE'] = str(0.1) # 温度值 os.environ['similarity_top_k'] = str(5) # SQL的TOP_K值 init_settings() init_observability() index = get_index() top_k = int(os.getenv("TOP_K")) # 向量的TOP_K值 temperature = float(os.getenv("LLM_TEMPERATURE")) # 温度值 similarity_top_k = float(os.getenv("similarity_top_k")) # SQL的TOP_K值 filters = generate_filters([]) engine = create_engine(os.getenv("SQL_DATABASE_URL", "")) sql_database = SQLDatabase(engine) table_schema_objs = makeDescriptionByEngine(sql_database) table_node_mapping = SQLTableNodeMapping(sql_database) # 创建SQL查询工具 sql_obj_index = ObjectIndex.from_objects( table_schema_objs, table_node_mapping, index_cls=VectorStoreIndex, ) sql_query_engine = SQLTableRetrieverQueryEngine(sql_database, sql_obj_index.as_retriever(similarity_top_k=similarity_top_k)) script_dir = os.path.dirname(os.path.abspath(__file__)) questions_file_path = os.path.join(script_dir, questions_file) results_file_path = os.path.join(script_dir, "parameters_results.json") questions = read_questions(questions_file_path) # 如果文件为空,则写入参数值 if not os.path.isfile(results_file_path): with open(results_file_path, 'w', encoding='utf-8') as file: json.dump({ "TOP_K": top_k, "LLM_TEMPERATURE": temperature, "similarity_top_k": similarity_top_k }, file, ensure_ascii=False) file.write('\n') # 循环执行查询 for i, question in enumerate(questions): print(f"Executing query {i+1}: {question}") sql_query_result = sql_query_engine.query(question) print(f"SQL查询结果: {sql_query_result}\n") save_results_to_file(question, f"SQL查询结果: {sql_query_result}", results_file_path) if __name__ == "__main__": if len(sys.argv) < 2: print("请提供questions.json文件的路径") sys.exit(1) questions_file = sys.argv[1] from phoenix.trace import using_project with using_project(questions_file) as obj: main(questions_file)