import re import os import sys from ctypes import cast from llama_index.core import VectorStoreIndex, SQLDatabase from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex from llama_index.readers.database import DatabaseReader from sqlalchemy import create_engine from app.api.routers.chat import generate_filters from app.engine import get_index, makeDescriptionByEngine from app.engine.loaders.db import CustomDatabaseReader from app.engine.vectordb import get_vector_store from app.observability import init_observability from app.settings import init_settings from dotenv import load_dotenv load_dotenv() def read_questions_and_answers(file_path): questions_and_answers = [] with open(file_path, 'r', encoding='utf-8') as file: for line in file: if "question" in line and "answer" in line: question_part = line.split(":")[1].strip() # 提取 question answer_part = re.search(r"answer:.*?(\d+)", line) # 使用正则提取 answer 中的数字 if answer_part: answer_value = answer_part.group(1) questions_and_answers.append((question_part, answer_value)) return questions_and_answers def save_results_to_file(question, result, correct_answer, file_path): with open(file_path, 'a', encoding='utf-8') as file: file.write(f"问题: {question}\n") file.write(f"查询结果: {result}\n") file.write(f"正确答案: {correct_answer}\n\n") def log_incorrect_answers(question, correct_answer, result, log_file_path): with open(log_file_path, 'a', encoding='utf-8') as file: file.write(f"错误问题: {question}\n") file.write(f"正确答案: {correct_answer}\n") file.write(f"查询结果: {result}\n\n") def main(): # 从命令行读取questions_file_path和查询类型 if len(sys.argv) < 3: print("请提供questions.txt文件的路径和查询类型(vector 或 sql)") sys.exit(1) questions_file_path = sys.argv[1] query_type = sys.argv[2].lower() # 获取脚本所在的目录 script_dir = os.path.dirname(os.path.abspath(__file__)) # 设置结果文件和日志文件的路径 results_file_path = os.path.join(script_dir, "query_results.txt") log_file_path = os.path.join(script_dir, "incorrect_answers_log.txt") # 更新环境变量 os.environ['TOP_K'] = str(5) # 向量的TOP_K值 os.environ['LLM_TEMPERATURE'] = str(0.1) # 温度值 os.environ['similarity_top_k'] = str(5) # SQL的TOP_K值 init_settings() init_observability() index = get_index() top_k = int(os.getenv("TOP_K")) # 向量的TOP_K值 temperature = float(os.getenv("LLM_TEMPERATURE")) # 温度值 similarity_top_k = float(os.getenv("similarity_top_k")) # SQL的TOP_K值 filters = generate_filters([]) engine = create_engine(os.getenv("SQL_DATABASE_URL", "")) sql_database = SQLDatabase(engine) table_schema_objs = makeDescriptionByEngine(sql_database) table_node_mapping = SQLTableNodeMapping(sql_database) # 创建SQL查询工具 sql_obj_index = ObjectIndex.from_objects( table_schema_objs, table_node_mapping, index_cls=VectorStoreIndex, ) sql_query_engine = SQLTableRetrieverQueryEngine(sql_database, sql_obj_index.as_retriever(similarity_top_k=similarity_top_k)) questions_and_answers = read_questions_and_answers(questions_file_path) # 如果文件为空,则写入参数值 if os.path.getsize(results_file_path) == 0: with open(results_file_path, 'w', encoding='utf-8') as file: file.write(f"TOP_K: {top_k}\n") file.write(f"LLM_TEMPERATURE: {temperature}\n") file.write(f"similarity_top_k: {similarity_top_k}\n\n") # 循环执行查询 for i, (question, correct_answer) in enumerate(questions_and_answers): print(f"Executing query {i+1}: {question}") if query_type == "vector": query_engine = index.as_query_engine( similarity_top_k=top_k, filters=filters ) query_result = query_engine.query(question) print(f"向量查询结果: {query_result}\n") save_results_to_file(question, f"向量查询结果: {query_result}", correct_answer, results_file_path) # 提取向量查询结果中的数字进行匹配 query_result_number = re.search(r"(\d+)", str(query_result)) elif query_type == "sql": sql_query_result = sql_query_engine.query(question) print(f"SQL查询结果: {sql_query_result}\n") save_results_to_file(question, f"SQL查询结果: {sql_query_result}", correct_answer, results_file_path) # 提取SQL查询结果中的数字进行匹配 query_result_number = re.search(r"(\d+)", str(sql_query_result)) else: print("无效的查询类型,请选择 'vector' 或 'sql'") sys.exit(1) if query_result_number: query_number = query_result_number.group(1) # 判断查询结果是否与正确答案匹配 if query_number == correct_answer: save_results_to_file(question, query_number, correct_answer, results_file_path) else: log_incorrect_answers(question, correct_answer, query_number, log_file_path) else: log_incorrect_answers(question, correct_answer, "未找到有效数字", log_file_path) if __name__ == "__main__": from phoenix.trace import using_project with using_project("ly_zjapp_test") as obj: main()