import re import os import sys import json from sqlalchemy import create_engine from llama_index.core import VectorStoreIndex, SQLDatabase from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex from app.api.routers.chat import generate_filters from app.engine import get_index from app.engine.engine import makeDescriptionByEngine from app.engine.loaders.db import CustomDatabaseReader from app.engine.vectordb import get_vector_store from app.observability import init_observability from app.settings import init_settings from dotenv import load_dotenv load_dotenv() def read_questions_and_answers(file_path): questions_and_answers = [] with open(file_path, 'r', encoding='utf-8') as file: data = json.load(file) # 读取 JSON 数据 for entry in data: question = entry.get("question", "").strip() # 获取 question answer = entry.get("answer", "").strip() # 直接获取 answer 而不是提取数字 if question and answer: questions_and_answers.append((question, answer)) return questions_and_answers def save_results_to_file(question, query_result, correct_answer, file_path): # 保存原始查询结果 result_data = { "问题": question, "查询结果": str(query_result), # 保存原始查询结果 "正确答案": correct_answer } with open(file_path, 'a', encoding='utf-8') as file: json.dump(result_data, file, ensure_ascii=False) file.write('\n') # 每个结果条目之间添加换行符 def log_incorrect_answers(question, correct_answer, query_result, log_file_path): # 保存原始查询结果 incorrect_data = { "错误问题": question, "正确答案": correct_answer, "查询结果": str(query_result) # 保存原始查询结果 } with open(log_file_path, 'a', encoding='utf-8') as file: json.dump(incorrect_data, file, ensure_ascii=False) file.write('\n') # 每个结果条目之间添加换行符 # 提取多个数字 def extract_all_numbers_from_result(result_str): """从查询结果字符串中提取所有数字""" # 使用正则表达式匹配所有数值(包含小数和科学计数法) numbers = re.findall(r"-?\d+,\d+(\.\d+)?|0E-\d+|\d+(\.\d+)?", result_str) # 移除逗号并返回所有数字的列表 return [num.replace(',', '') for num in numbers] # 判断两个浮点数是否接近 def is_close_enough(val1, val2, epsilon=1e-5): """判断两个数值是否在指定的误差范围内接近""" return abs(val1 - val2) < epsilon def is_answer_correct(query_result_str, correct_answer_str): """检查查询结果是否与正确答案匹配""" # 提取查询结果中的数字或编码 query_result_value = extract_number_or_code_from_result(query_result_str) # 提取正确答案中的数字或编码 correct_answer_value = extract_number_or_code_from_result(correct_answer_str) # 对比提取的数字或编码 if query_result_value and correct_answer_value: try: # 移除逗号,并转换为浮点数 query_result_float = float(query_result_value.replace(',', '')) correct_answer_float = float(correct_answer_value.replace(',', '')) # 处理科学计数法中的零值 if query_result_float == 0.0 and correct_answer_float == 0.0: return True # 四舍五入处理到小数点后5位 rounded_query_result = round(query_result_float, 5) rounded_correct_answer = round(correct_answer_float, 5) # 比较四舍五入后的浮点数值 return rounded_query_result == rounded_correct_answer except ValueError: # 如果无法转换为浮点数,则直接比较字符串 return query_result_value == correct_answer_value return False # 如果任何一方为空,则认为不匹配 def extract_number_or_code_from_result(result_str): """从查询结果字符串中提取数字或编码,并处理逗号、百分号和科学计数法""" # 使用正则表达式匹配浮点数,包括可能的多位小数、逗号、百分比形式和科学计数法 match = re.search(r"(\d{1,3}(,\d{3})*(\.\d+)?|0E-\d+)", result_str) if match: number_str = match.group(1).replace(',', '').replace('%', '') # 移除逗号和百分号 return number_str # 尝试从结果中提取所有可能的编码格式 potential_codes = re.findall(r"\b[A-Z][A-Za-z\d-]+\b", result_str) # 返回第一个匹配的编码 return potential_codes[0] if potential_codes else None def main(questions_file, query_type): # 获取脚本所在的目录 script_dir = os.path.dirname(os.path.abspath(__file__)) # 将文件扩展名更改为 .json questions_file_path = os.path.join(script_dir, questions_file) results_file_path = os.path.join(script_dir, "query_results.json") log_file_path = os.path.join(script_dir, "incorrect_answers_log.json") # 如果 .json 文件不存在,则生成一个空的 JSON 文件 if not os.path.exists(questions_file_path): with open(questions_file_path, 'w', encoding='utf-8') as file: json.dump([], file) # 写入空数组 # 更新环境变量 os.environ['TOP_K'] = str(5) # 向量的TOP_K值 os.environ['LLM_TEMPERATURE'] = str(0.1) # 温度值 os.environ['similarity_top_k'] = str(5) # SQL的TOP_K值 init_settings() init_observability() index = get_index() top_k = int(os.getenv("TOP_K")) # 向量的TOP_K值 temperature = float(os.getenv("LLM_TEMPERATURE")) # 温度值 similarity_top_k = int(os.getenv("similarity_top_k")) # SQL的TOP_K值 filters = generate_filters([]) engine = create_engine(os.getenv("SQL_DATABASE_URL", "")) sql_database = SQLDatabase(engine) table_schema_objs = makeDescriptionByEngine(sql_database) table_node_mapping = SQLTableNodeMapping(sql_database) # 创建SQL查询工具 sql_obj_index = ObjectIndex.from_objects( table_schema_objs, table_node_mapping, index_cls=VectorStoreIndex, ) sql_query_engine = SQLTableRetrieverQueryEngine(sql_database, sql_obj_index.as_retriever(similarity_top_k=similarity_top_k)) questions_and_answers = read_questions_and_answers(questions_file_path) # 如果文件为空,则写入参数值 if os.path.getsize(results_file_path) == 0: with open(results_file_path, 'w', encoding='utf-8') as file: json.dump({ "TOP_K": top_k, "LLM_TEMPERATURE": temperature, "similarity_top_k": similarity_top_k }, file, ensure_ascii=False) file.write('\n') # 循环执行查询 for i, (question, correct_answer) in enumerate(questions_and_answers): print(f"执行查询 {i+1}: {question}") if query_type == "vector": query_engine = index.as_query_engine( similarity_top_k=top_k, filters=filters ) query_result = query_engine.query(question) print(f"向量查询结果: {query_result}\n") # 提取向量查询结果中的数字或编码进行匹配 query_result_str = f"The encoding for the query \"{question}\" is {str(query_result)}" elif query_type == "sql": sql_query_result = sql_query_engine.query(question) print(f"SQL查询结果: {sql_query_result}\n") # 提取SQL查询结果中的数字或编码进行匹配 query_result_str = f"The encoding for the query \"{question}\" is {str(sql_query_result)}" else: print("无效的查询类型,请选择 'vector' 或 'sql'") sys.exit(1) if is_answer_correct(query_result_str, correct_answer): # 只在查询结果正确时记录结果 save_results_to_file(question, query_result_str, correct_answer, results_file_path) else: # 记录不正确的答案 log_incorrect_answers(question, correct_answer, query_result_str, log_file_path) if __name__ == "__main__": if len(sys.argv) < 3: print("请提供questions.json文件名和查询类型(vector 或 sql)") sys.exit(1) questions_file = sys.argv[1] query_type = sys.argv[2].lower() main(questions_file, query_type)