Files
2024-08-16 19:02:06 +08:00

110 lines
4.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import json
import sys
from llama_index.core import VectorStoreIndex, SQLDatabase
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex
from sqlalchemy import create_engine
from app.api.routers.chat import generate_filters
from app.engine import get_index, makeDescriptionByEngine
from app.engine.vectordb import get_vector_store
from app.observability import init_observability
from app.settings import init_settings
from dotenv import load_dotenv
load_dotenv()
def read_questions(file_path):
with open(file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
questions = [item["question"] for item in data]
return questions
def save_results_to_file(question, result, file_path):
result_data = {
"question": question,
"result": result
}
with open(file_path, 'a', encoding='utf-8') as file:
json.dump(result_data, file, ensure_ascii=False)
file.write('\n')
def main(questions_file, query_type):
# 更新环境变量
os.environ['TOP_K'] = str(5) # 向量的TOP_K值
os.environ['similarity_top_k'] = str(1) # SQL的TOP_K值固定为1
init_settings()
init_observability()
index = get_index()
top_k = int(os.getenv("TOP_K")) # 向量的TOP_K值
similarity_top_k = int(os.getenv("similarity_top_k")) # SQL的TOP_K值
filters = generate_filters([])
engine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
sql_database = SQLDatabase(engine)
table_schema_objs = makeDescriptionByEngine(sql_database)
table_node_mapping = SQLTableNodeMapping(sql_database)
# 创建SQL查询工具
sql_obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
index_cls=VectorStoreIndex,
)
sql_query_engine = SQLTableRetrieverQueryEngine(sql_database,
sql_obj_index.as_retriever(similarity_top_k=similarity_top_k))
script_dir = os.path.dirname(os.path.abspath(__file__))
questions_file_path = os.path.join(script_dir, questions_file)
results_file_path = os.path.join(script_dir, "parameters_results.json")
questions = read_questions(questions_file_path)
# # 如果文件为空,则写入参数值
# if not os.path.isfile(results_file_path):
# with open(results_file_path, 'w', encoding='utf-8') as file:
# json.dump({
# "TOP_K": top_k,
# "similarity_top_k": similarity_top_k
# }, file, ensure_ascii=False)
# file.write('\n')
# 循环执行查询
for i, question in enumerate(questions):
print(f"Executing query {i+1}: {question}")
# 对于每个问题,测试不同的温度值
for temperature in range(1, 11): # 从1到10
temperature_value = temperature / 10.0 # 从0.1到1.0
os.environ['LLM_TEMPERATURE'] = str(temperature_value)
if query_type == "vector":
query_engine = index.as_query_engine(
similarity_top_k=top_k, filters=filters
)
query_result = query_engine.query(question)
print(f"Vector Query Result: {query_result}\n")
save_results_to_file(question, f"Current parameters: TOP_K={top_k}, similarity_top_k={similarity_top_k}, Temperature: {temperature_value:.1f}, Vector Query Result: {query_result}", results_file_path)
elif query_type == "sql":
sql_query_result = sql_query_engine.query(question)
print(f"SQL Query Result: {sql_query_result}\n")
save_results_to_file(question, f"Current parameters: TOP_K={top_k}, similarity_top_k={similarity_top_k}, Temperature: {temperature_value:.1f}, SQL Query Result: {sql_query_result}", results_file_path)
else:
print("无效的查询类型,请选择 'vector' 或 'sql'")
sys.exit(1)
if __name__ == "__main__":
if len(sys.argv) < 3:
print("请提供questions.json文件的路径和查询类型(vector 或 sql")
sys.exit(1)
questions_file = sys.argv[1]
query_type = sys.argv[2].lower()
from phoenix.trace import using_project
with using_project(questions_file) as obj:
main(questions_file, query_type)