from llama_index.core.indices.property_graph import LLMSynonymRetriever,VectorContextRetriever,PGRetriever from llama_index.core.indices.property_graph.transformations.schema_llm import * from llama_index.core import SimpleDirectoryReader from llama_index.core import settings from llama_index.core import PropertyGraphIndex from typing import List,Tuple,Literal from app.settings import init_settings import os from llama_index.core.storage.storage_context import StorageContext from llama_index.core import load_index_from_storage from app.observability import init_observability from app.engine.vectordb import get_Neo4j_Graph_Store from llama_index.core.response_synthesizers import ResponseMode from util.register import * from llama_index.core.query_engine import RetrieverQueryEngine from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template from app.engine.engine import get_node_postprocessors class PropertyGraph: def __init__(self,prjFlag:str) -> None: self._prjFlag = prjFlag def create_query_engine(self,retriever): postprocess = get_node_postprocessors() query_engine = RetrieverQueryEngine.from_args( retriever = retriever, text_qa_template=text_qa_template, refine_template=refine_template, summary_template = summary_template, simple_template = simple_template, node_postprocessors=postprocess, use_async=True, streaming=False, response_mode = ResponseMode.TREE_SUMMARIZE ) return query_engine def getPropertyGraphIndex(self): GRAPH_STORE_TYPE = os.getenv("GRAPH_STORE_TYPE", "") if GRAPH_STORE_TYPE == 'neo4j': index = PropertyGraphIndex.from_existing(property_graph_store= get_Neo4j_Graph_Store(self._prjFlag)) else: GRAPH_STORAGE_DIR = os.getenv("GRAPH_STORAGE_PATH", "storage_graph") prjCachePath = GRAPH_STORAGE_DIR + f"/{self._prjFlag}" if not os.path.exists(prjCachePath): return None storeContext = StorageContext.from_defaults(persist_dir = prjCachePath) index = load_index_from_storage(storeContext) return index def query(self,query_str:str): index = self.getPropertyGraphIndex() synonym_retriver = LLMSynonymRetriever(index.property_graph_store, llm=settings.Settings.llm, max_keywords=10, include_text=False ) if index.property_graph_store.supports_vector_queries: vector_store = None else: vector_store = index.vector_store vector_retriver = VectorContextRetriever(index.property_graph_store, vector_store = vector_store, embed_model=settings.Settings.embed_model, similarity_top_k=5, include_text=False ) retriever = index.as_retriever(sub_retrievers=[synonym_retriver,vector_retriver]) query_engine = self.create_query_engine(retriever) response = query_engine.query(query_str) print(response) return str(response) if __name__ == "__main__": init_settings() init_observability() # graph = PropertyGraph('projects_1b20bbf4-3243-4ac3-bcf0-8a91e9157521') # graph.query('代码为XLBT的金额是')