dev #1
+96
-50
@@ -2,13 +2,11 @@ import re
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from ctypes import cast
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
from llama_index.core import VectorStoreIndex, SQLDatabase
|
||||
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
|
||||
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
from app.api.routers.chat import generate_filters
|
||||
from app.engine import get_index, makeDescriptionByEngine
|
||||
from app.engine.loaders.db import CustomDatabaseReader
|
||||
@@ -19,31 +17,94 @@ from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def read_questions_and_answers(file_path):
|
||||
questions_and_answers = []
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
data = json.load(file) # 读取 JSON 数据
|
||||
for entry in data:
|
||||
question = entry.get("question", "").strip() # 获取 question
|
||||
answer_match = re.search(r"(\d+\.?\d*)", entry.get("answer", "")) # 使用正则提取 answer 中的数字部分
|
||||
if question and answer_match:
|
||||
answer_value = answer_match.group(1) # 获取匹配的数字
|
||||
questions_and_answers.append((question, answer_value))
|
||||
answer = entry.get("answer", "").strip() # 直接获取 answer 而不是提取数字
|
||||
if question and answer:
|
||||
questions_and_answers.append((question, answer))
|
||||
return questions_and_answers
|
||||
|
||||
|
||||
def save_results_to_file(question, result, correct_answer, file_path):
|
||||
def save_results_to_file(question, query_result, correct_answer, file_path):
|
||||
# 保存原始查询结果
|
||||
result_data = {
|
||||
"问题": question,
|
||||
"查询结果": str(query_result), # 保存原始查询结果
|
||||
"正确答案": correct_answer
|
||||
}
|
||||
with open(file_path, 'a', encoding='utf-8') as file:
|
||||
file.write(f"问题: {question}\n")
|
||||
file.write(f"查询结果: {result}\n")
|
||||
file.write(f"正确答案: {correct_answer}\n\n")
|
||||
json.dump(result_data, file, ensure_ascii=False)
|
||||
file.write('\n') # 每个结果条目之间添加换行符
|
||||
|
||||
def log_incorrect_answers(question, correct_answer, result, log_file_path):
|
||||
def log_incorrect_answers(question, correct_answer, query_result, log_file_path):
|
||||
# 保存原始查询结果
|
||||
incorrect_data = {
|
||||
"错误问题": question,
|
||||
"正确答案": correct_answer,
|
||||
"查询结果": str(query_result) # 保存原始查询结果
|
||||
}
|
||||
with open(log_file_path, 'a', encoding='utf-8') as file:
|
||||
file.write(f"错误问题: {question}\n")
|
||||
file.write(f"正确答案: {correct_answer}\n")
|
||||
file.write(f"查询结果: {result}\n\n")
|
||||
json.dump(incorrect_data, file, ensure_ascii=False)
|
||||
file.write('\n') # 每个结果条目之间添加换行符
|
||||
|
||||
# 提取多个数字
|
||||
def extract_all_numbers_from_result(result_str):
|
||||
"""从查询结果字符串中提取所有数字"""
|
||||
# 使用正则表达式匹配所有数值(包含小数和科学计数法)
|
||||
numbers = re.findall(r"-?\d+,\d+(\.\d+)?|0E-\d+|\d+(\.\d+)?", result_str)
|
||||
# 移除逗号并返回所有数字的列表
|
||||
return [num.replace(',', '') for num in numbers]
|
||||
|
||||
# 判断两个浮点数是否接近
|
||||
def is_close_enough(val1, val2, epsilon=1e-5):
|
||||
"""判断两个数值是否在指定的误差范围内接近"""
|
||||
return abs(val1 - val2) < epsilon
|
||||
|
||||
def is_answer_correct(query_result_str, correct_answer_str):
|
||||
"""检查查询结果是否与正确答案匹配"""
|
||||
# 提取查询结果中的数字或编码
|
||||
query_result_value = extract_number_or_code_from_result(query_result_str)
|
||||
# 提取正确答案中的数字或编码
|
||||
correct_answer_value = extract_number_or_code_from_result(correct_answer_str)
|
||||
|
||||
# 对比提取的数字或编码
|
||||
if query_result_value and correct_answer_value:
|
||||
try:
|
||||
# 移除逗号,并转换为浮点数
|
||||
query_result_float = float(query_result_value.replace(',', ''))
|
||||
correct_answer_float = float(correct_answer_value.replace(',', ''))
|
||||
|
||||
# 处理科学计数法中的零值
|
||||
if query_result_float == 0.0 and correct_answer_float == 0.0:
|
||||
return True
|
||||
|
||||
# 四舍五入处理到小数点后5位
|
||||
rounded_query_result = round(query_result_float, 5)
|
||||
rounded_correct_answer = round(correct_answer_float, 5)
|
||||
|
||||
# 比较四舍五入后的浮点数值
|
||||
return rounded_query_result == rounded_correct_answer
|
||||
except ValueError:
|
||||
# 如果无法转换为浮点数,则直接比较字符串
|
||||
return query_result_value == correct_answer_value
|
||||
return False # 如果任何一方为空,则认为不匹配
|
||||
|
||||
def extract_number_or_code_from_result(result_str):
|
||||
"""从查询结果字符串中提取数字或编码,并处理逗号、百分号和科学计数法"""
|
||||
# 使用正则表达式匹配浮点数,包括可能的多位小数、逗号、百分比形式和科学计数法
|
||||
match = re.search(r"(\d{1,3}(,\d{3})*(\.\d+)?|0E-\d+)", result_str)
|
||||
if match:
|
||||
number_str = match.group(1).replace(',', '').replace('%', '') # 移除逗号和百分号
|
||||
return number_str
|
||||
|
||||
# 尝试从结果中提取所有可能的编码格式
|
||||
potential_codes = re.findall(r"\b[A-Z][A-Za-z\d-]+\b", result_str)
|
||||
# 返回第一个匹配的编码
|
||||
return potential_codes[0] if potential_codes else None
|
||||
|
||||
|
||||
def main(questions_file, query_type):
|
||||
# 获取脚本所在的目录
|
||||
@@ -59,14 +120,6 @@ def main(questions_file, query_type):
|
||||
with open(questions_file_path, 'w', encoding='utf-8') as file:
|
||||
json.dump([], file) # 写入空数组
|
||||
|
||||
if not os.path.exists(results_file_path):
|
||||
with open(results_file_path, 'w', encoding='utf-8') as file:
|
||||
json.dump([], file) # 写入空数组
|
||||
|
||||
if not os.path.exists(log_file_path):
|
||||
with open(log_file_path, 'w', encoding='utf-8') as file:
|
||||
json.dump([], file) # 写入空数组
|
||||
|
||||
# 更新环境变量
|
||||
os.environ['TOP_K'] = str(5) # 向量的TOP_K值
|
||||
os.environ['LLM_TEMPERATURE'] = str(0.1) # 温度值
|
||||
@@ -79,7 +132,7 @@ def main(questions_file, query_type):
|
||||
|
||||
top_k = int(os.getenv("TOP_K")) # 向量的TOP_K值
|
||||
temperature = float(os.getenv("LLM_TEMPERATURE")) # 温度值
|
||||
similarity_top_k = float(os.getenv("similarity_top_k")) # SQL的TOP_K值
|
||||
similarity_top_k = int(os.getenv("similarity_top_k")) # SQL的TOP_K值
|
||||
filters = generate_filters([])
|
||||
|
||||
engine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
|
||||
@@ -100,13 +153,16 @@ def main(questions_file, query_type):
|
||||
# 如果文件为空,则写入参数值
|
||||
if os.path.getsize(results_file_path) == 0:
|
||||
with open(results_file_path, 'w', encoding='utf-8') as file:
|
||||
file.write(f"TOP_K: {top_k}\n")
|
||||
file.write(f"LLM_TEMPERATURE: {temperature}\n")
|
||||
file.write(f"similarity_top_k: {similarity_top_k}\n\n")
|
||||
json.dump({
|
||||
"TOP_K": top_k,
|
||||
"LLM_TEMPERATURE": temperature,
|
||||
"similarity_top_k": similarity_top_k
|
||||
}, file, ensure_ascii=False)
|
||||
file.write('\n')
|
||||
|
||||
# 循环执行查询
|
||||
for i, (question, correct_answer) in enumerate(questions_and_answers):
|
||||
print(f"Executing query {i+1}: {question}")
|
||||
print(f"执行查询 {i+1}: {question}")
|
||||
|
||||
if query_type == "vector":
|
||||
query_engine = index.as_query_engine(
|
||||
@@ -114,41 +170,31 @@ def main(questions_file, query_type):
|
||||
)
|
||||
query_result = query_engine.query(question)
|
||||
print(f"向量查询结果: {query_result}\n")
|
||||
save_results_to_file(question, f"向量查询结果: {query_result}", correct_answer, results_file_path)
|
||||
|
||||
# 提取向量查询结果中的数字进行匹配
|
||||
query_result_number = re.search(r"(\d+)", str(query_result))
|
||||
# 提取向量查询结果中的数字或编码进行匹配
|
||||
query_result_str = f"The encoding for the query \"{question}\" is {str(query_result)}"
|
||||
elif query_type == "sql":
|
||||
sql_query_result = sql_query_engine.query(question)
|
||||
print(f"SQL查询结果: {sql_query_result}\n")
|
||||
save_results_to_file(question, f"SQL查询结果: {sql_query_result}", correct_answer, results_file_path)
|
||||
|
||||
# 提取SQL查询结果中的数字进行匹配
|
||||
query_result_number = re.search(r"(\d+)", str(sql_query_result))
|
||||
# 提取SQL查询结果中的数字或编码进行匹配
|
||||
query_result_str = f"The encoding for the query \"{question}\" is {str(sql_query_result)}"
|
||||
else:
|
||||
print("无效的查询类型,请选择 'vector' 或 'sql'")
|
||||
sys.exit(1)
|
||||
|
||||
if query_result_number:
|
||||
query_number = query_result_number.group(1)
|
||||
|
||||
# 判断查询结果是否与正确答案匹配
|
||||
if query_number == correct_answer:
|
||||
save_results_to_file(question, query_number, correct_answer, results_file_path)
|
||||
else:
|
||||
log_incorrect_answers(question, correct_answer, query_number, log_file_path)
|
||||
if is_answer_correct(query_result_str, correct_answer):
|
||||
# 只在查询结果正确时记录结果
|
||||
save_results_to_file(question, query_result_str, correct_answer, results_file_path)
|
||||
else:
|
||||
log_incorrect_answers(question, correct_answer, "未找到有效数字", log_file_path)
|
||||
# 记录不正确的答案
|
||||
log_incorrect_answers(question, correct_answer, query_result_str, log_file_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
if len(sys.argv) < 3:
|
||||
print("请提供questions.json文件名和查询类型(vector 或 sql)")
|
||||
sys.exit(1)
|
||||
questions_file = sys.argv[1]
|
||||
query_type = sys.argv[2].lower()
|
||||
|
||||
from phoenix.trace import using_project
|
||||
|
||||
with using_project(questions_file) as obj:
|
||||
main(questions_file, query_type)
|
||||
main(questions_file, query_type)
|
||||
|
||||
Reference in New Issue
Block a user