Add new files and update existing files

This commit is contained in:
chentianrui
2024-08-16 11:17:27 +08:00
parent 3082ac5f3d
commit ae7e21768b
14 changed files with 1581 additions and 91 deletions
+27 -33
View File
@@ -1,16 +1,14 @@
import os
from ctypes import cast
import json
import sys
from llama_index.core import VectorStoreIndex, SQLDatabase
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex
from llama_index.readers.database import DatabaseReader
from sqlalchemy import create_engine
from app.api.routers.chat import generate_filters
from app.engine import get_index, makeDescriptionByEngine
from app.engine.loaders.db import CustomDatabaseReader
from app.engine.vectordb import get_vector_store
from app.observability import init_observability
from app.settings import init_settings
@@ -19,25 +17,21 @@ from dotenv import load_dotenv
load_dotenv()
def read_questions(file_path):
questions = []
with open(file_path, 'r', encoding='utf-8') as file:
for line in file:
if "question" in line:
question_part = line.split("")[1].strip() # 提取 "question" 后的内容
questions.append(question_part)
data = json.load(file)
questions = [item["question"] for item in data]
return questions
def save_results_to_file(question, result, file_path):
result_data = {
"question": question,
"result": result
}
with open(file_path, 'a', encoding='utf-8') as file:
file.write(f"问题: {question}\n")
file.write(f"结果: {result}\n\n")
json.dump(result_data, file, ensure_ascii=False)
file.write('\n')
def main():
# 从命令行读取questions_file_path
if len(sys.argv) < 2:
print("请提供questions.txt文件的路径")
sys.exit(1)
questions_file_path = sys.argv[1]
def main(questions_file):
# 更新环境变量
os.environ['TOP_K'] = str(5) # 向量的TOP_K值
os.environ['LLM_TEMPERATURE'] = str(0.1) # 温度值
@@ -66,36 +60,36 @@ def main():
sql_query_engine = SQLTableRetrieverQueryEngine(sql_database,
sql_obj_index.as_retriever(similarity_top_k=similarity_top_k))
questions = read_questions(questions_file_path)
script_dir = os.path.dirname(os.path.abspath(__file__))
questions_file_path = os.path.join(script_dir, questions_file)
results_file_path = os.path.join(script_dir, "parameters_results.json")
results_file_path = os.path.join(script_dir, "query_results.txt")
questions = read_questions(questions_file_path)
# 如果文件为空,则写入参数值
if os.path.getsize(results_file_path) == 0:
if not os.path.isfile(results_file_path):
with open(results_file_path, 'w', encoding='utf-8') as file:
file.write(f"TOP_K: {top_k}\n")
file.write(f"LLM_TEMPERATURE: {temperature}\n")
file.write(f"similarity_top_k: {similarity_top_k}\n\n")
json.dump({
"TOP_K": top_k,
"LLM_TEMPERATURE": temperature,
"similarity_top_k": similarity_top_k
}, file, ensure_ascii=False)
file.write('\n')
# 循环执行查询
for i, question in enumerate(questions):
print(f"Executing query {i+1}: {question}")
# query_engine = index.as_query_engine(
# similarity_top_k=top_k, filters=filters
# )
# query_result = query_engine.query(question)
# print(f"向量查询结果: {query_result}\n")
# save_results_to_file(question, f"向量查询结果: {query_result}", results_file_path)
sql_query_result = sql_query_engine.query(question)
print(f"SQL查询结果: {sql_query_result}\n")
save_results_to_file(question, f"SQL查询结果: {sql_query_result}", results_file_path)
if __name__ == "__main__":
if len(sys.argv) < 2:
print("请提供questions.json文件的路径")
sys.exit(1)
questions_file = sys.argv[1]
from phoenix.trace import using_project
with using_project("ly_zjapp_test") as obj:
main()
with using_project(questions_file) as obj:
main(questions_file)