Files
zjdataai-app/backend/test1/query_test.py
T
2024-08-15 19:08:58 +08:00

140 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import re
import os
import sys
from ctypes import cast
from llama_index.core import VectorStoreIndex, SQLDatabase
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex
from llama_index.readers.database import DatabaseReader
from sqlalchemy import create_engine
from app.api.routers.chat import generate_filters
from app.engine import get_index, makeDescriptionByEngine
from app.engine.loaders.db import CustomDatabaseReader
from app.engine.vectordb import get_vector_store
from app.observability import init_observability
from app.settings import init_settings
from dotenv import load_dotenv
load_dotenv()
def read_questions_and_answers(file_path):
questions_and_answers = []
with open(file_path, 'r', encoding='utf-8') as file:
for line in file:
if "question" in line and "answer" in line:
question_part = line.split("")[1].strip() # 提取 question
answer_part = re.search(r"answer.*?(\d+)", line) # 使用正则提取 answer 中的数字
if answer_part:
answer_value = answer_part.group(1)
questions_and_answers.append((question_part, answer_value))
return questions_and_answers
def save_results_to_file(question, result, correct_answer, file_path):
with open(file_path, 'a', encoding='utf-8') as file:
file.write(f"问题: {question}\n")
file.write(f"查询结果: {result}\n")
file.write(f"正确答案: {correct_answer}\n\n")
def log_incorrect_answers(question, correct_answer, result, log_file_path):
with open(log_file_path, 'a', encoding='utf-8') as file:
file.write(f"错误问题: {question}\n")
file.write(f"正确答案: {correct_answer}\n")
file.write(f"查询结果: {result}\n\n")
def main():
# 从命令行读取questions_file_path和查询类型
if len(sys.argv) < 3:
print("请提供questions.txt文件的路径和查询类型(vector 或 sql")
sys.exit(1)
questions_file_path = sys.argv[1]
query_type = sys.argv[2].lower()
# 获取脚本所在的目录
script_dir = os.path.dirname(os.path.abspath(__file__))
# 设置结果文件和日志文件的路径
results_file_path = os.path.join(script_dir, "query_results.txt")
log_file_path = os.path.join(script_dir, "incorrect_answers_log.txt")
# 更新环境变量
os.environ['TOP_K'] = str(5) # 向量的TOP_K值
os.environ['LLM_TEMPERATURE'] = str(0.1) # 温度值
os.environ['similarity_top_k'] = str(5) # SQL的TOP_K值
init_settings()
init_observability()
index = get_index()
top_k = int(os.getenv("TOP_K")) # 向量的TOP_K值
temperature = float(os.getenv("LLM_TEMPERATURE")) # 温度值
similarity_top_k = float(os.getenv("similarity_top_k")) # SQL的TOP_K值
filters = generate_filters([])
engine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
sql_database = SQLDatabase(engine)
table_schema_objs = makeDescriptionByEngine(sql_database)
table_node_mapping = SQLTableNodeMapping(sql_database)
# 创建SQL查询工具
sql_obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
index_cls=VectorStoreIndex,
)
sql_query_engine = SQLTableRetrieverQueryEngine(sql_database,
sql_obj_index.as_retriever(similarity_top_k=similarity_top_k))
questions_and_answers = read_questions_and_answers(questions_file_path)
# 如果文件为空,则写入参数值
if os.path.getsize(results_file_path) == 0:
with open(results_file_path, 'w', encoding='utf-8') as file:
file.write(f"TOP_K: {top_k}\n")
file.write(f"LLM_TEMPERATURE: {temperature}\n")
file.write(f"similarity_top_k: {similarity_top_k}\n\n")
# 循环执行查询
for i, (question, correct_answer) in enumerate(questions_and_answers):
print(f"Executing query {i+1}: {question}")
if query_type == "vector":
query_engine = index.as_query_engine(
similarity_top_k=top_k, filters=filters
)
query_result = query_engine.query(question)
print(f"向量查询结果: {query_result}\n")
save_results_to_file(question, f"向量查询结果: {query_result}", correct_answer, results_file_path)
# 提取向量查询结果中的数字进行匹配
query_result_number = re.search(r"(\d+)", str(query_result))
elif query_type == "sql":
sql_query_result = sql_query_engine.query(question)
print(f"SQL查询结果: {sql_query_result}\n")
save_results_to_file(question, f"SQL查询结果: {sql_query_result}", correct_answer, results_file_path)
# 提取SQL查询结果中的数字进行匹配
query_result_number = re.search(r"(\d+)", str(sql_query_result))
else:
print("无效的查询类型,请选择 'vector' 或 'sql'")
sys.exit(1)
if query_result_number:
query_number = query_result_number.group(1)
# 判断查询结果是否与正确答案匹配
if query_number == correct_answer:
save_results_to_file(question, query_number, correct_answer, results_file_path)
else:
log_incorrect_answers(question, correct_answer, query_number, log_file_path)
else:
log_incorrect_answers(question, correct_answer, "未找到有效数字", log_file_path)
if __name__ == "__main__":
from phoenix.trace import using_project
with using_project("ly_zjapp_test") as obj:
main()