155 lines
6.3 KiB
Python
155 lines
6.3 KiB
Python
import re
|
||
import os
|
||
import sys
|
||
import json
|
||
from ctypes import cast
|
||
|
||
from llama_index.core import VectorStoreIndex, SQLDatabase
|
||
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
|
||
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex
|
||
from sqlalchemy import create_engine
|
||
|
||
from app.api.routers.chat import generate_filters
|
||
from app.engine import get_index, makeDescriptionByEngine
|
||
from app.engine.loaders.db import CustomDatabaseReader
|
||
from app.engine.vectordb import get_vector_store
|
||
from app.observability import init_observability
|
||
from app.settings import init_settings
|
||
from dotenv import load_dotenv
|
||
|
||
load_dotenv()
|
||
|
||
|
||
def read_questions_and_answers(file_path):
|
||
questions_and_answers = []
|
||
with open(file_path, 'r', encoding='utf-8') as file:
|
||
data = json.load(file) # 读取 JSON 数据
|
||
for entry in data:
|
||
question = entry.get("question", "").strip() # 获取 question
|
||
answer_match = re.search(r"(\d+\.?\d*)", entry.get("answer", "")) # 使用正则提取 answer 中的数字部分
|
||
if question and answer_match:
|
||
answer_value = answer_match.group(1) # 获取匹配的数字
|
||
questions_and_answers.append((question, answer_value))
|
||
return questions_and_answers
|
||
|
||
|
||
def save_results_to_file(question, result, correct_answer, file_path):
|
||
with open(file_path, 'a', encoding='utf-8') as file:
|
||
file.write(f"问题: {question}\n")
|
||
file.write(f"查询结果: {result}\n")
|
||
file.write(f"正确答案: {correct_answer}\n\n")
|
||
|
||
def log_incorrect_answers(question, correct_answer, result, log_file_path):
|
||
with open(log_file_path, 'a', encoding='utf-8') as file:
|
||
file.write(f"错误问题: {question}\n")
|
||
file.write(f"正确答案: {correct_answer}\n")
|
||
file.write(f"查询结果: {result}\n\n")
|
||
|
||
def main(questions_file, query_type):
|
||
# 获取脚本所在的目录
|
||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
|
||
# 将文件扩展名更改为 .json
|
||
questions_file_path = os.path.join(script_dir, questions_file)
|
||
results_file_path = os.path.join(script_dir, "query_results.json")
|
||
log_file_path = os.path.join(script_dir, "incorrect_answers_log.json")
|
||
|
||
# 如果 .json 文件不存在,则生成一个空的 JSON 文件
|
||
if not os.path.exists(questions_file_path):
|
||
with open(questions_file_path, 'w', encoding='utf-8') as file:
|
||
json.dump([], file) # 写入空数组
|
||
|
||
if not os.path.exists(results_file_path):
|
||
with open(results_file_path, 'w', encoding='utf-8') as file:
|
||
json.dump([], file) # 写入空数组
|
||
|
||
if not os.path.exists(log_file_path):
|
||
with open(log_file_path, 'w', encoding='utf-8') as file:
|
||
json.dump([], file) # 写入空数组
|
||
|
||
# 更新环境变量
|
||
os.environ['TOP_K'] = str(5) # 向量的TOP_K值
|
||
os.environ['LLM_TEMPERATURE'] = str(0.1) # 温度值
|
||
os.environ['similarity_top_k'] = str(5) # SQL的TOP_K值
|
||
|
||
init_settings()
|
||
init_observability()
|
||
|
||
index = get_index()
|
||
|
||
top_k = int(os.getenv("TOP_K")) # 向量的TOP_K值
|
||
temperature = float(os.getenv("LLM_TEMPERATURE")) # 温度值
|
||
similarity_top_k = float(os.getenv("similarity_top_k")) # SQL的TOP_K值
|
||
filters = generate_filters([])
|
||
|
||
engine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
|
||
sql_database = SQLDatabase(engine)
|
||
table_schema_objs = makeDescriptionByEngine(sql_database)
|
||
table_node_mapping = SQLTableNodeMapping(sql_database)
|
||
# 创建SQL查询工具
|
||
sql_obj_index = ObjectIndex.from_objects(
|
||
table_schema_objs,
|
||
table_node_mapping,
|
||
index_cls=VectorStoreIndex,
|
||
)
|
||
sql_query_engine = SQLTableRetrieverQueryEngine(sql_database,
|
||
sql_obj_index.as_retriever(similarity_top_k=similarity_top_k))
|
||
|
||
questions_and_answers = read_questions_and_answers(questions_file_path)
|
||
|
||
# 如果文件为空,则写入参数值
|
||
if os.path.getsize(results_file_path) == 0:
|
||
with open(results_file_path, 'w', encoding='utf-8') as file:
|
||
file.write(f"TOP_K: {top_k}\n")
|
||
file.write(f"LLM_TEMPERATURE: {temperature}\n")
|
||
file.write(f"similarity_top_k: {similarity_top_k}\n\n")
|
||
|
||
# 循环执行查询
|
||
for i, (question, correct_answer) in enumerate(questions_and_answers):
|
||
print(f"Executing query {i+1}: {question}")
|
||
|
||
if query_type == "vector":
|
||
query_engine = index.as_query_engine(
|
||
similarity_top_k=top_k, filters=filters
|
||
)
|
||
query_result = query_engine.query(question)
|
||
print(f"向量查询结果: {query_result}\n")
|
||
save_results_to_file(question, f"向量查询结果: {query_result}", correct_answer, results_file_path)
|
||
|
||
# 提取向量查询结果中的数字进行匹配
|
||
query_result_number = re.search(r"(\d+)", str(query_result))
|
||
elif query_type == "sql":
|
||
sql_query_result = sql_query_engine.query(question)
|
||
print(f"SQL查询结果: {sql_query_result}\n")
|
||
save_results_to_file(question, f"SQL查询结果: {sql_query_result}", correct_answer, results_file_path)
|
||
|
||
# 提取SQL查询结果中的数字进行匹配
|
||
query_result_number = re.search(r"(\d+)", str(sql_query_result))
|
||
else:
|
||
print("无效的查询类型,请选择 'vector' 或 'sql'")
|
||
sys.exit(1)
|
||
|
||
if query_result_number:
|
||
query_number = query_result_number.group(1)
|
||
|
||
# 判断查询结果是否与正确答案匹配
|
||
if query_number == correct_answer:
|
||
save_results_to_file(question, query_number, correct_answer, results_file_path)
|
||
else:
|
||
log_incorrect_answers(question, correct_answer, query_number, log_file_path)
|
||
else:
|
||
log_incorrect_answers(question, correct_answer, "未找到有效数字", log_file_path)
|
||
|
||
if __name__ == "__main__":
|
||
|
||
if len(sys.argv) < 3:
|
||
print("请提供questions.json文件名和查询类型(vector 或 sql)")
|
||
sys.exit(1)
|
||
questions_file = sys.argv[1]
|
||
query_type = sys.argv[2].lower()
|
||
|
||
from phoenix.trace import using_project
|
||
|
||
with using_project(questions_file) as obj:
|
||
main(questions_file, query_type)
|