import os from llama_index.core import SummaryIndex, SQLDatabase, VectorStoreIndex from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema from llama_index.core.query_engine import RetrieverQueryEngine from llama_index.core.response_synthesizers import ResponseMode from llama_index.readers.database import DatabaseReader from sqlalchemy import create_engine from util.register import * from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template from app.engine.retriever.HybridRetriever import HybridRetriever from app.engine.response.treeSummResponse import CustomTreeResponse from llama_index.core.settings import Settings from llama_index.core.indices.property_graph import LLMSynonymRetriever,VectorContextRetriever from llama_index.core import PropertyGraphIndex ModelPlateCategory = '模型平台' def get_node_postprocessors(): rerank_enabled = os.getenv("RERANK_ENABLED").title() if rerank_enabled is None or rerank_enabled == 'False': return [] Rerank_provider = os.getenv("RERANK_PROVIDER") modelPaltCls = ClsRegister.get(ModelPlateCategory,Rerank_provider) postprocess = None if modelPaltCls is not None: modelPalt = modelPaltCls() postprocess = modelPalt.rerank() else: raise ValueError(f"Invalid rerank provider: {Rerank_provider}") return postprocess def makeDescriptionByEngine(sql_database:SQLDatabase): reader = DatabaseReader(sql_database) table_names = sql_database.get_usable_table_names() table_schema_objs = [] for table_name in table_names: columns = sql_database.get_table_columns(table_name) if len(columns) > 150: continue stats_txt = "" if table_name == 'gongchengshuxing': stats_txt = '该表中有以下属性:' documents = reader.load_data(query='select name from gongchengshuxing') for index in range(len(documents) if len(documents) < 30 else 30): if index == 0: continue elif index > 1: stats_txt += ',' stats_txt += documents[index].text.split(':')[1] tbSchema = (SQLTableSchema(table_name=table_name, context_str=stats_txt)) table_schema_objs.append(tbSchema) return table_schema_objs def get_Retriever(index,**kwargs): strEnableHybrid = os.getenv("HYBRID_ENABLED",'False') bEnableHybrid = True if strEnableHybrid is not None and strEnableHybrid.title() == 'True' else False if bEnableHybrid: alpha = float(os.getenv("HYBRID_ALPHA", "0.5")) retriever = HybridRetriever(index,alpha = alpha,**kwargs) else: retriever = index.as_retriever(**kwargs) return retriever def get_synthesizer(): return CustomTreeResponse( llm=Settings.llm, summary_template=summary_template, use_async=True, streaming=False, ) sql_database = None sql_obj_index = None # Create a summary query engine def create_summary_query_engine(top_k=3, use_reranker=False, filters=None): global sql_obj_index global sql_database if sql_obj_index is None or sql_database is None: sqlengine = create_engine(os.getenv("SQL_DATABASE_URL", "")) sql_database = SQLDatabase(sqlengine) table_schema_objs = makeDescriptionByEngine(sql_database) table_node_mapping = SQLTableNodeMapping(sql_database) sql_obj_index = ObjectIndex.from_objects( table_schema_objs, table_node_mapping, index_cls=VectorStoreIndex, ) # 创建SQL查询工具 sql_query_engine = SQLTableRetrieverQueryEngine(sql_database, sql_obj_index.as_retriever(similarity_top_k=top_k), verbose=True, ) return sql_query_engine # Create a summary query engine def create_summary_query_engine(index, top_k=3, use_reranker=False, filters=None): summary_index = SummaryIndex(index.vector_store.get_nodes(node_ids=None)) summary_query_engine = summary_index.as_query_engine( response_mode=ResponseMode.TREE_SUMMARIZE, use_async=True, streaming=False, ) return summary_query_engine # Create a query engine def create_query_engine(index,top_k=3, use_reranker=False, filters=None, response_mode=None): # 创建向量检索查询工具 postprocess = None if use_reranker: postprocess = get_node_postprocessors() llm_query = os.getenv('LLM_QUERY_WAY','rag') if llm_query == 'graph': graphIndex:PropertyGraphIndex = index synonym_retriver = LLMSynonymRetriever(graphIndex.property_graph_store, llm=Settings.llm, include_text=False ) if graphIndex.property_graph_store.supports_vector_queries: vector_store = None else: vector_store = graphIndex.vector_store vector_retriver = VectorContextRetriever(graphIndex.property_graph_store, vector_store = vector_store, embed_model=Settings.embed_model, similarity_top_k=top_k, include_text=False ) retriever = graphIndex.as_retriever(sub_retrievers=[synonym_retriver,vector_retriver]) else: retriever = get_Retriever(index, similarity_top_k=top_k, filters=filters), query_engine = RetrieverQueryEngine.from_args( retriever = retriever, text_qa_template=text_qa_template, refine_template=refine_template, summary_template = summary_template, simple_template = simple_template, node_postprocessors=postprocess, use_async=True, streaming=False, response_mode = response_mode ) return query_engine