Add new files and update existing files

This commit is contained in:
chentianrui
2024-08-16 11:17:27 +08:00
parent 3082ac5f3d
commit ae7e21768b
14 changed files with 1581 additions and 91 deletions
+36 -21
View File
@@ -1,12 +1,12 @@
import re
import os
import sys
import json
from ctypes import cast
from llama_index.core import VectorStoreIndex, SQLDatabase
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex
from llama_index.readers.database import DatabaseReader
from sqlalchemy import create_engine
from app.api.routers.chat import generate_filters
@@ -23,15 +23,16 @@ load_dotenv()
def read_questions_and_answers(file_path):
questions_and_answers = []
with open(file_path, 'r', encoding='utf-8') as file:
for line in file:
if "question" in line and "answer" in line:
question_part = line.split("")[1].strip() # 取 question
answer_part = re.search(r"answer.*?(\d+)", line) # 使用正则提取 answer 中的数字
if answer_part:
answer_value = answer_part.group(1)
questions_and_answers.append((question_part, answer_value))
data = json.load(file) # 读取 JSON 数据
for entry in data:
question = entry.get("question", "").strip() # 取 question
answer_match = re.search(r"(\d+\.?\d*)", entry.get("answer", "")) # 使用正则提取 answer 中的数字部分
if question and answer_match:
answer_value = answer_match.group(1) # 获取匹配的数字
questions_and_answers.append((question, answer_value))
return questions_and_answers
def save_results_to_file(question, result, correct_answer, file_path):
with open(file_path, 'a', encoding='utf-8') as file:
file.write(f"问题: {question}\n")
@@ -44,20 +45,27 @@ def log_incorrect_answers(question, correct_answer, result, log_file_path):
file.write(f"正确答案: {correct_answer}\n")
file.write(f"查询结果: {result}\n\n")
def main():
# 从命令行读取questions_file_path和查询类型
if len(sys.argv) < 3:
print("请提供questions.txt文件的路径和查询类型(vector 或 sql)")
sys.exit(1)
questions_file_path = sys.argv[1]
query_type = sys.argv[2].lower()
def main(questions_file, query_type):
# 获取脚本所在的目录
script_dir = os.path.dirname(os.path.abspath(__file__))
# 设置结果文件和日志文件的路径
results_file_path = os.path.join(script_dir, "query_results.txt")
log_file_path = os.path.join(script_dir, "incorrect_answers_log.txt")
# 将文件扩展名更改为 .json
questions_file_path = os.path.join(script_dir, questions_file)
results_file_path = os.path.join(script_dir, "query_results.json")
log_file_path = os.path.join(script_dir, "incorrect_answers_log.json")
# 如果 .json 文件不存在,则生成一个空的 JSON 文件
if not os.path.exists(questions_file_path):
with open(questions_file_path, 'w', encoding='utf-8') as file:
json.dump([], file) # 写入空数组
if not os.path.exists(results_file_path):
with open(results_file_path, 'w', encoding='utf-8') as file:
json.dump([], file) # 写入空数组
if not os.path.exists(log_file_path):
with open(log_file_path, 'w', encoding='utf-8') as file:
json.dump([], file) # 写入空数组
# 更新环境变量
os.environ['TOP_K'] = str(5) # 向量的TOP_K值
@@ -133,7 +141,14 @@ def main():
log_incorrect_answers(question, correct_answer, "未找到有效数字", log_file_path)
if __name__ == "__main__":
if len(sys.argv) < 3:
print("请提供questions.json文件名和查询类型(vector 或 sql)")
sys.exit(1)
questions_file = sys.argv[1]
query_type = sys.argv[2].lower()
from phoenix.trace import using_project
with using_project("ly_zjapp_test") as obj:
main()
with using_project(questions_file) as obj:
main(questions_file, query_type)