dev #5

Closed
ly wants to merge 93 commits from dev into dev-db
2 changed files with 30 additions and 38 deletions
Showing only changes of commit 3ceb30c375 - Show all commits
+26 -2
View File
@@ -2,17 +2,41 @@ import os
from llama_index.core import SummaryIndex, SQLDatabase, VectorStoreIndex from llama_index.core import SummaryIndex, SQLDatabase, VectorStoreIndex
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema
from llama_index.core.query_engine import RetrieverQueryEngine from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.response_synthesizers import ResponseMode from llama_index.core.response_synthesizers import ResponseMode
from llama_index.readers.database import DatabaseReader
from sqlalchemy import create_engine from sqlalchemy import create_engine
from app.engine import makeDescriptionByEngine
from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template
from app.engine.retriever.HybridRetriever import HybridRetriever from app.engine.retriever.HybridRetriever import HybridRetriever
from app.settings import get_node_postprocessors from app.settings import get_node_postprocessors
def makeDescriptionByEngine(sql_database:SQLDatabase):
reader = DatabaseReader(sql_database)
table_names = sql_database.get_usable_table_names()
table_schema_objs = []
for table_name in table_names:
columns = sql_database.get_table_columns(table_name)
if len(columns) > 150:
continue
stats_txt = ""
if table_name == 'gongchengshuxing':
stats_txt = '该表中有以下属性:'
documents = reader.load_data(query='select name from gongchengshuxing')
for index in range(len(documents) if len(documents) < 30 else 30):
if index == 0:
continue
elif index > 1:
stats_txt += ','
stats_txt += documents[index].text.split(':')[1]
tbSchema = (SQLTableSchema(table_name=table_name, context_str=stats_txt))
table_schema_objs.append(tbSchema)
return table_schema_objs
def get_Retriever(index,**kwargs): def get_Retriever(index,**kwargs):
bEnableHybrid = True if os.getenv("HYBRID_ENABLED",False).title() == 'True' else False bEnableHybrid = True if os.getenv("HYBRID_ENABLED",False).title() == 'True' else False
+4 -36
View File
@@ -1,20 +1,14 @@
import os
import logging import logging
from typing import List
from typing import Any, List, Optional from typing import Any, List, Optional
from llama_index.core.readers.base import BaseReader
from llama_index.core.schema import Document
from llama_index.core.utilities.sql_wrapper import SQLDatabase
from sqlalchemy import text
from sqlalchemy.engine import Engine
from llama_index.core import SQLDatabase, Document from llama_index.core import SQLDatabase, Document
from llama_index.core.objects import SQLTableSchema, SQLTableNodeMapping from llama_index.core.objects import SQLTableSchema
from llama_index.core.readers.base import BaseReader from llama_index.core.readers.base import BaseReader
from llama_index.readers.database import DatabaseReader from llama_index.readers.database import DatabaseReader
from pydantic import BaseModel, validator from pydantic import BaseModel
from llama_index.core.indices.vector_store import VectorStoreIndex
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy import text
from sqlalchemy.engine import Engine
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -119,32 +113,6 @@ class DBLoaderConfig(BaseModel):
uri: str uri: str
queries: List[str] queries: List[str]
def makeDescriptionByEngine(sql_database:SQLDatabase):
reader = DatabaseReader(sql_database)
table_names = sql_database.get_usable_table_names()
table_schema_objs = []
for table_name in table_names:
columns = sql_database.get_table_columns(table_name)
if len(columns) > 150:
continue
stats_txt = ""
if table_name == 'gongchengshuxing':
stats_txt = '该表中有以下属性:'
documents = reader.load_data(query='select name from gongchengshuxing')
for index in range(len(documents) if len(documents) < 30 else 30):
if index == 0:
continue
elif index > 1:
stats_txt += ','
stats_txt += documents[index].text.split(':')[1]
tbSchema = (SQLTableSchema(table_name=table_name, context_str=stats_txt))
table_schema_objs.append(tbSchema)
return table_schema_objs
def get_db_documents(configs: list[DBLoaderConfig]): def get_db_documents(configs: list[DBLoaderConfig]):
docs = [] docs = []