Files
zjdataai-app/backend/test1/test_parameters.py
T
2024-08-15 19:08:58 +08:00

102 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
from ctypes import cast
import sys
from llama_index.core import VectorStoreIndex, SQLDatabase
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex
from llama_index.readers.database import DatabaseReader
from sqlalchemy import create_engine
from app.api.routers.chat import generate_filters
from app.engine import get_index, makeDescriptionByEngine
from app.engine.loaders.db import CustomDatabaseReader
from app.engine.vectordb import get_vector_store
from app.observability import init_observability
from app.settings import init_settings
from dotenv import load_dotenv
load_dotenv()
def read_questions(file_path):
questions = []
with open(file_path, 'r', encoding='utf-8') as file:
for line in file:
if "question" in line:
question_part = line.split("")[1].strip() # 提取 "question" 后的内容
questions.append(question_part)
return questions
def save_results_to_file(question, result, file_path):
with open(file_path, 'a', encoding='utf-8') as file:
file.write(f"问题: {question}\n")
file.write(f"结果: {result}\n\n")
def main():
# 从命令行读取questions_file_path
if len(sys.argv) < 2:
print("请提供questions.txt文件的路径")
sys.exit(1)
questions_file_path = sys.argv[1]
# 更新环境变量
os.environ['TOP_K'] = str(5) # 向量的TOP_K值
os.environ['LLM_TEMPERATURE'] = str(0.1) # 温度值
os.environ['similarity_top_k'] = str(5) # SQL的TOP_K值
init_settings()
init_observability()
index = get_index()
top_k = int(os.getenv("TOP_K")) # 向量的TOP_K值
temperature = float(os.getenv("LLM_TEMPERATURE")) # 温度值
similarity_top_k = float(os.getenv("similarity_top_k")) # SQL的TOP_K值
filters = generate_filters([])
engine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
sql_database = SQLDatabase(engine)
table_schema_objs = makeDescriptionByEngine(sql_database)
table_node_mapping = SQLTableNodeMapping(sql_database)
# 创建SQL查询工具
sql_obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
index_cls=VectorStoreIndex,
)
sql_query_engine = SQLTableRetrieverQueryEngine(sql_database,
sql_obj_index.as_retriever(similarity_top_k=similarity_top_k))
questions = read_questions(questions_file_path)
script_dir = os.path.dirname(os.path.abspath(__file__))
results_file_path = os.path.join(script_dir, "query_results.txt")
# 如果文件为空,则写入参数值
if os.path.getsize(results_file_path) == 0:
with open(results_file_path, 'w', encoding='utf-8') as file:
file.write(f"TOP_K: {top_k}\n")
file.write(f"LLM_TEMPERATURE: {temperature}\n")
file.write(f"similarity_top_k: {similarity_top_k}\n\n")
# 循环执行查询
for i, question in enumerate(questions):
print(f"Executing query {i+1}: {question}")
# query_engine = index.as_query_engine(
# similarity_top_k=top_k, filters=filters
# )
# query_result = query_engine.query(question)
# print(f"向量查询结果: {query_result}\n")
# save_results_to_file(question, f"向量查询结果: {query_result}", results_file_path)
sql_query_result = sql_query_engine.query(question)
print(f"SQL查询结果: {sql_query_result}\n")
save_results_to_file(question, f"SQL查询结果: {sql_query_result}", results_file_path)
if __name__ == "__main__":
from phoenix.trace import using_project
with using_project("ly_zjapp_test") as obj:
main()