96 lines
3.3 KiB
Python
96 lines
3.3 KiB
Python
import os
|
|
import json
|
|
import sys
|
|
|
|
from llama_index.core import VectorStoreIndex, SQLDatabase
|
|
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
|
|
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex
|
|
from sqlalchemy import create_engine
|
|
|
|
from app.api.routers.chat import generate_filters
|
|
from app.engine import get_index, makeDescriptionByEngine
|
|
from app.engine.vectordb import get_vector_store
|
|
from app.observability import init_observability
|
|
from app.settings import init_settings
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
|
|
def read_questions(file_path):
|
|
with open(file_path, 'r', encoding='utf-8') as file:
|
|
data = json.load(file)
|
|
questions = [item["question"] for item in data]
|
|
return questions
|
|
|
|
def save_results_to_file(question, result, file_path):
|
|
result_data = {
|
|
"question": question,
|
|
"result": result
|
|
}
|
|
with open(file_path, 'a', encoding='utf-8') as file:
|
|
json.dump(result_data, file, ensure_ascii=False)
|
|
file.write('\n')
|
|
|
|
def main(questions_file):
|
|
# 更新环境变量
|
|
os.environ['TOP_K'] = str(5) # 向量的TOP_K值
|
|
os.environ['LLM_TEMPERATURE'] = str(0.1) # 温度值
|
|
os.environ['similarity_top_k'] = str(5) # SQL的TOP_K值
|
|
|
|
init_settings()
|
|
init_observability()
|
|
|
|
index = get_index()
|
|
|
|
top_k = int(os.getenv("TOP_K")) # 向量的TOP_K值
|
|
temperature = float(os.getenv("LLM_TEMPERATURE")) # 温度值
|
|
similarity_top_k = float(os.getenv("similarity_top_k")) # SQL的TOP_K值
|
|
filters = generate_filters([])
|
|
|
|
engine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
|
|
sql_database = SQLDatabase(engine)
|
|
table_schema_objs = makeDescriptionByEngine(sql_database)
|
|
table_node_mapping = SQLTableNodeMapping(sql_database)
|
|
# 创建SQL查询工具
|
|
sql_obj_index = ObjectIndex.from_objects(
|
|
table_schema_objs,
|
|
table_node_mapping,
|
|
index_cls=VectorStoreIndex,
|
|
)
|
|
sql_query_engine = SQLTableRetrieverQueryEngine(sql_database,
|
|
sql_obj_index.as_retriever(similarity_top_k=similarity_top_k))
|
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
questions_file_path = os.path.join(script_dir, questions_file)
|
|
results_file_path = os.path.join(script_dir, "parameters_results.json")
|
|
|
|
questions = read_questions(questions_file_path)
|
|
|
|
# 如果文件为空,则写入参数值
|
|
if not os.path.isfile(results_file_path):
|
|
with open(results_file_path, 'w', encoding='utf-8') as file:
|
|
json.dump({
|
|
"TOP_K": top_k,
|
|
"LLM_TEMPERATURE": temperature,
|
|
"similarity_top_k": similarity_top_k
|
|
}, file, ensure_ascii=False)
|
|
file.write('\n')
|
|
|
|
# 循环执行查询
|
|
for i, question in enumerate(questions):
|
|
print(f"Executing query {i+1}: {question}")
|
|
sql_query_result = sql_query_engine.query(question)
|
|
print(f"SQL查询结果: {sql_query_result}\n")
|
|
save_results_to_file(question, f"SQL查询结果: {sql_query_result}", results_file_path)
|
|
|
|
if __name__ == "__main__":
|
|
if len(sys.argv) < 2:
|
|
print("请提供questions.json文件的路径")
|
|
sys.exit(1)
|
|
questions_file = sys.argv[1]
|
|
|
|
from phoenix.trace import using_project
|
|
|
|
with using_project(questions_file) as obj:
|
|
main(questions_file)
|