38 lines
1.5 KiB
Python
38 lines
1.5 KiB
Python
import logging,os
|
|
from llama_index.core.indices import VectorStoreIndex
|
|
from app.engine.vectordb import get_vector_store,get_Neo4j_Graph_Store
|
|
from typing import Dict,Any
|
|
from llama_index.core import PropertyGraphIndex
|
|
from llama_index.core.storage.storage_context import StorageContext
|
|
from llama_index.core import load_index_from_storage
|
|
|
|
logger = logging.getLogger("uvicorn")
|
|
|
|
|
|
def get_index(prjFlag:str):
|
|
if prjFlag is None or prjFlag == '':
|
|
raise ValueError('无效的工程标识')
|
|
logger.info("Connecting vector store...")
|
|
index = None
|
|
llm_query = os.getenv('LLM_QUERY_WAY')
|
|
if llm_query == 'graph':
|
|
index = getPropertyGraphIndex(prjFlag)
|
|
else:
|
|
store = get_vector_store(prjFlag)
|
|
index = VectorStoreIndex.from_vector_store(store)
|
|
logger.info("Finished load index from vector store.")
|
|
return index
|
|
|
|
|
|
def getPropertyGraphIndex(prjFlag:str):
|
|
GRAPH_STORE_TYPE = os.getenv("GRAPH_STORE_TYPE", "")
|
|
if GRAPH_STORE_TYPE == 'neo4j':
|
|
index = PropertyGraphIndex.from_existing(property_graph_store= get_Neo4j_Graph_Store(prjFlag))
|
|
else:
|
|
GRAPH_STORAGE_DIR = os.getenv("GRAPH_STORAGE_PATH", "storage_graph")
|
|
prjCachePath = GRAPH_STORAGE_DIR + f"/{prjFlag}"
|
|
if not os.path.exists(prjCachePath):
|
|
return None
|
|
storeContext = StorageContext.from_defaults(persist_dir = prjCachePath,vector_store = get_vector_store(prjFlag))
|
|
index = load_index_from_storage(storeContext)
|
|
return index |