dev #1

Merged
ly merged 41 commits from dev into main 2024-08-22 09:41:14 +08:00
Showing only changes of commit aa2cecc997 - Show all commits
+96 -50
View File
@@ -2,13 +2,11 @@ import re
import os import os
import sys import sys
import json import json
from ctypes import cast from sqlalchemy import create_engine
from llama_index.core import VectorStoreIndex, SQLDatabase from llama_index.core import VectorStoreIndex, SQLDatabase
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex
from sqlalchemy import create_engine
from app.api.routers.chat import generate_filters from app.api.routers.chat import generate_filters
from app.engine import get_index, makeDescriptionByEngine from app.engine import get_index, makeDescriptionByEngine
from app.engine.loaders.db import CustomDatabaseReader from app.engine.loaders.db import CustomDatabaseReader
@@ -19,31 +17,94 @@ from dotenv import load_dotenv
load_dotenv() load_dotenv()
def read_questions_and_answers(file_path): def read_questions_and_answers(file_path):
questions_and_answers = [] questions_and_answers = []
with open(file_path, 'r', encoding='utf-8') as file: with open(file_path, 'r', encoding='utf-8') as file:
data = json.load(file) # 读取 JSON 数据 data = json.load(file) # 读取 JSON 数据
for entry in data: for entry in data:
question = entry.get("question", "").strip() # 获取 question question = entry.get("question", "").strip() # 获取 question
answer_match = re.search(r"(\d+\.?\d*)", entry.get("answer", "")) # 使用正则提取 answer 中的数字部分 answer = entry.get("answer", "").strip() # 直接获取 answer 而不是提取数字
if question and answer_match: if question and answer:
answer_value = answer_match.group(1) # 获取匹配的数字 questions_and_answers.append((question, answer))
questions_and_answers.append((question, answer_value))
return questions_and_answers return questions_and_answers
def save_results_to_file(question, query_result, correct_answer, file_path):
def save_results_to_file(question, result, correct_answer, file_path): # 保存原始查询结果
result_data = {
"问题": question,
"查询结果": str(query_result), # 保存原始查询结果
"正确答案": correct_answer
}
with open(file_path, 'a', encoding='utf-8') as file: with open(file_path, 'a', encoding='utf-8') as file:
file.write(f"问题: {question}\n") json.dump(result_data, file, ensure_ascii=False)
file.write(f"查询结果: {result}\n") file.write('\n') # 每个结果条目之间添加换行符
file.write(f"正确答案: {correct_answer}\n\n")
def log_incorrect_answers(question, correct_answer, result, log_file_path): def log_incorrect_answers(question, correct_answer, query_result, log_file_path):
# 保存原始查询结果
incorrect_data = {
"错误问题": question,
"正确答案": correct_answer,
"查询结果": str(query_result) # 保存原始查询结果
}
with open(log_file_path, 'a', encoding='utf-8') as file: with open(log_file_path, 'a', encoding='utf-8') as file:
file.write(f"错误问题: {question}\n") json.dump(incorrect_data, file, ensure_ascii=False)
file.write(f"正确答案: {correct_answer}\n") file.write('\n') # 每个结果条目之间添加换行符
file.write(f"查询结果: {result}\n\n")
# 提取多个数字
def extract_all_numbers_from_result(result_str):
"""从查询结果字符串中提取所有数字"""
# 使用正则表达式匹配所有数值(包含小数和科学计数法)
numbers = re.findall(r"-?\d+,\d+(\.\d+)?|0E-\d+|\d+(\.\d+)?", result_str)
# 移除逗号并返回所有数字的列表
return [num.replace(',', '') for num in numbers]
# 判断两个浮点数是否接近
def is_close_enough(val1, val2, epsilon=1e-5):
"""判断两个数值是否在指定的误差范围内接近"""
return abs(val1 - val2) < epsilon
def is_answer_correct(query_result_str, correct_answer_str):
"""检查查询结果是否与正确答案匹配"""
# 提取查询结果中的数字或编码
query_result_value = extract_number_or_code_from_result(query_result_str)
# 提取正确答案中的数字或编码
correct_answer_value = extract_number_or_code_from_result(correct_answer_str)
# 对比提取的数字或编码
if query_result_value and correct_answer_value:
try:
# 移除逗号,并转换为浮点数
query_result_float = float(query_result_value.replace(',', ''))
correct_answer_float = float(correct_answer_value.replace(',', ''))
# 处理科学计数法中的零值
if query_result_float == 0.0 and correct_answer_float == 0.0:
return True
# 四舍五入处理到小数点后5位
rounded_query_result = round(query_result_float, 5)
rounded_correct_answer = round(correct_answer_float, 5)
# 比较四舍五入后的浮点数值
return rounded_query_result == rounded_correct_answer
except ValueError:
# 如果无法转换为浮点数,则直接比较字符串
return query_result_value == correct_answer_value
return False # 如果任何一方为空,则认为不匹配
def extract_number_or_code_from_result(result_str):
"""从查询结果字符串中提取数字或编码,并处理逗号、百分号和科学计数法"""
# 使用正则表达式匹配浮点数,包括可能的多位小数、逗号、百分比形式和科学计数法
match = re.search(r"(\d{1,3}(,\d{3})*(\.\d+)?|0E-\d+)", result_str)
if match:
number_str = match.group(1).replace(',', '').replace('%', '') # 移除逗号和百分号
return number_str
# 尝试从结果中提取所有可能的编码格式
potential_codes = re.findall(r"\b[A-Z][A-Za-z\d-]+\b", result_str)
# 返回第一个匹配的编码
return potential_codes[0] if potential_codes else None
def main(questions_file, query_type): def main(questions_file, query_type):
# 获取脚本所在的目录 # 获取脚本所在的目录
@@ -59,14 +120,6 @@ def main(questions_file, query_type):
with open(questions_file_path, 'w', encoding='utf-8') as file: with open(questions_file_path, 'w', encoding='utf-8') as file:
json.dump([], file) # 写入空数组 json.dump([], file) # 写入空数组
if not os.path.exists(results_file_path):
with open(results_file_path, 'w', encoding='utf-8') as file:
json.dump([], file) # 写入空数组
if not os.path.exists(log_file_path):
with open(log_file_path, 'w', encoding='utf-8') as file:
json.dump([], file) # 写入空数组
# 更新环境变量 # 更新环境变量
os.environ['TOP_K'] = str(5) # 向量的TOP_K值 os.environ['TOP_K'] = str(5) # 向量的TOP_K值
os.environ['LLM_TEMPERATURE'] = str(0.1) # 温度值 os.environ['LLM_TEMPERATURE'] = str(0.1) # 温度值
@@ -79,7 +132,7 @@ def main(questions_file, query_type):
top_k = int(os.getenv("TOP_K")) # 向量的TOP_K值 top_k = int(os.getenv("TOP_K")) # 向量的TOP_K值
temperature = float(os.getenv("LLM_TEMPERATURE")) # 温度值 temperature = float(os.getenv("LLM_TEMPERATURE")) # 温度值
similarity_top_k = float(os.getenv("similarity_top_k")) # SQL的TOP_K值 similarity_top_k = int(os.getenv("similarity_top_k")) # SQL的TOP_K值
filters = generate_filters([]) filters = generate_filters([])
engine = create_engine(os.getenv("SQL_DATABASE_URL", "")) engine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
@@ -100,13 +153,16 @@ def main(questions_file, query_type):
# 如果文件为空,则写入参数值 # 如果文件为空,则写入参数值
if os.path.getsize(results_file_path) == 0: if os.path.getsize(results_file_path) == 0:
with open(results_file_path, 'w', encoding='utf-8') as file: with open(results_file_path, 'w', encoding='utf-8') as file:
file.write(f"TOP_K: {top_k}\n") json.dump({
file.write(f"LLM_TEMPERATURE: {temperature}\n") "TOP_K": top_k,
file.write(f"similarity_top_k: {similarity_top_k}\n\n") "LLM_TEMPERATURE": temperature,
"similarity_top_k": similarity_top_k
}, file, ensure_ascii=False)
file.write('\n')
# 循环执行查询 # 循环执行查询
for i, (question, correct_answer) in enumerate(questions_and_answers): for i, (question, correct_answer) in enumerate(questions_and_answers):
print(f"Executing query {i+1}: {question}") print(f"执行查询 {i+1}: {question}")
if query_type == "vector": if query_type == "vector":
query_engine = index.as_query_engine( query_engine = index.as_query_engine(
@@ -114,41 +170,31 @@ def main(questions_file, query_type):
) )
query_result = query_engine.query(question) query_result = query_engine.query(question)
print(f"向量查询结果: {query_result}\n") print(f"向量查询结果: {query_result}\n")
save_results_to_file(question, f"向量查询结果: {query_result}", correct_answer, results_file_path)
# 提取向量查询结果中的数字进行匹配 # 提取向量查询结果中的数字或编码进行匹配
query_result_number = re.search(r"(\d+)", str(query_result)) query_result_str = f"The encoding for the query \"{question}\" is {str(query_result)}"
elif query_type == "sql": elif query_type == "sql":
sql_query_result = sql_query_engine.query(question) sql_query_result = sql_query_engine.query(question)
print(f"SQL查询结果: {sql_query_result}\n") print(f"SQL查询结果: {sql_query_result}\n")
save_results_to_file(question, f"SQL查询结果: {sql_query_result}", correct_answer, results_file_path)
# 提取SQL查询结果中的数字进行匹配 # 提取SQL查询结果中的数字或编码进行匹配
query_result_number = re.search(r"(\d+)", str(sql_query_result)) query_result_str = f"The encoding for the query \"{question}\" is {str(sql_query_result)}"
else: else:
print("无效的查询类型,请选择 'vector''sql'") print("无效的查询类型,请选择 'vector''sql'")
sys.exit(1) sys.exit(1)
if query_result_number: if is_answer_correct(query_result_str, correct_answer):
query_number = query_result_number.group(1) # 只在查询结果正确时记录结果
save_results_to_file(question, query_result_str, correct_answer, results_file_path)
# 判断查询结果是否与正确答案匹配
if query_number == correct_answer:
save_results_to_file(question, query_number, correct_answer, results_file_path)
else:
log_incorrect_answers(question, correct_answer, query_number, log_file_path)
else: else:
log_incorrect_answers(question, correct_answer, "未找到有效数字", log_file_path) # 记录不正确的答案
log_incorrect_answers(question, correct_answer, query_result_str, log_file_path)
if __name__ == "__main__": if __name__ == "__main__":
if len(sys.argv) < 3: if len(sys.argv) < 3:
print("请提供questions.json文件名和查询类型(vector 或 sql)") print("请提供questions.json文件名和查询类型(vector 或 sql)")
sys.exit(1) sys.exit(1)
questions_file = sys.argv[1] questions_file = sys.argv[1]
query_type = sys.argv[2].lower() query_type = sys.argv[2].lower()
from phoenix.trace import using_project main(questions_file, query_type)
with using_project(questions_file) as obj:
main(questions_file, query_type)