From 3ceb30c375e00a7898fb95c7be530a10755273c7 Mon Sep 17 00:00:00 2001 From: paituo <330435863@qq.com> Date: Thu, 22 Aug 2024 16:17:10 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E7=BC=BA=E9=99=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/engine/engine.py | 28 ++++++++++++++++++++-- backend/app/engine/loaders/db.py | 40 ++++---------------------------- 2 files changed, 30 insertions(+), 38 deletions(-) diff --git a/backend/app/engine/engine.py b/backend/app/engine/engine.py index 69428e0..cd36de2 100644 --- a/backend/app/engine/engine.py +++ b/backend/app/engine/engine.py @@ -2,17 +2,41 @@ import os from llama_index.core import SummaryIndex, SQLDatabase, VectorStoreIndex from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine -from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex +from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema from llama_index.core.query_engine import RetrieverQueryEngine from llama_index.core.response_synthesizers import ResponseMode +from llama_index.readers.database import DatabaseReader from sqlalchemy import create_engine -from app.engine import makeDescriptionByEngine from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template from app.engine.retriever.HybridRetriever import HybridRetriever from app.settings import get_node_postprocessors +def makeDescriptionByEngine(sql_database:SQLDatabase): + reader = DatabaseReader(sql_database) + table_names = sql_database.get_usable_table_names() + table_schema_objs = [] + for table_name in table_names: + columns = sql_database.get_table_columns(table_name) + if len(columns) > 150: + continue + stats_txt = "" + + if table_name == 'gongchengshuxing': + stats_txt = '该表中有以下属性:' + documents = reader.load_data(query='select name from gongchengshuxing') + for index in range(len(documents) if len(documents) < 30 else 30): + if index == 0: + continue + elif index > 1: + stats_txt += ',' + stats_txt += documents[index].text.split(':')[1] + + tbSchema = (SQLTableSchema(table_name=table_name, context_str=stats_txt)) + table_schema_objs.append(tbSchema) + + return table_schema_objs def get_Retriever(index,**kwargs): bEnableHybrid = True if os.getenv("HYBRID_ENABLED",False).title() == 'True' else False diff --git a/backend/app/engine/loaders/db.py b/backend/app/engine/loaders/db.py index 63a7c02..d6310e2 100644 --- a/backend/app/engine/loaders/db.py +++ b/backend/app/engine/loaders/db.py @@ -1,20 +1,14 @@ -import os import logging -from typing import List from typing import Any, List, Optional -from llama_index.core.readers.base import BaseReader -from llama_index.core.schema import Document -from llama_index.core.utilities.sql_wrapper import SQLDatabase -from sqlalchemy import text -from sqlalchemy.engine import Engine from llama_index.core import SQLDatabase, Document -from llama_index.core.objects import SQLTableSchema, SQLTableNodeMapping +from llama_index.core.objects import SQLTableSchema from llama_index.core.readers.base import BaseReader from llama_index.readers.database import DatabaseReader -from pydantic import BaseModel, validator -from llama_index.core.indices.vector_store import VectorStoreIndex +from pydantic import BaseModel from sqlalchemy import create_engine +from sqlalchemy import text +from sqlalchemy.engine import Engine logger = logging.getLogger(__name__) @@ -119,32 +113,6 @@ class DBLoaderConfig(BaseModel): uri: str queries: List[str] -def makeDescriptionByEngine(sql_database:SQLDatabase): - reader = DatabaseReader(sql_database) - - table_names = sql_database.get_usable_table_names() - table_schema_objs = [] - for table_name in table_names: - columns = sql_database.get_table_columns(table_name) - if len(columns) > 150: - continue - stats_txt = "" - - if table_name == 'gongchengshuxing': - stats_txt = '该表中有以下属性:' - documents = reader.load_data(query='select name from gongchengshuxing') - for index in range(len(documents) if len(documents) < 30 else 30): - if index == 0: - continue - elif index > 1: - stats_txt += ',' - stats_txt += documents[index].text.split(':')[1] - - tbSchema = (SQLTableSchema(table_name=table_name, context_str=stats_txt)) - table_schema_objs.append(tbSchema) - - return table_schema_objs - def get_db_documents(configs: list[DBLoaderConfig]): docs = []