88 lines
3.7 KiB
Python
88 lines
3.7 KiB
Python
from llama_index.core.indices.property_graph import LLMSynonymRetriever,VectorContextRetriever,PGRetriever
|
|
from llama_index.core.indices.property_graph.transformations.schema_llm import *
|
|
from llama_index.core import SimpleDirectoryReader
|
|
from llama_index.core import settings
|
|
from llama_index.core import PropertyGraphIndex
|
|
from typing import List,Tuple,Literal
|
|
from app.settings import init_settings
|
|
import os
|
|
from llama_index.core.storage.storage_context import StorageContext
|
|
from llama_index.core import load_index_from_storage
|
|
from app.observability import init_observability
|
|
from app.engine.vectordb import get_Neo4j_Graph_Store,get_vector_store
|
|
from llama_index.core.response_synthesizers import ResponseMode
|
|
from util.register import *
|
|
from llama_index.core.query_engine import RetrieverQueryEngine
|
|
from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template
|
|
from app.engine.engine import get_node_postprocessors
|
|
|
|
class PropertyGraph:
|
|
def __init__(self,prjFlag:str) -> None:
|
|
self._prjFlag = prjFlag
|
|
|
|
def create_query_engine(self,retriever):
|
|
postprocess = get_node_postprocessors()
|
|
query_engine = RetrieverQueryEngine.from_args(
|
|
retriever = retriever,
|
|
text_qa_template=text_qa_template,
|
|
refine_template=refine_template,
|
|
summary_template = summary_template,
|
|
simple_template = simple_template,
|
|
node_postprocessors=postprocess,
|
|
use_async=True,
|
|
streaming=False,
|
|
response_mode = ResponseMode.TREE_SUMMARIZE
|
|
)
|
|
return query_engine
|
|
|
|
def getPropertyGraphIndex(self):
|
|
GRAPH_STORE_TYPE = os.getenv("GRAPH_STORE_TYPE", "")
|
|
if GRAPH_STORE_TYPE == 'neo4j':
|
|
index = PropertyGraphIndex.from_existing(property_graph_store= get_Neo4j_Graph_Store(self._prjFlag))
|
|
else:
|
|
GRAPH_STORAGE_DIR = os.getenv("GRAPH_STORAGE_PATH", "storage_graph")
|
|
prjCachePath = GRAPH_STORAGE_DIR + f"/{self._prjFlag}"
|
|
if not os.path.exists(prjCachePath):
|
|
return None
|
|
storeContext = StorageContext.from_defaults(persist_dir = prjCachePath,vector_store = get_vector_store(self._prjFlag))
|
|
index = load_index_from_storage(storeContext)
|
|
return index
|
|
|
|
def query(self,query_str:str):
|
|
index = self.getPropertyGraphIndex()
|
|
synonym_retriver = LLMSynonymRetriever(index.property_graph_store,
|
|
llm=settings.Settings.llm,
|
|
max_keywords=10,
|
|
include_text=False
|
|
)
|
|
if index.property_graph_store.supports_vector_queries:
|
|
vector_store = None
|
|
else:
|
|
vector_store = index.vector_store
|
|
vector_retriver = VectorContextRetriever(index.property_graph_store,
|
|
vector_store = vector_store,
|
|
embed_model=settings.Settings.embed_model,
|
|
similarity_top_k=5,
|
|
include_text=False
|
|
)
|
|
|
|
retriever = index.as_retriever(sub_retrievers=[synonym_retriver,vector_retriver])
|
|
query_engine = self.create_query_engine(retriever)
|
|
|
|
response = query_engine.query(query_str)
|
|
print(response)
|
|
return str(response)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
init_settings()
|
|
init_observability()
|
|
graph = PropertyGraph('projects_1b20bbf4-3243-4ac3-bcf0-8a91e9157521')
|
|
graph.query('代码为XLBT的金额是')
|
|
|
|
|
|
|
|
|
|
|
|
|