Files
zjdataai-app/backend/app/engine/engine.py
T
2024-09-19 11:38:35 +08:00

159 lines
6.2 KiB
Python

import os
from llama_index.core import SummaryIndex, SQLDatabase, VectorStoreIndex
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.response_synthesizers import ResponseMode
from llama_index.readers.database import DatabaseReader
from sqlalchemy import create_engine
from util.register import *
from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template
from app.engine.retriever.HybridRetriever import HybridRetriever
from app.engine.response.treeSummResponse import CustomTreeResponse
from llama_index.core.settings import Settings
from llama_index.core.indices.property_graph import LLMSynonymRetriever,VectorContextRetriever
from llama_index.core import PropertyGraphIndex
ModelPlateCategory = '模型平台'
def get_node_postprocessors():
rerank_enabled = os.getenv("RERANK_ENABLED").title()
if rerank_enabled is None or rerank_enabled == 'False':
return []
Rerank_provider = os.getenv("RERANK_PROVIDER")
modelPaltCls = ClsRegister.get(ModelPlateCategory,Rerank_provider)
postprocess = None
if modelPaltCls is not None:
modelPalt = modelPaltCls()
postprocess = modelPalt.rerank()
else:
raise ValueError(f"Invalid rerank provider: {Rerank_provider}")
return postprocess
def makeDescriptionByEngine(sql_database:SQLDatabase):
reader = DatabaseReader(sql_database)
table_names = sql_database.get_usable_table_names()
table_schema_objs = []
for table_name in table_names:
columns = sql_database.get_table_columns(table_name)
if len(columns) > 150:
continue
stats_txt = ""
if table_name == 'gongchengshuxing':
stats_txt = '该表中有以下属性:'
documents = reader.load_data(query='select name from gongchengshuxing')
for index in range(len(documents) if len(documents) < 30 else 30):
if index == 0:
continue
elif index > 1:
stats_txt += ','
stats_txt += documents[index].text.split(':')[1]
tbSchema = (SQLTableSchema(table_name=table_name, context_str=stats_txt))
table_schema_objs.append(tbSchema)
return table_schema_objs
def get_Retriever(index,**kwargs):
strEnableHybrid = os.getenv("HYBRID_ENABLED",'False')
bEnableHybrid = True if strEnableHybrid is not None and strEnableHybrid.title() == 'True' else False
if bEnableHybrid:
alpha = float(os.getenv("HYBRID_ALPHA", "0.5"))
retriever = HybridRetriever(index,alpha = alpha,**kwargs)
else:
retriever = index.as_retriever(**kwargs)
return retriever
def get_synthesizer():
return CustomTreeResponse(
llm=Settings.llm,
summary_template=summary_template,
use_async=True,
streaming=False,
)
sql_database = None
sql_obj_index = None
# Create a summary query engine
def create_summary_query_engine(top_k=3, use_reranker=False, filters=None):
global sql_obj_index
global sql_database
if sql_obj_index is None or sql_database is None:
sqlengine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
sql_database = SQLDatabase(sqlengine)
table_schema_objs = makeDescriptionByEngine(sql_database)
table_node_mapping = SQLTableNodeMapping(sql_database)
sql_obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
index_cls=VectorStoreIndex,
)
# 创建SQL查询工具
sql_query_engine = SQLTableRetrieverQueryEngine(sql_database,
sql_obj_index.as_retriever(similarity_top_k=top_k),
verbose=True,
)
return sql_query_engine
# Create a summary query engine
def create_summary_query_engine(index, top_k=3, use_reranker=False, filters=None):
summary_index = SummaryIndex(index.vector_store.get_nodes(node_ids=None))
summary_query_engine = summary_index.as_query_engine(
response_mode=ResponseMode.TREE_SUMMARIZE,
use_async=True,
streaming=False,
)
return summary_query_engine
# Create a query engine
def create_query_engine(index,top_k=3, use_reranker=False, filters=None, response_mode=None):
# 创建向量检索查询工具
postprocess = None
if use_reranker:
postprocess = get_node_postprocessors()
llm_query = os.getenv('LLM_QUERY_WAY','rag')
if llm_query == 'graph':
graphIndex:PropertyGraphIndex = index
synonym_retriver = LLMSynonymRetriever(graphIndex.property_graph_store,
llm=Settings.llm,
include_text=False
)
if graphIndex.property_graph_store.supports_vector_queries:
vector_store = None
else:
vector_store = graphIndex.vector_store
vector_retriver = VectorContextRetriever(graphIndex.property_graph_store,
vector_store = vector_store,
embed_model=Settings.embed_model,
similarity_top_k=top_k,
include_text=False
)
retriever = graphIndex.as_retriever(sub_retrievers=[synonym_retriver,vector_retriver])
else:
retriever = get_Retriever(index,
similarity_top_k=top_k,
filters=filters)
query_engine = RetrieverQueryEngine.from_args(
retriever = retriever,
text_qa_template=text_qa_template,
refine_template=refine_template,
summary_template = summary_template,
simple_template = simple_template,
node_postprocessors=postprocess,
use_async=True,
streaming=False,
response_mode = response_mode
)
return query_engine