优化了提示词
This commit is contained in:
@@ -0,0 +1,108 @@
|
||||
import os
|
||||
|
||||
from llama_index.core import SummaryIndex, SQLDatabase, VectorStoreIndex
|
||||
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
|
||||
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema
|
||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
from llama_index.core.response_synthesizers import ResponseMode
|
||||
from llama_index.readers.database import DatabaseReader
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template
|
||||
from app.engine.retriever.HybridRetriever import HybridRetriever
|
||||
from app.settings import get_node_postprocessors
|
||||
|
||||
def makeDescriptionByEngine(sql_database:SQLDatabase):
|
||||
reader = DatabaseReader(sql_database)
|
||||
|
||||
table_names = sql_database.get_usable_table_names()
|
||||
table_schema_objs = []
|
||||
for table_name in table_names:
|
||||
columns = sql_database.get_table_columns(table_name)
|
||||
if len(columns) > 150:
|
||||
continue
|
||||
stats_txt = ""
|
||||
|
||||
if table_name == 'gongchengshuxing':
|
||||
stats_txt = '该表中有以下属性:'
|
||||
documents = reader.load_data(query='select name from gongchengshuxing')
|
||||
for index in range(len(documents) if len(documents) < 30 else 30):
|
||||
if index == 0:
|
||||
continue
|
||||
elif index > 1:
|
||||
stats_txt += ','
|
||||
stats_txt += documents[index].text.split(':')[1]
|
||||
|
||||
tbSchema = (SQLTableSchema(table_name=table_name, context_str=stats_txt))
|
||||
table_schema_objs.append(tbSchema)
|
||||
|
||||
return table_schema_objs
|
||||
|
||||
def get_Retriever(index,**kwargs):
|
||||
strEnableHybrid = os.getenv("HYBRID_ENABLED",'False')
|
||||
bEnableHybrid = True if strEnableHybrid is not None and strEnableHybrid.title() == 'True' else False
|
||||
if bEnableHybrid:
|
||||
alpha = float(os.getenv("HYBRID_ALPHA", "0.5"))
|
||||
retriever = HybridRetriever(index,alpha = alpha,**kwargs)
|
||||
else:
|
||||
retriever = index.as_retriever(**kwargs)
|
||||
return retriever
|
||||
|
||||
|
||||
sql_database = None
|
||||
sql_obj_index = None
|
||||
|
||||
# Create a summary query engine
|
||||
def create_summary_query_engine(top_k=3, use_reranker=False, filters=None):
|
||||
global sql_obj_index
|
||||
global sql_database
|
||||
if sql_obj_index is None or sql_database is None:
|
||||
sqlengine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
|
||||
sql_database = SQLDatabase(sqlengine)
|
||||
table_schema_objs = makeDescriptionByEngine(sql_database)
|
||||
table_node_mapping = SQLTableNodeMapping(sql_database)
|
||||
|
||||
sql_obj_index = ObjectIndex.from_objects(
|
||||
table_schema_objs,
|
||||
table_node_mapping,
|
||||
index_cls=VectorStoreIndex,
|
||||
)
|
||||
|
||||
# 创建SQL查询工具
|
||||
sql_query_engine = SQLTableRetrieverQueryEngine(sql_database,
|
||||
sql_obj_index.as_retriever(similarity_top_k=top_k),
|
||||
verbose=True,
|
||||
)
|
||||
return sql_query_engine
|
||||
|
||||
# Create a summary query engine
|
||||
def create_summary_query_engine(index, top_k=3, use_reranker=False, filters=None):
|
||||
summary_index = SummaryIndex(index.vector_store.get_nodes(node_ids=None))
|
||||
summary_query_engine = summary_index.as_query_engine(
|
||||
response_mode=ResponseMode.TREE_SUMMARIZE,
|
||||
use_async=True,
|
||||
streaming=True,
|
||||
)
|
||||
return summary_query_engine
|
||||
|
||||
# Create a query engine
|
||||
def create_query_engine(index, top_k=3, use_reranker=False, filters=None):
|
||||
# 创建向量检索查询工具
|
||||
postprocess = None
|
||||
if use_reranker:
|
||||
postprocess = get_node_postprocessors()
|
||||
|
||||
query_engine = RetrieverQueryEngine.from_args(
|
||||
get_Retriever(index,
|
||||
similarity_top_k=top_k,
|
||||
filters=filters),
|
||||
text_qa_template=text_qa_template,
|
||||
refine_template=refine_template,
|
||||
summary_template = summary_template,
|
||||
simple_template = simple_template,
|
||||
node_postprocessors=postprocess,
|
||||
use_async=True,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
return query_engine
|
||||
Reference in New Issue
Block a user