Files
2024-08-16 14:29:35 +08:00

201 lines
8.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import re
import os
import sys
import json
from sqlalchemy import create_engine
from llama_index.core import VectorStoreIndex, SQLDatabase
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex
from app.api.routers.chat import generate_filters
from app.engine import get_index, makeDescriptionByEngine
from app.engine.loaders.db import CustomDatabaseReader
from app.engine.vectordb import get_vector_store
from app.observability import init_observability
from app.settings import init_settings
from dotenv import load_dotenv
load_dotenv()
def read_questions_and_answers(file_path):
questions_and_answers = []
with open(file_path, 'r', encoding='utf-8') as file:
data = json.load(file) # 读取 JSON 数据
for entry in data:
question = entry.get("question", "").strip() # 获取 question
answer = entry.get("answer", "").strip() # 直接获取 answer 而不是提取数字
if question and answer:
questions_and_answers.append((question, answer))
return questions_and_answers
def save_results_to_file(question, query_result, correct_answer, file_path):
# 保存原始查询结果
result_data = {
"问题": question,
"查询结果": str(query_result), # 保存原始查询结果
"正确答案": correct_answer
}
with open(file_path, 'a', encoding='utf-8') as file:
json.dump(result_data, file, ensure_ascii=False)
file.write('\n') # 每个结果条目之间添加换行符
def log_incorrect_answers(question, correct_answer, query_result, log_file_path):
# 保存原始查询结果
incorrect_data = {
"错误问题": question,
"正确答案": correct_answer,
"查询结果": str(query_result) # 保存原始查询结果
}
with open(log_file_path, 'a', encoding='utf-8') as file:
json.dump(incorrect_data, file, ensure_ascii=False)
file.write('\n') # 每个结果条目之间添加换行符
# 提取多个数字
def extract_all_numbers_from_result(result_str):
"""从查询结果字符串中提取所有数字"""
# 使用正则表达式匹配所有数值(包含小数和科学计数法)
numbers = re.findall(r"-?\d+,\d+(\.\d+)?|0E-\d+|\d+(\.\d+)?", result_str)
# 移除逗号并返回所有数字的列表
return [num.replace(',', '') for num in numbers]
# 判断两个浮点数是否接近
def is_close_enough(val1, val2, epsilon=1e-5):
"""判断两个数值是否在指定的误差范围内接近"""
return abs(val1 - val2) < epsilon
def is_answer_correct(query_result_str, correct_answer_str):
"""检查查询结果是否与正确答案匹配"""
# 提取查询结果中的数字或编码
query_result_value = extract_number_or_code_from_result(query_result_str)
# 提取正确答案中的数字或编码
correct_answer_value = extract_number_or_code_from_result(correct_answer_str)
# 对比提取的数字或编码
if query_result_value and correct_answer_value:
try:
# 移除逗号,并转换为浮点数
query_result_float = float(query_result_value.replace(',', ''))
correct_answer_float = float(correct_answer_value.replace(',', ''))
# 处理科学计数法中的零值
if query_result_float == 0.0 and correct_answer_float == 0.0:
return True
# 四舍五入处理到小数点后5位
rounded_query_result = round(query_result_float, 5)
rounded_correct_answer = round(correct_answer_float, 5)
# 比较四舍五入后的浮点数值
return rounded_query_result == rounded_correct_answer
except ValueError:
# 如果无法转换为浮点数,则直接比较字符串
return query_result_value == correct_answer_value
return False # 如果任何一方为空,则认为不匹配
def extract_number_or_code_from_result(result_str):
"""从查询结果字符串中提取数字或编码,并处理逗号、百分号和科学计数法"""
# 使用正则表达式匹配浮点数,包括可能的多位小数、逗号、百分比形式和科学计数法
match = re.search(r"(\d{1,3}(,\d{3})*(\.\d+)?|0E-\d+)", result_str)
if match:
number_str = match.group(1).replace(',', '').replace('%', '') # 移除逗号和百分号
return number_str
# 尝试从结果中提取所有可能的编码格式
potential_codes = re.findall(r"\b[A-Z][A-Za-z\d-]+\b", result_str)
# 返回第一个匹配的编码
return potential_codes[0] if potential_codes else None
def main(questions_file, query_type):
# 获取脚本所在的目录
script_dir = os.path.dirname(os.path.abspath(__file__))
# 将文件扩展名更改为 .json
questions_file_path = os.path.join(script_dir, questions_file)
results_file_path = os.path.join(script_dir, "query_results.json")
log_file_path = os.path.join(script_dir, "incorrect_answers_log.json")
# 如果 .json 文件不存在,则生成一个空的 JSON 文件
if not os.path.exists(questions_file_path):
with open(questions_file_path, 'w', encoding='utf-8') as file:
json.dump([], file) # 写入空数组
# 更新环境变量
os.environ['TOP_K'] = str(5) # 向量的TOP_K值
os.environ['LLM_TEMPERATURE'] = str(0.1) # 温度值
os.environ['similarity_top_k'] = str(5) # SQL的TOP_K值
init_settings()
init_observability()
index = get_index()
top_k = int(os.getenv("TOP_K")) # 向量的TOP_K值
temperature = float(os.getenv("LLM_TEMPERATURE")) # 温度值
similarity_top_k = int(os.getenv("similarity_top_k")) # SQL的TOP_K值
filters = generate_filters([])
engine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
sql_database = SQLDatabase(engine)
table_schema_objs = makeDescriptionByEngine(sql_database)
table_node_mapping = SQLTableNodeMapping(sql_database)
# 创建SQL查询工具
sql_obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
index_cls=VectorStoreIndex,
)
sql_query_engine = SQLTableRetrieverQueryEngine(sql_database,
sql_obj_index.as_retriever(similarity_top_k=similarity_top_k))
questions_and_answers = read_questions_and_answers(questions_file_path)
# 如果文件为空,则写入参数值
if os.path.getsize(results_file_path) == 0:
with open(results_file_path, 'w', encoding='utf-8') as file:
json.dump({
"TOP_K": top_k,
"LLM_TEMPERATURE": temperature,
"similarity_top_k": similarity_top_k
}, file, ensure_ascii=False)
file.write('\n')
# 循环执行查询
for i, (question, correct_answer) in enumerate(questions_and_answers):
print(f"执行查询 {i+1}: {question}")
if query_type == "vector":
query_engine = index.as_query_engine(
similarity_top_k=top_k, filters=filters
)
query_result = query_engine.query(question)
print(f"向量查询结果: {query_result}\n")
# 提取向量查询结果中的数字或编码进行匹配
query_result_str = f"The encoding for the query \"{question}\" is {str(query_result)}"
elif query_type == "sql":
sql_query_result = sql_query_engine.query(question)
print(f"SQL查询结果: {sql_query_result}\n")
# 提取SQL查询结果中的数字或编码进行匹配
query_result_str = f"The encoding for the query \"{question}\" is {str(sql_query_result)}"
else:
print("无效的查询类型,请选择 'vector' 或 'sql'")
sys.exit(1)
if is_answer_correct(query_result_str, correct_answer):
# 只在查询结果正确时记录结果
save_results_to_file(question, query_result_str, correct_answer, results_file_path)
else:
# 记录不正确的答案
log_incorrect_answers(question, correct_answer, query_result_str, log_file_path)
if __name__ == "__main__":
if len(sys.argv) < 3:
print("请提供questions.json文件名和查询类型(vector 或 sql")
sys.exit(1)
questions_file = sys.argv[1]
query_type = sys.argv[2].lower()
main(questions_file, query_type)