dev #1
@@ -0,0 +1,19 @@
|
|||||||
|
import chromadb
|
||||||
|
|
||||||
|
# 创建 ChromaDB 客户端
|
||||||
|
chroma_client = chromadb.PersistentClient(path="/home/bw/ctr/zjdataai-app/backend/storage_vector-1/")
|
||||||
|
|
||||||
|
# 获取已存在的 "default" 集合
|
||||||
|
collection = chroma_client.get_collection(name="default")
|
||||||
|
|
||||||
|
# 获取集合中的所有数据
|
||||||
|
results = collection.get(
|
||||||
|
include=['documents', 'metadatas', 'embeddings'] # 只包含允许的选项
|
||||||
|
)
|
||||||
|
|
||||||
|
# 将结果转换为字符串并保存到txt文件中
|
||||||
|
with open('/home/bw/ctr/zjdataai-app/backend/test1/query_results-1.txt', 'w', encoding='utf-8') as file:
|
||||||
|
file.write(str(results))
|
||||||
|
|
||||||
|
# 打印结果
|
||||||
|
print("查询结果已保存到 query_results.txt 文件中。")
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
@@ -0,0 +1,139 @@
|
|||||||
|
import re
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from ctypes import cast
|
||||||
|
|
||||||
|
from llama_index.core import VectorStoreIndex, SQLDatabase
|
||||||
|
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
|
||||||
|
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex
|
||||||
|
from llama_index.readers.database import DatabaseReader
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
|
||||||
|
from app.api.routers.chat import generate_filters
|
||||||
|
from app.engine import get_index, makeDescriptionByEngine
|
||||||
|
from app.engine.loaders.db import CustomDatabaseReader
|
||||||
|
from app.engine.vectordb import get_vector_store
|
||||||
|
from app.observability import init_observability
|
||||||
|
from app.settings import init_settings
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
def read_questions_and_answers(file_path):
|
||||||
|
questions_and_answers = []
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as file:
|
||||||
|
for line in file:
|
||||||
|
if "question" in line and "answer" in line:
|
||||||
|
question_part = line.split(":")[1].strip() # 提取 question
|
||||||
|
answer_part = re.search(r"answer:.*?(\d+)", line) # 使用正则提取 answer 中的数字
|
||||||
|
if answer_part:
|
||||||
|
answer_value = answer_part.group(1)
|
||||||
|
questions_and_answers.append((question_part, answer_value))
|
||||||
|
return questions_and_answers
|
||||||
|
|
||||||
|
def save_results_to_file(question, result, correct_answer, file_path):
|
||||||
|
with open(file_path, 'a', encoding='utf-8') as file:
|
||||||
|
file.write(f"问题: {question}\n")
|
||||||
|
file.write(f"查询结果: {result}\n")
|
||||||
|
file.write(f"正确答案: {correct_answer}\n\n")
|
||||||
|
|
||||||
|
def log_incorrect_answers(question, correct_answer, result, log_file_path):
|
||||||
|
with open(log_file_path, 'a', encoding='utf-8') as file:
|
||||||
|
file.write(f"错误问题: {question}\n")
|
||||||
|
file.write(f"正确答案: {correct_answer}\n")
|
||||||
|
file.write(f"查询结果: {result}\n\n")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 从命令行读取questions_file_path和查询类型
|
||||||
|
if len(sys.argv) < 3:
|
||||||
|
print("请提供questions.txt文件的路径和查询类型(vector 或 sql)")
|
||||||
|
sys.exit(1)
|
||||||
|
questions_file_path = sys.argv[1]
|
||||||
|
query_type = sys.argv[2].lower()
|
||||||
|
|
||||||
|
# 获取脚本所在的目录
|
||||||
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
# 设置结果文件和日志文件的路径
|
||||||
|
results_file_path = os.path.join(script_dir, "query_results.txt")
|
||||||
|
log_file_path = os.path.join(script_dir, "incorrect_answers_log.txt")
|
||||||
|
|
||||||
|
# 更新环境变量
|
||||||
|
os.environ['TOP_K'] = str(5) # 向量的TOP_K值
|
||||||
|
os.environ['LLM_TEMPERATURE'] = str(0.1) # 温度值
|
||||||
|
os.environ['similarity_top_k'] = str(5) # SQL的TOP_K值
|
||||||
|
|
||||||
|
init_settings()
|
||||||
|
init_observability()
|
||||||
|
|
||||||
|
index = get_index()
|
||||||
|
|
||||||
|
top_k = int(os.getenv("TOP_K")) # 向量的TOP_K值
|
||||||
|
temperature = float(os.getenv("LLM_TEMPERATURE")) # 温度值
|
||||||
|
similarity_top_k = float(os.getenv("similarity_top_k")) # SQL的TOP_K值
|
||||||
|
filters = generate_filters([])
|
||||||
|
|
||||||
|
engine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
|
||||||
|
sql_database = SQLDatabase(engine)
|
||||||
|
table_schema_objs = makeDescriptionByEngine(sql_database)
|
||||||
|
table_node_mapping = SQLTableNodeMapping(sql_database)
|
||||||
|
# 创建SQL查询工具
|
||||||
|
sql_obj_index = ObjectIndex.from_objects(
|
||||||
|
table_schema_objs,
|
||||||
|
table_node_mapping,
|
||||||
|
index_cls=VectorStoreIndex,
|
||||||
|
)
|
||||||
|
sql_query_engine = SQLTableRetrieverQueryEngine(sql_database,
|
||||||
|
sql_obj_index.as_retriever(similarity_top_k=similarity_top_k))
|
||||||
|
|
||||||
|
questions_and_answers = read_questions_and_answers(questions_file_path)
|
||||||
|
|
||||||
|
# 如果文件为空,则写入参数值
|
||||||
|
if os.path.getsize(results_file_path) == 0:
|
||||||
|
with open(results_file_path, 'w', encoding='utf-8') as file:
|
||||||
|
file.write(f"TOP_K: {top_k}\n")
|
||||||
|
file.write(f"LLM_TEMPERATURE: {temperature}\n")
|
||||||
|
file.write(f"similarity_top_k: {similarity_top_k}\n\n")
|
||||||
|
|
||||||
|
# 循环执行查询
|
||||||
|
for i, (question, correct_answer) in enumerate(questions_and_answers):
|
||||||
|
print(f"Executing query {i+1}: {question}")
|
||||||
|
|
||||||
|
if query_type == "vector":
|
||||||
|
query_engine = index.as_query_engine(
|
||||||
|
similarity_top_k=top_k, filters=filters
|
||||||
|
)
|
||||||
|
query_result = query_engine.query(question)
|
||||||
|
print(f"向量查询结果: {query_result}\n")
|
||||||
|
save_results_to_file(question, f"向量查询结果: {query_result}", correct_answer, results_file_path)
|
||||||
|
|
||||||
|
# 提取向量查询结果中的数字进行匹配
|
||||||
|
query_result_number = re.search(r"(\d+)", str(query_result))
|
||||||
|
elif query_type == "sql":
|
||||||
|
sql_query_result = sql_query_engine.query(question)
|
||||||
|
print(f"SQL查询结果: {sql_query_result}\n")
|
||||||
|
save_results_to_file(question, f"SQL查询结果: {sql_query_result}", correct_answer, results_file_path)
|
||||||
|
|
||||||
|
# 提取SQL查询结果中的数字进行匹配
|
||||||
|
query_result_number = re.search(r"(\d+)", str(sql_query_result))
|
||||||
|
else:
|
||||||
|
print("无效的查询类型,请选择 'vector' 或 'sql'")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if query_result_number:
|
||||||
|
query_number = query_result_number.group(1)
|
||||||
|
|
||||||
|
# 判断查询结果是否与正确答案匹配
|
||||||
|
if query_number == correct_answer:
|
||||||
|
save_results_to_file(question, query_number, correct_answer, results_file_path)
|
||||||
|
else:
|
||||||
|
log_incorrect_answers(question, correct_answer, query_number, log_file_path)
|
||||||
|
else:
|
||||||
|
log_incorrect_answers(question, correct_answer, "未找到有效数字", log_file_path)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from phoenix.trace import using_project
|
||||||
|
|
||||||
|
with using_project("ly_zjapp_test") as obj:
|
||||||
|
main()
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
from sqlalchemy import create_engine, MetaData, Table, select, func
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
def generate_questions(file_path, num_questions_per_table=10):
|
||||||
|
engine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
|
||||||
|
metadata = MetaData()
|
||||||
|
metadata.reflect(bind=engine)
|
||||||
|
|
||||||
|
# 定义表名及其对应的列索引和问题模板
|
||||||
|
tables_info = {
|
||||||
|
"ProjectProperties": (0, "Attribute_Value", "{name_value}的属性值是多少?"),
|
||||||
|
"OtherFee": (0, "Amount", "{name_value}的金额是多少?"),
|
||||||
|
"FeeCollectionTable": (0, "Rate", "{name_value}的费率是多少?"),
|
||||||
|
"ProjectDivision": (0, "Total_Price", "{name_value}的合价是多少?"),
|
||||||
|
"ProjectDivisions_CostPreview": (0, "Direct_Fee", "{name_value}的直接费是多少?"),
|
||||||
|
"TotalCalculateTable": (0, "Amount", "{name_value}的金额是多少?"),
|
||||||
|
"ProjectQuantities": (0, "Code", "{name_value}的编码是多少?")
|
||||||
|
}
|
||||||
|
|
||||||
|
questions = []
|
||||||
|
|
||||||
|
for table_name, (name_index, value_column, question_template) in tables_info.items():
|
||||||
|
# 加载这张表
|
||||||
|
table = Table(table_name, metadata, autoload_with=engine)
|
||||||
|
|
||||||
|
# 创建会话
|
||||||
|
Session = sessionmaker(bind=engine)
|
||||||
|
session = Session()
|
||||||
|
|
||||||
|
# 获取列名
|
||||||
|
name_column = table.columns.keys()[name_index]
|
||||||
|
|
||||||
|
# 对于每个表生成num_questions_per_table个问题
|
||||||
|
for _ in range(num_questions_per_table):
|
||||||
|
# 查询表中的随机一行,并获取名称列的值
|
||||||
|
row = session.query(table).order_by(func.random()).first()
|
||||||
|
name_value = getattr(row, name_column)
|
||||||
|
|
||||||
|
# 构造问题
|
||||||
|
question = question_template.format(name_value=name_value)
|
||||||
|
questions.append(question)
|
||||||
|
|
||||||
|
# 写入文件
|
||||||
|
with open(file_path, 'w', encoding='utf-8') as file:
|
||||||
|
for question in questions:
|
||||||
|
file.write(question + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
questions_file_path = "/home/bw/ctr/zjdataai-app/backend/test1/questions.txt"
|
||||||
|
generate_questions(questions_file_path)
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
question:线路参数_转角次数的属性值是多少? answer:线路参数_转角次数的属性值是64
|
||||||
|
question:接地土石方量的属性值是多少? answer:接地土石方量的属性值是16
|
||||||
|
question:工程监理费的金额是多少? answer:工程监理费的金额是131009.92
|
||||||
|
question:矿产压覆评估费用的金额是多少? answer:矿产压覆评估费用的金额是0
|
||||||
|
question:线路取费表(余物清理)的费率是多少? answer:线路取费表(余物清理)的费率是100
|
||||||
|
question:线路取费表(拆除)的费率是多少? answer:线路取费表(拆除)的费率是100
|
||||||
|
question:一般线路本体工程的合价是多少? answer:一般线路本体工程的合价是55105688268.5176
|
||||||
|
question:基础工程的合价是多少? answer:基础工程的合价是49051649642.9667
|
||||||
|
question:线路取费表(调试工程)aa的直接费是多少? answer:线路取费表(调试工程)aa的直接费是22411207942.4858
|
||||||
|
question:线路取费表的直接费是多少? answer:线路取费表的直接费是7314300665.34141
|
||||||
|
question:一般线路本体工程的金额是多少? answer:一般线路本体工程的金额是55105688268.5176
|
||||||
|
question:架空输电线路本体工程的金额是多少? answer:架空输电线路本体工程的金额是55105688268.5176
|
||||||
|
question:截止阀的编码是多少? answer:截止阀的编码是F01010101
|
||||||
|
question:自定义主材的编码是多少? answer:自定义主材的编码是asd
|
||||||
@@ -0,0 +1,101 @@
|
|||||||
|
import os
|
||||||
|
from ctypes import cast
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from llama_index.core import VectorStoreIndex, SQLDatabase
|
||||||
|
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
|
||||||
|
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex
|
||||||
|
from llama_index.readers.database import DatabaseReader
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
|
||||||
|
from app.api.routers.chat import generate_filters
|
||||||
|
from app.engine import get_index, makeDescriptionByEngine
|
||||||
|
from app.engine.loaders.db import CustomDatabaseReader
|
||||||
|
from app.engine.vectordb import get_vector_store
|
||||||
|
from app.observability import init_observability
|
||||||
|
from app.settings import init_settings
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
def read_questions(file_path):
|
||||||
|
questions = []
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as file:
|
||||||
|
for line in file:
|
||||||
|
if "question" in line:
|
||||||
|
question_part = line.split(":")[1].strip() # 提取 "question" 后的内容
|
||||||
|
questions.append(question_part)
|
||||||
|
return questions
|
||||||
|
|
||||||
|
def save_results_to_file(question, result, file_path):
|
||||||
|
with open(file_path, 'a', encoding='utf-8') as file:
|
||||||
|
file.write(f"问题: {question}\n")
|
||||||
|
file.write(f"结果: {result}\n\n")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 从命令行读取questions_file_path
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
print("请提供questions.txt文件的路径")
|
||||||
|
sys.exit(1)
|
||||||
|
questions_file_path = sys.argv[1]
|
||||||
|
# 更新环境变量
|
||||||
|
os.environ['TOP_K'] = str(5) # 向量的TOP_K值
|
||||||
|
os.environ['LLM_TEMPERATURE'] = str(0.1) # 温度值
|
||||||
|
os.environ['similarity_top_k'] = str(5) # SQL的TOP_K值
|
||||||
|
|
||||||
|
init_settings()
|
||||||
|
init_observability()
|
||||||
|
|
||||||
|
index = get_index()
|
||||||
|
|
||||||
|
top_k = int(os.getenv("TOP_K")) # 向量的TOP_K值
|
||||||
|
temperature = float(os.getenv("LLM_TEMPERATURE")) # 温度值
|
||||||
|
similarity_top_k = float(os.getenv("similarity_top_k")) # SQL的TOP_K值
|
||||||
|
filters = generate_filters([])
|
||||||
|
|
||||||
|
engine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
|
||||||
|
sql_database = SQLDatabase(engine)
|
||||||
|
table_schema_objs = makeDescriptionByEngine(sql_database)
|
||||||
|
table_node_mapping = SQLTableNodeMapping(sql_database)
|
||||||
|
# 创建SQL查询工具
|
||||||
|
sql_obj_index = ObjectIndex.from_objects(
|
||||||
|
table_schema_objs,
|
||||||
|
table_node_mapping,
|
||||||
|
index_cls=VectorStoreIndex,
|
||||||
|
)
|
||||||
|
sql_query_engine = SQLTableRetrieverQueryEngine(sql_database,
|
||||||
|
sql_obj_index.as_retriever(similarity_top_k=similarity_top_k))
|
||||||
|
|
||||||
|
questions = read_questions(questions_file_path)
|
||||||
|
|
||||||
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
results_file_path = os.path.join(script_dir, "query_results.txt")
|
||||||
|
|
||||||
|
# 如果文件为空,则写入参数值
|
||||||
|
if os.path.getsize(results_file_path) == 0:
|
||||||
|
with open(results_file_path, 'w', encoding='utf-8') as file:
|
||||||
|
file.write(f"TOP_K: {top_k}\n")
|
||||||
|
file.write(f"LLM_TEMPERATURE: {temperature}\n")
|
||||||
|
file.write(f"similarity_top_k: {similarity_top_k}\n\n")
|
||||||
|
|
||||||
|
# 循环执行查询
|
||||||
|
for i, question in enumerate(questions):
|
||||||
|
print(f"Executing query {i+1}: {question}")
|
||||||
|
# query_engine = index.as_query_engine(
|
||||||
|
# similarity_top_k=top_k, filters=filters
|
||||||
|
# )
|
||||||
|
# query_result = query_engine.query(question)
|
||||||
|
|
||||||
|
# print(f"向量查询结果: {query_result}\n")
|
||||||
|
# save_results_to_file(question, f"向量查询结果: {query_result}", results_file_path)
|
||||||
|
|
||||||
|
sql_query_result = sql_query_engine.query(question)
|
||||||
|
print(f"SQL查询结果: {sql_query_result}\n")
|
||||||
|
save_results_to_file(question, f"SQL查询结果: {sql_query_result}", results_file_path)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from phoenix.trace import using_project
|
||||||
|
|
||||||
|
with using_project("ly_zjapp_test") as obj:
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user