Add new files and update existing files
This commit is contained in:
+36
-21
@@ -1,12 +1,12 @@
|
||||
import re
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from ctypes import cast
|
||||
|
||||
from llama_index.core import VectorStoreIndex, SQLDatabase
|
||||
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
|
||||
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex
|
||||
from llama_index.readers.database import DatabaseReader
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
from app.api.routers.chat import generate_filters
|
||||
@@ -23,15 +23,16 @@ load_dotenv()
|
||||
def read_questions_and_answers(file_path):
|
||||
questions_and_answers = []
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
for line in file:
|
||||
if "question" in line and "answer" in line:
|
||||
question_part = line.split(":")[1].strip() # 提取 question
|
||||
answer_part = re.search(r"answer:.*?(\d+)", line) # 使用正则提取 answer 中的数字
|
||||
if answer_part:
|
||||
answer_value = answer_part.group(1)
|
||||
questions_and_answers.append((question_part, answer_value))
|
||||
data = json.load(file) # 读取 JSON 数据
|
||||
for entry in data:
|
||||
question = entry.get("question", "").strip() # 获取 question
|
||||
answer_match = re.search(r"(\d+\.?\d*)", entry.get("answer", "")) # 使用正则提取 answer 中的数字部分
|
||||
if question and answer_match:
|
||||
answer_value = answer_match.group(1) # 获取匹配的数字
|
||||
questions_and_answers.append((question, answer_value))
|
||||
return questions_and_answers
|
||||
|
||||
|
||||
def save_results_to_file(question, result, correct_answer, file_path):
|
||||
with open(file_path, 'a', encoding='utf-8') as file:
|
||||
file.write(f"问题: {question}\n")
|
||||
@@ -44,20 +45,27 @@ def log_incorrect_answers(question, correct_answer, result, log_file_path):
|
||||
file.write(f"正确答案: {correct_answer}\n")
|
||||
file.write(f"查询结果: {result}\n\n")
|
||||
|
||||
def main():
|
||||
# 从命令行读取questions_file_path和查询类型
|
||||
if len(sys.argv) < 3:
|
||||
print("请提供questions.txt文件的路径和查询类型(vector 或 sql)")
|
||||
sys.exit(1)
|
||||
questions_file_path = sys.argv[1]
|
||||
query_type = sys.argv[2].lower()
|
||||
|
||||
def main(questions_file, query_type):
|
||||
# 获取脚本所在的目录
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# 设置结果文件和日志文件的路径
|
||||
results_file_path = os.path.join(script_dir, "query_results.txt")
|
||||
log_file_path = os.path.join(script_dir, "incorrect_answers_log.txt")
|
||||
# 将文件扩展名更改为 .json
|
||||
questions_file_path = os.path.join(script_dir, questions_file)
|
||||
results_file_path = os.path.join(script_dir, "query_results.json")
|
||||
log_file_path = os.path.join(script_dir, "incorrect_answers_log.json")
|
||||
|
||||
# 如果 .json 文件不存在,则生成一个空的 JSON 文件
|
||||
if not os.path.exists(questions_file_path):
|
||||
with open(questions_file_path, 'w', encoding='utf-8') as file:
|
||||
json.dump([], file) # 写入空数组
|
||||
|
||||
if not os.path.exists(results_file_path):
|
||||
with open(results_file_path, 'w', encoding='utf-8') as file:
|
||||
json.dump([], file) # 写入空数组
|
||||
|
||||
if not os.path.exists(log_file_path):
|
||||
with open(log_file_path, 'w', encoding='utf-8') as file:
|
||||
json.dump([], file) # 写入空数组
|
||||
|
||||
# 更新环境变量
|
||||
os.environ['TOP_K'] = str(5) # 向量的TOP_K值
|
||||
@@ -133,7 +141,14 @@ def main():
|
||||
log_incorrect_answers(question, correct_answer, "未找到有效数字", log_file_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
if len(sys.argv) < 3:
|
||||
print("请提供questions.json文件名和查询类型(vector 或 sql)")
|
||||
sys.exit(1)
|
||||
questions_file = sys.argv[1]
|
||||
query_type = sys.argv[2].lower()
|
||||
|
||||
from phoenix.trace import using_project
|
||||
|
||||
with using_project("ly_zjapp_test") as obj:
|
||||
main()
|
||||
with using_project(questions_file) as obj:
|
||||
main(questions_file, query_type)
|
||||
|
||||
Reference in New Issue
Block a user