Files
zjdataai-app/backend/app/engine/generate.py
T
2024-09-14 09:51:11 +08:00

133 lines
4.7 KiB
Python

from dotenv import load_dotenv
load_dotenv()
import logging
import os
from app.engine.loaders import get_document_Types, get_documents,getProjectInfos
from app.engine.vectordb import get_vector_store,get_Neo4j_Graph_Store
from app.settings import init_settings
from app.engine.retriever.CHBM25Retriever import CHBM25Retriever
from llama_index.core.ingestion import IngestionPipeline
from llama_index.core.node_parser import SentenceSplitter,MarkdownNodeParser
from llama_index.core.settings import Settings
from llama_index.core.storage import StorageContext
from llama_index.core.storage.docstore import SimpleDocumentStore
from llama_index.core import PropertyGraphIndex
from app.engine.graph.extractor import PrjGraphExtractor
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
STORAGE_DIR = os.getenv("STORAGE_DIR", "storage")
def get_doc_store(docType:str):
# If the storage directory is there, load the document store from it.
# If not, set up an in-memory document store since we can't load from a directory that doesn't exist.
storeDir = os.path.join(STORAGE_DIR,docType)
if os.path.exists(storeDir):
return SimpleDocumentStore.from_persist_dir(storeDir)
else:
return SimpleDocumentStore()
def run_pipeline(docstore, vector_store, documents):
pipeline = IngestionPipeline(
transformations=[
#SentenceSplitter(
#chunk_size=Settings.chunk_size,
#chunk_overlap=Settings.chunk_overlap,
#),
#MarkdownNodeParser(),
Settings.embed_model,
],
docstore=docstore,
docstore_strategy="upserts_and_delete",
vector_store=vector_store,
)
# Run the ingestion pipeline and store the results
nodes = pipeline.run(show_progress=True, documents=documents)
return nodes
def persist_storage(docstore, vector_store):
storage_context = StorageContext.from_defaults(
docstore=docstore,
vector_store=vector_store,
)
storage_context.persist(STORAGE_DIR)
def persist_BMRetriever(vector_store):
STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm")
nodes = vector_store.get_nodes([])
top_k = min(int(os.getenv("TOP_K", "3")),len(nodes))
bmRetriver = CHBM25Retriever.from_defaults(similarity_top_k=top_k,nodes = nodes)
bmRetriver.persist(STORAGE_DIR)
def generate_datasource():
logger.info("Generate index for the provided data")
# Get the stores and documents or create new ones
docTypes = get_document_Types()
for docType in docTypes:
documents = get_documents(docType)
# Set private=false to mark the document as public (required for filtering)
for doc in documents:
doc.metadata["private"] = "false"
docstore = get_doc_store(docType)
vector_store = get_vector_store(docType)
# Run the ingestion pipeline
_ = run_pipeline(docstore, vector_store, documents)
# Build the index and persist storage
persist_storage(docstore, vector_store)
persist_BMRetriever(vector_store)
logger.info("Finished generating the index")
class PropertyGraphChache:
def generate(self):
GRAPH_STORE_TYPE = os.getenv("GRAPH_STORE_TYPE", "")
GRAPH_STORAGE_DIR = os.getenv("GRAPH_STORAGE_PATH", "storage_graph")
prjInfos = getProjectInfos()
for prjInfo in prjInfos:
prjFlag = prjInfo['flag']
prjName = prjInfo['name']
chche_Path = GRAPH_STORAGE_DIR + f'/{prjFlag}'
if GRAPH_STORE_TYPE == 'neo4j':
self.neo4jProertyGraph()
else:
self.simplePropertyGraph(prjName,prjFlag,chche_Path)
def simplePropertyGraph(self,prjName:str,prjFlag:str,filePath:str):
documents = get_documents(prjFlag)
index = PropertyGraphIndex(
nodes =documents,
kg_extractors = [PrjGraphExtractor(prjName)],
embed_model = Settings.embed_model,
show_progress= True
)
os.makedirs(filePath,exist_ok = True)
index.storage_context.persist(persist_dir = filePath)
def neo4jProertyGraph(self,prjName:str,prjFlag:str,filePath:str):
neo4jStore =get_Neo4j_Graph_Store(prjFlag)
documents = get_documents(prjFlag)
PropertyGraphIndex(
nodes =documents,
property_graph_store = neo4jStore,
kg_extractors = [PrjGraphExtractor(prjName)],
embed_model = Settings.embed_model,
show_progress= True
)
if __name__ == "__main__":
init_settings()
from phoenix.trace import using_project
with using_project(os.getenv("PHOENIX_PROJECT_NAME") + "_generate") as obj:
generate_datasource()
PropertyGraphChache().generate()