新增属性图谱
This commit is contained in:
@@ -5,8 +5,8 @@ load_dotenv()
|
||||
import logging
|
||||
import os
|
||||
|
||||
from app.engine.loaders import get_document_Types, get_documents
|
||||
from app.engine.vectordb import get_vector_store
|
||||
from app.engine.loaders import get_document_Types, get_documents,getProjectInfos
|
||||
from app.engine.vectordb import get_vector_store,get_Neo4j_Graph_Store
|
||||
from app.settings import init_settings
|
||||
from app.engine.retriever.CHBM25Retriever import CHBM25Retriever
|
||||
from llama_index.core.ingestion import IngestionPipeline
|
||||
@@ -14,15 +14,15 @@ from llama_index.core.node_parser import SentenceSplitter,MarkdownNodeParser
|
||||
from llama_index.core.settings import Settings
|
||||
from llama_index.core.storage import StorageContext
|
||||
from llama_index.core.storage.docstore import SimpleDocumentStore
|
||||
from llama_index.core import PropertyGraphIndex
|
||||
from app.engine.graph.extractor import PrjGraphExtractor
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger()
|
||||
|
||||
STORAGE_DIR = os.getenv("STORAGE_DIR", "storage")
|
||||
|
||||
|
||||
def get_doc_store(docType:str):
|
||||
|
||||
# If the storage directory is there, load the document store from it.
|
||||
# If not, set up an in-memory document store since we can't load from a directory that doesn't exist.
|
||||
storeDir = os.path.join(STORAGE_DIR,docType)
|
||||
@@ -31,7 +31,6 @@ def get_doc_store(docType:str):
|
||||
else:
|
||||
return SimpleDocumentStore()
|
||||
|
||||
|
||||
def run_pipeline(docstore, vector_store, documents):
|
||||
pipeline = IngestionPipeline(
|
||||
transformations=[
|
||||
@@ -49,10 +48,8 @@ def run_pipeline(docstore, vector_store, documents):
|
||||
|
||||
# Run the ingestion pipeline and store the results
|
||||
nodes = pipeline.run(show_progress=True, documents=documents)
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
def persist_storage(docstore, vector_store):
|
||||
storage_context = StorageContext.from_defaults(
|
||||
docstore=docstore,
|
||||
@@ -60,7 +57,6 @@ def persist_storage(docstore, vector_store):
|
||||
)
|
||||
storage_context.persist(STORAGE_DIR)
|
||||
|
||||
|
||||
def persist_BMRetriever(vector_store):
|
||||
STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm")
|
||||
nodes = vector_store.get_nodes([])
|
||||
@@ -68,9 +64,7 @@ def persist_BMRetriever(vector_store):
|
||||
bmRetriver = CHBM25Retriever.from_defaults(similarity_top_k=top_k,nodes = nodes)
|
||||
bmRetriver.persist(STORAGE_DIR)
|
||||
|
||||
|
||||
def generate_datasource():
|
||||
init_settings()
|
||||
logger.info("Generate index for the provided data")
|
||||
|
||||
# Get the stores and documents or create new ones
|
||||
@@ -92,8 +86,47 @@ def generate_datasource():
|
||||
|
||||
logger.info("Finished generating the index")
|
||||
|
||||
class PropertyGraphChache:
|
||||
def generate(self):
|
||||
GRAPH_STORE_TYPE = os.getenv("GRAPH_STORE_TYPE", "")
|
||||
GRAPH_STORAGE_DIR = os.getenv("GRAPH_STORAGE_DIR", "storage_graph")
|
||||
prjInfos = getProjectInfos()
|
||||
for prjInfo in prjInfos:
|
||||
prjFlag = prjInfo['flag']
|
||||
prjName = prjInfo['name']
|
||||
chche_Path = GRAPH_STORAGE_DIR + f'/{prjFlag}'
|
||||
|
||||
if GRAPH_STORE_TYPE == 'neo4j':
|
||||
self.neo4jProertyGraph()
|
||||
else:
|
||||
self.simplePropertyGraph(prjName,prjFlag,chche_Path)
|
||||
|
||||
def simplePropertyGraph(self,prjName:str,prjFlag:str,filePath:str):
|
||||
documents = get_documents(prjFlag)
|
||||
index = PropertyGraphIndex(
|
||||
nodes =documents,
|
||||
kg_extractors = [PrjGraphExtractor(prjName)],
|
||||
embed_model = Settings.embed_model,
|
||||
show_progress= True
|
||||
)
|
||||
os.makedirs(filePath,exist_ok = True)
|
||||
index.storage_context.persist(persist_dir = filePath)
|
||||
|
||||
def neo4jProertyGraph(self,prjName:str,prjFlag:str,filePath:str):
|
||||
neo4jStore =get_Neo4j_Graph_Store(prjFlag)
|
||||
documents = get_documents(prjFlag)
|
||||
PropertyGraphIndex(
|
||||
nodes =documents,
|
||||
property_graph_store = neo4jStore,
|
||||
kg_extractors = [PrjGraphExtractor(prjName)],
|
||||
embed_model = Settings.embed_model,
|
||||
show_progress= True
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_settings()
|
||||
from phoenix.trace import using_project
|
||||
with using_project(os.getenv("PHOENIX_PROJECT_NAME") + "_generate") as obj:
|
||||
generate_datasource()
|
||||
PropertyGraphChache().generate()
|
||||
|
||||
Reference in New Issue
Block a user