From 3082ac5f3daab7df41f0dfeb0ba4c4918ec9ff4f Mon Sep 17 00:00:00 2001 From: chentianrui Date: Thu, 15 Aug 2024 19:08:58 +0800 Subject: [PATCH] Add new files and update existing files --- backend/test1/__init__.py | 0 backend/test1/chromedb.py | 19 ++++ backend/test1/incorrect_answers_log.txt | 1 + backend/test1/query_results.txt | 0 backend/test1/query_results1.txt | 0 backend/test1/query_test.py | 139 ++++++++++++++++++++++++ backend/test1/question.py | 56 ++++++++++ backend/test1/questions.txt | 14 +++ backend/test1/test_parameters.py | 101 +++++++++++++++++ 9 files changed, 330 insertions(+) create mode 100644 backend/test1/__init__.py create mode 100644 backend/test1/chromedb.py create mode 100644 backend/test1/incorrect_answers_log.txt create mode 100644 backend/test1/query_results.txt create mode 100644 backend/test1/query_results1.txt create mode 100644 backend/test1/query_test.py create mode 100644 backend/test1/question.py create mode 100644 backend/test1/questions.txt create mode 100644 backend/test1/test_parameters.py diff --git a/backend/test1/__init__.py b/backend/test1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/test1/chromedb.py b/backend/test1/chromedb.py new file mode 100644 index 0000000..3e552e3 --- /dev/null +++ b/backend/test1/chromedb.py @@ -0,0 +1,19 @@ +import chromadb + +# 创建 ChromaDB 客户端 +chroma_client = chromadb.PersistentClient(path="/home/bw/ctr/zjdataai-app/backend/storage_vector-1/") + +# 获取已存在的 "default" 集合 +collection = chroma_client.get_collection(name="default") + +# 获取集合中的所有数据 +results = collection.get( + include=['documents', 'metadatas', 'embeddings'] # 只包含允许的选项 +) + +# 将结果转换为字符串并保存到txt文件中 +with open('/home/bw/ctr/zjdataai-app/backend/test1/query_results-1.txt', 'w', encoding='utf-8') as file: + file.write(str(results)) + +# 打印结果 +print("查询结果已保存到 query_results.txt 文件中。") diff --git a/backend/test1/incorrect_answers_log.txt b/backend/test1/incorrect_answers_log.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/backend/test1/incorrect_answers_log.txt @@ -0,0 +1 @@ + diff --git a/backend/test1/query_results.txt b/backend/test1/query_results.txt new file mode 100644 index 0000000..e69de29 diff --git a/backend/test1/query_results1.txt b/backend/test1/query_results1.txt new file mode 100644 index 0000000..e69de29 diff --git a/backend/test1/query_test.py b/backend/test1/query_test.py new file mode 100644 index 0000000..60533b8 --- /dev/null +++ b/backend/test1/query_test.py @@ -0,0 +1,139 @@ +import re +import os +import sys +from ctypes import cast + +from llama_index.core import VectorStoreIndex, SQLDatabase +from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine +from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex +from llama_index.readers.database import DatabaseReader +from sqlalchemy import create_engine + +from app.api.routers.chat import generate_filters +from app.engine import get_index, makeDescriptionByEngine +from app.engine.loaders.db import CustomDatabaseReader +from app.engine.vectordb import get_vector_store +from app.observability import init_observability +from app.settings import init_settings +from dotenv import load_dotenv + +load_dotenv() + + +def read_questions_and_answers(file_path): + questions_and_answers = [] + with open(file_path, 'r', encoding='utf-8') as file: + for line in file: + if "question" in line and "answer" in line: + question_part = line.split(":")[1].strip() # 提取 question + answer_part = re.search(r"answer:.*?(\d+)", line) # 使用正则提取 answer 中的数字 + if answer_part: + answer_value = answer_part.group(1) + questions_and_answers.append((question_part, answer_value)) + return questions_and_answers + +def save_results_to_file(question, result, correct_answer, file_path): + with open(file_path, 'a', encoding='utf-8') as file: + file.write(f"问题: {question}\n") + file.write(f"查询结果: {result}\n") + file.write(f"正确答案: {correct_answer}\n\n") + +def log_incorrect_answers(question, correct_answer, result, log_file_path): + with open(log_file_path, 'a', encoding='utf-8') as file: + file.write(f"错误问题: {question}\n") + file.write(f"正确答案: {correct_answer}\n") + file.write(f"查询结果: {result}\n\n") + +def main(): + # 从命令行读取questions_file_path和查询类型 + if len(sys.argv) < 3: + print("请提供questions.txt文件的路径和查询类型(vector 或 sql)") + sys.exit(1) + questions_file_path = sys.argv[1] + query_type = sys.argv[2].lower() + + # 获取脚本所在的目录 + script_dir = os.path.dirname(os.path.abspath(__file__)) + + # 设置结果文件和日志文件的路径 + results_file_path = os.path.join(script_dir, "query_results.txt") + log_file_path = os.path.join(script_dir, "incorrect_answers_log.txt") + + # 更新环境变量 + os.environ['TOP_K'] = str(5) # 向量的TOP_K值 + os.environ['LLM_TEMPERATURE'] = str(0.1) # 温度值 + os.environ['similarity_top_k'] = str(5) # SQL的TOP_K值 + + init_settings() + init_observability() + + index = get_index() + + top_k = int(os.getenv("TOP_K")) # 向量的TOP_K值 + temperature = float(os.getenv("LLM_TEMPERATURE")) # 温度值 + similarity_top_k = float(os.getenv("similarity_top_k")) # SQL的TOP_K值 + filters = generate_filters([]) + + engine = create_engine(os.getenv("SQL_DATABASE_URL", "")) + sql_database = SQLDatabase(engine) + table_schema_objs = makeDescriptionByEngine(sql_database) + table_node_mapping = SQLTableNodeMapping(sql_database) + # 创建SQL查询工具 + sql_obj_index = ObjectIndex.from_objects( + table_schema_objs, + table_node_mapping, + index_cls=VectorStoreIndex, + ) + sql_query_engine = SQLTableRetrieverQueryEngine(sql_database, + sql_obj_index.as_retriever(similarity_top_k=similarity_top_k)) + + questions_and_answers = read_questions_and_answers(questions_file_path) + + # 如果文件为空,则写入参数值 + if os.path.getsize(results_file_path) == 0: + with open(results_file_path, 'w', encoding='utf-8') as file: + file.write(f"TOP_K: {top_k}\n") + file.write(f"LLM_TEMPERATURE: {temperature}\n") + file.write(f"similarity_top_k: {similarity_top_k}\n\n") + + # 循环执行查询 + for i, (question, correct_answer) in enumerate(questions_and_answers): + print(f"Executing query {i+1}: {question}") + + if query_type == "vector": + query_engine = index.as_query_engine( + similarity_top_k=top_k, filters=filters + ) + query_result = query_engine.query(question) + print(f"向量查询结果: {query_result}\n") + save_results_to_file(question, f"向量查询结果: {query_result}", correct_answer, results_file_path) + + # 提取向量查询结果中的数字进行匹配 + query_result_number = re.search(r"(\d+)", str(query_result)) + elif query_type == "sql": + sql_query_result = sql_query_engine.query(question) + print(f"SQL查询结果: {sql_query_result}\n") + save_results_to_file(question, f"SQL查询结果: {sql_query_result}", correct_answer, results_file_path) + + # 提取SQL查询结果中的数字进行匹配 + query_result_number = re.search(r"(\d+)", str(sql_query_result)) + else: + print("无效的查询类型,请选择 'vector' 或 'sql'") + sys.exit(1) + + if query_result_number: + query_number = query_result_number.group(1) + + # 判断查询结果是否与正确答案匹配 + if query_number == correct_answer: + save_results_to_file(question, query_number, correct_answer, results_file_path) + else: + log_incorrect_answers(question, correct_answer, query_number, log_file_path) + else: + log_incorrect_answers(question, correct_answer, "未找到有效数字", log_file_path) + +if __name__ == "__main__": + from phoenix.trace import using_project + + with using_project("ly_zjapp_test") as obj: + main() diff --git a/backend/test1/question.py b/backend/test1/question.py new file mode 100644 index 0000000..962621b --- /dev/null +++ b/backend/test1/question.py @@ -0,0 +1,56 @@ +import os +import random +from sqlalchemy import create_engine, MetaData, Table, select, func +from sqlalchemy.orm import sessionmaker +from dotenv import load_dotenv + +load_dotenv() + +def generate_questions(file_path, num_questions_per_table=10): + engine = create_engine(os.getenv("SQL_DATABASE_URL", "")) + metadata = MetaData() + metadata.reflect(bind=engine) + + # 定义表名及其对应的列索引和问题模板 + tables_info = { + "ProjectProperties": (0, "Attribute_Value", "{name_value}的属性值是多少?"), + "OtherFee": (0, "Amount", "{name_value}的金额是多少?"), + "FeeCollectionTable": (0, "Rate", "{name_value}的费率是多少?"), + "ProjectDivision": (0, "Total_Price", "{name_value}的合价是多少?"), + "ProjectDivisions_CostPreview": (0, "Direct_Fee", "{name_value}的直接费是多少?"), + "TotalCalculateTable": (0, "Amount", "{name_value}的金额是多少?"), + "ProjectQuantities": (0, "Code", "{name_value}的编码是多少?") + } + + questions = [] + + for table_name, (name_index, value_column, question_template) in tables_info.items(): + # 加载这张表 + table = Table(table_name, metadata, autoload_with=engine) + + # 创建会话 + Session = sessionmaker(bind=engine) + session = Session() + + # 获取列名 + name_column = table.columns.keys()[name_index] + + # 对于每个表生成num_questions_per_table个问题 + for _ in range(num_questions_per_table): + # 查询表中的随机一行,并获取名称列的值 + row = session.query(table).order_by(func.random()).first() + name_value = getattr(row, name_column) + + # 构造问题 + question = question_template.format(name_value=name_value) + questions.append(question) + + # 写入文件 + with open(file_path, 'w', encoding='utf-8') as file: + for question in questions: + file.write(question + '\n') + + +if __name__ == "__main__": + questions_file_path = "/home/bw/ctr/zjdataai-app/backend/test1/questions.txt" + generate_questions(questions_file_path) \ No newline at end of file diff --git a/backend/test1/questions.txt b/backend/test1/questions.txt new file mode 100644 index 0000000..0d4e731 --- /dev/null +++ b/backend/test1/questions.txt @@ -0,0 +1,14 @@ +question:线路参数_转角次数的属性值是多少? answer:线路参数_转角次数的属性值是64 +question:接地土石方量的属性值是多少? answer:接地土石方量的属性值是16 +question:工程监理费的金额是多少? answer:工程监理费的金额是131009.92 +question:矿产压覆评估费用的金额是多少? answer:矿产压覆评估费用的金额是0 +question:线路取费表(余物清理)的费率是多少? answer:线路取费表(余物清理)的费率是100 +question:线路取费表(拆除)的费率是多少? answer:线路取费表(拆除)的费率是100 +question:一般线路本体工程的合价是多少? answer:一般线路本体工程的合价是55105688268.5176 +question:基础工程的合价是多少? answer:基础工程的合价是49051649642.9667 +question:线路取费表(调试工程)aa的直接费是多少? answer:线路取费表(调试工程)aa的直接费是22411207942.4858 +question:线路取费表的直接费是多少? answer:线路取费表的直接费是7314300665.34141 +question:一般线路本体工程的金额是多少? answer:一般线路本体工程的金额是55105688268.5176 +question:架空输电线路本体工程的金额是多少? answer:架空输电线路本体工程的金额是55105688268.5176 +question:截止阀的编码是多少? answer:截止阀的编码是F01010101 +question:自定义主材的编码是多少? answer:自定义主材的编码是asd \ No newline at end of file diff --git a/backend/test1/test_parameters.py b/backend/test1/test_parameters.py new file mode 100644 index 0000000..fda5677 --- /dev/null +++ b/backend/test1/test_parameters.py @@ -0,0 +1,101 @@ +import os +from ctypes import cast +import sys + +from llama_index.core import VectorStoreIndex, SQLDatabase +from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine +from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex +from llama_index.readers.database import DatabaseReader +from sqlalchemy import create_engine + +from app.api.routers.chat import generate_filters +from app.engine import get_index, makeDescriptionByEngine +from app.engine.loaders.db import CustomDatabaseReader +from app.engine.vectordb import get_vector_store +from app.observability import init_observability +from app.settings import init_settings +from dotenv import load_dotenv + +load_dotenv() + +def read_questions(file_path): + questions = [] + with open(file_path, 'r', encoding='utf-8') as file: + for line in file: + if "question" in line: + question_part = line.split(":")[1].strip() # 提取 "question" 后的内容 + questions.append(question_part) + return questions + +def save_results_to_file(question, result, file_path): + with open(file_path, 'a', encoding='utf-8') as file: + file.write(f"问题: {question}\n") + file.write(f"结果: {result}\n\n") + +def main(): + # 从命令行读取questions_file_path + if len(sys.argv) < 2: + print("请提供questions.txt文件的路径") + sys.exit(1) + questions_file_path = sys.argv[1] + # 更新环境变量 + os.environ['TOP_K'] = str(5) # 向量的TOP_K值 + os.environ['LLM_TEMPERATURE'] = str(0.1) # 温度值 + os.environ['similarity_top_k'] = str(5) # SQL的TOP_K值 + + init_settings() + init_observability() + + index = get_index() + + top_k = int(os.getenv("TOP_K")) # 向量的TOP_K值 + temperature = float(os.getenv("LLM_TEMPERATURE")) # 温度值 + similarity_top_k = float(os.getenv("similarity_top_k")) # SQL的TOP_K值 + filters = generate_filters([]) + + engine = create_engine(os.getenv("SQL_DATABASE_URL", "")) + sql_database = SQLDatabase(engine) + table_schema_objs = makeDescriptionByEngine(sql_database) + table_node_mapping = SQLTableNodeMapping(sql_database) + # 创建SQL查询工具 + sql_obj_index = ObjectIndex.from_objects( + table_schema_objs, + table_node_mapping, + index_cls=VectorStoreIndex, + ) + sql_query_engine = SQLTableRetrieverQueryEngine(sql_database, + sql_obj_index.as_retriever(similarity_top_k=similarity_top_k)) + + questions = read_questions(questions_file_path) + + script_dir = os.path.dirname(os.path.abspath(__file__)) + + results_file_path = os.path.join(script_dir, "query_results.txt") + + # 如果文件为空,则写入参数值 + if os.path.getsize(results_file_path) == 0: + with open(results_file_path, 'w', encoding='utf-8') as file: + file.write(f"TOP_K: {top_k}\n") + file.write(f"LLM_TEMPERATURE: {temperature}\n") + file.write(f"similarity_top_k: {similarity_top_k}\n\n") + + # 循环执行查询 + for i, question in enumerate(questions): + print(f"Executing query {i+1}: {question}") + # query_engine = index.as_query_engine( + # similarity_top_k=top_k, filters=filters + # ) + # query_result = query_engine.query(question) + + # print(f"向量查询结果: {query_result}\n") + # save_results_to_file(question, f"向量查询结果: {query_result}", results_file_path) + + sql_query_result = sql_query_engine.query(question) + print(f"SQL查询结果: {sql_query_result}\n") + save_results_to_file(question, f"SQL查询结果: {sql_query_result}", results_file_path) + +if __name__ == "__main__": + from phoenix.trace import using_project + + with using_project("ly_zjapp_test") as obj: + main()