Files
zjdataai-app/backend/test1/query_test.py
T
2024-08-16 11:17:27 +08:00

155 lines
6.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import re
import os
import sys
import json
from ctypes import cast
from llama_index.core import VectorStoreIndex, SQLDatabase
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex
from sqlalchemy import create_engine
from app.api.routers.chat import generate_filters
from app.engine import get_index, makeDescriptionByEngine
from app.engine.loaders.db import CustomDatabaseReader
from app.engine.vectordb import get_vector_store
from app.observability import init_observability
from app.settings import init_settings
from dotenv import load_dotenv
load_dotenv()
def read_questions_and_answers(file_path):
questions_and_answers = []
with open(file_path, 'r', encoding='utf-8') as file:
data = json.load(file) # 读取 JSON 数据
for entry in data:
question = entry.get("question", "").strip() # 获取 question
answer_match = re.search(r"(\d+\.?\d*)", entry.get("answer", "")) # 使用正则提取 answer 中的数字部分
if question and answer_match:
answer_value = answer_match.group(1) # 获取匹配的数字
questions_and_answers.append((question, answer_value))
return questions_and_answers
def save_results_to_file(question, result, correct_answer, file_path):
with open(file_path, 'a', encoding='utf-8') as file:
file.write(f"问题: {question}\n")
file.write(f"查询结果: {result}\n")
file.write(f"正确答案: {correct_answer}\n\n")
def log_incorrect_answers(question, correct_answer, result, log_file_path):
with open(log_file_path, 'a', encoding='utf-8') as file:
file.write(f"错误问题: {question}\n")
file.write(f"正确答案: {correct_answer}\n")
file.write(f"查询结果: {result}\n\n")
def main(questions_file, query_type):
# 获取脚本所在的目录
script_dir = os.path.dirname(os.path.abspath(__file__))
# 将文件扩展名更改为 .json
questions_file_path = os.path.join(script_dir, questions_file)
results_file_path = os.path.join(script_dir, "query_results.json")
log_file_path = os.path.join(script_dir, "incorrect_answers_log.json")
# 如果 .json 文件不存在,则生成一个空的 JSON 文件
if not os.path.exists(questions_file_path):
with open(questions_file_path, 'w', encoding='utf-8') as file:
json.dump([], file) # 写入空数组
if not os.path.exists(results_file_path):
with open(results_file_path, 'w', encoding='utf-8') as file:
json.dump([], file) # 写入空数组
if not os.path.exists(log_file_path):
with open(log_file_path, 'w', encoding='utf-8') as file:
json.dump([], file) # 写入空数组
# 更新环境变量
os.environ['TOP_K'] = str(5) # 向量的TOP_K值
os.environ['LLM_TEMPERATURE'] = str(0.1) # 温度值
os.environ['similarity_top_k'] = str(5) # SQL的TOP_K值
init_settings()
init_observability()
index = get_index()
top_k = int(os.getenv("TOP_K")) # 向量的TOP_K值
temperature = float(os.getenv("LLM_TEMPERATURE")) # 温度值
similarity_top_k = float(os.getenv("similarity_top_k")) # SQL的TOP_K值
filters = generate_filters([])
engine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
sql_database = SQLDatabase(engine)
table_schema_objs = makeDescriptionByEngine(sql_database)
table_node_mapping = SQLTableNodeMapping(sql_database)
# 创建SQL查询工具
sql_obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
index_cls=VectorStoreIndex,
)
sql_query_engine = SQLTableRetrieverQueryEngine(sql_database,
sql_obj_index.as_retriever(similarity_top_k=similarity_top_k))
questions_and_answers = read_questions_and_answers(questions_file_path)
# 如果文件为空,则写入参数值
if os.path.getsize(results_file_path) == 0:
with open(results_file_path, 'w', encoding='utf-8') as file:
file.write(f"TOP_K: {top_k}\n")
file.write(f"LLM_TEMPERATURE: {temperature}\n")
file.write(f"similarity_top_k: {similarity_top_k}\n\n")
# 循环执行查询
for i, (question, correct_answer) in enumerate(questions_and_answers):
print(f"Executing query {i+1}: {question}")
if query_type == "vector":
query_engine = index.as_query_engine(
similarity_top_k=top_k, filters=filters
)
query_result = query_engine.query(question)
print(f"向量查询结果: {query_result}\n")
save_results_to_file(question, f"向量查询结果: {query_result}", correct_answer, results_file_path)
# 提取向量查询结果中的数字进行匹配
query_result_number = re.search(r"(\d+)", str(query_result))
elif query_type == "sql":
sql_query_result = sql_query_engine.query(question)
print(f"SQL查询结果: {sql_query_result}\n")
save_results_to_file(question, f"SQL查询结果: {sql_query_result}", correct_answer, results_file_path)
# 提取SQL查询结果中的数字进行匹配
query_result_number = re.search(r"(\d+)", str(sql_query_result))
else:
print("无效的查询类型,请选择 'vector' 或 'sql'")
sys.exit(1)
if query_result_number:
query_number = query_result_number.group(1)
# 判断查询结果是否与正确答案匹配
if query_number == correct_answer:
save_results_to_file(question, query_number, correct_answer, results_file_path)
else:
log_incorrect_answers(question, correct_answer, query_number, log_file_path)
else:
log_incorrect_answers(question, correct_answer, "未找到有效数字", log_file_path)
if __name__ == "__main__":
if len(sys.argv) < 3:
print("请提供questions.json文件名和查询类型(vector 或 sql")
sys.exit(1)
questions_file = sys.argv[1]
query_type = sys.argv[2].lower()
from phoenix.trace import using_project
with using_project(questions_file) as obj:
main(questions_file, query_type)