dev #5
@@ -2,17 +2,41 @@ import os
|
|||||||
|
|
||||||
from llama_index.core import SummaryIndex, SQLDatabase, VectorStoreIndex
|
from llama_index.core import SummaryIndex, SQLDatabase, VectorStoreIndex
|
||||||
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
|
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
|
||||||
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex
|
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema
|
||||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||||
from llama_index.core.response_synthesizers import ResponseMode
|
from llama_index.core.response_synthesizers import ResponseMode
|
||||||
|
from llama_index.readers.database import DatabaseReader
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
|
|
||||||
from app.engine import makeDescriptionByEngine
|
|
||||||
from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template
|
from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template
|
||||||
from app.engine.retriever.HybridRetriever import HybridRetriever
|
from app.engine.retriever.HybridRetriever import HybridRetriever
|
||||||
from app.settings import get_node_postprocessors
|
from app.settings import get_node_postprocessors
|
||||||
|
|
||||||
|
def makeDescriptionByEngine(sql_database:SQLDatabase):
|
||||||
|
reader = DatabaseReader(sql_database)
|
||||||
|
|
||||||
|
table_names = sql_database.get_usable_table_names()
|
||||||
|
table_schema_objs = []
|
||||||
|
for table_name in table_names:
|
||||||
|
columns = sql_database.get_table_columns(table_name)
|
||||||
|
if len(columns) > 150:
|
||||||
|
continue
|
||||||
|
stats_txt = ""
|
||||||
|
|
||||||
|
if table_name == 'gongchengshuxing':
|
||||||
|
stats_txt = '该表中有以下属性:'
|
||||||
|
documents = reader.load_data(query='select name from gongchengshuxing')
|
||||||
|
for index in range(len(documents) if len(documents) < 30 else 30):
|
||||||
|
if index == 0:
|
||||||
|
continue
|
||||||
|
elif index > 1:
|
||||||
|
stats_txt += ','
|
||||||
|
stats_txt += documents[index].text.split(':')[1]
|
||||||
|
|
||||||
|
tbSchema = (SQLTableSchema(table_name=table_name, context_str=stats_txt))
|
||||||
|
table_schema_objs.append(tbSchema)
|
||||||
|
|
||||||
|
return table_schema_objs
|
||||||
|
|
||||||
def get_Retriever(index,**kwargs):
|
def get_Retriever(index,**kwargs):
|
||||||
bEnableHybrid = True if os.getenv("HYBRID_ENABLED",False).title() == 'True' else False
|
bEnableHybrid = True if os.getenv("HYBRID_ENABLED",False).title() == 'True' else False
|
||||||
|
|||||||
@@ -1,20 +1,14 @@
|
|||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from llama_index.core.readers.base import BaseReader
|
|
||||||
from llama_index.core.schema import Document
|
|
||||||
from llama_index.core.utilities.sql_wrapper import SQLDatabase
|
|
||||||
from sqlalchemy import text
|
|
||||||
from sqlalchemy.engine import Engine
|
|
||||||
from llama_index.core import SQLDatabase, Document
|
from llama_index.core import SQLDatabase, Document
|
||||||
from llama_index.core.objects import SQLTableSchema, SQLTableNodeMapping
|
from llama_index.core.objects import SQLTableSchema
|
||||||
from llama_index.core.readers.base import BaseReader
|
from llama_index.core.readers.base import BaseReader
|
||||||
from llama_index.readers.database import DatabaseReader
|
from llama_index.readers.database import DatabaseReader
|
||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel
|
||||||
from llama_index.core.indices.vector_store import VectorStoreIndex
|
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -119,32 +113,6 @@ class DBLoaderConfig(BaseModel):
|
|||||||
uri: str
|
uri: str
|
||||||
queries: List[str]
|
queries: List[str]
|
||||||
|
|
||||||
def makeDescriptionByEngine(sql_database:SQLDatabase):
|
|
||||||
reader = DatabaseReader(sql_database)
|
|
||||||
|
|
||||||
table_names = sql_database.get_usable_table_names()
|
|
||||||
table_schema_objs = []
|
|
||||||
for table_name in table_names:
|
|
||||||
columns = sql_database.get_table_columns(table_name)
|
|
||||||
if len(columns) > 150:
|
|
||||||
continue
|
|
||||||
stats_txt = ""
|
|
||||||
|
|
||||||
if table_name == 'gongchengshuxing':
|
|
||||||
stats_txt = '该表中有以下属性:'
|
|
||||||
documents = reader.load_data(query='select name from gongchengshuxing')
|
|
||||||
for index in range(len(documents) if len(documents) < 30 else 30):
|
|
||||||
if index == 0:
|
|
||||||
continue
|
|
||||||
elif index > 1:
|
|
||||||
stats_txt += ','
|
|
||||||
stats_txt += documents[index].text.split(':')[1]
|
|
||||||
|
|
||||||
tbSchema = (SQLTableSchema(table_name=table_name, context_str=stats_txt))
|
|
||||||
table_schema_objs.append(tbSchema)
|
|
||||||
|
|
||||||
return table_schema_objs
|
|
||||||
|
|
||||||
def get_db_documents(configs: list[DBLoaderConfig]):
|
def get_db_documents(configs: list[DBLoaderConfig]):
|
||||||
docs = []
|
docs = []
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user