1 Commits

Author SHA1 Message Date
wanyaokun 7e58a1a223 实现多工程数据存储支持 2024-08-13 13:11:17 +08:00
9 changed files with 97 additions and 51 deletions
+3 -1
View File
@@ -87,7 +87,9 @@ class PrivateFileService:
nodes = pipeline.run(documents=documents) nodes = pipeline.run(documents=documents)
# Add the nodes to the index and persist it # Add the nodes to the index and persist it
current_index = get_index() indexs = get_index()
if len(indexs) > 0:
current_index = list(indexs.values())[0]
# Insert the documents into the index # Insert the documents into the index
if isinstance(current_index, LlamaCloudIndex): if isinstance(current_index, LlamaCloudIndex):
+3 -1
View File
@@ -43,7 +43,9 @@ def get_chat_engine(filters=None, params=None):
description="来源于一个由博微公司电力造价软件编制的造价工程文件。该文件以多张表格的形式存储存储了整个工程的全部数据内容。适用于以详细的自然语言查询表格数据方式查询造价工程各项具体属性、费用的数值。请先使用“zj_query_tool”无法解决才使用本工具") description="来源于一个由博微公司电力造价软件编制的造价工程文件。该文件以多张表格的形式存储存储了整个工程的全部数据内容。适用于以详细的自然语言查询表格数据方式查询造价工程各项具体属性、费用的数值。请先使用“zj_query_tool”无法解决才使用本工具")
# Add query tool if index exists # Add query tool if index exists
index = get_index() indexs = get_index()
if len(indexs) > 0:
index = list(indexs.values())[0]
if index is not None: if index is not None:
summary_index = SummaryIndex(index.vector_store.get_nodes(node_ids=None)) summary_index = SummaryIndex(index.vector_store.get_nodes(node_ids=None))
summary_query_engine = summary_index.as_query_engine() summary_query_engine = summary_index.as_query_engine()
+17 -19
View File
@@ -5,7 +5,7 @@ load_dotenv()
import logging import logging
import os import os
from app.engine.loaders import get_documents from app.engine.loaders import get_document_Types, get_documents
from app.engine.vectordb import get_vector_store from app.engine.vectordb import get_vector_store
from app.settings import init_settings from app.settings import init_settings
from llama_index.core.ingestion import IngestionPipeline from llama_index.core.ingestion import IngestionPipeline
@@ -19,17 +19,16 @@ logger = logging.getLogger()
STORAGE_DIR = os.getenv("STORAGE_DIR", "storage") STORAGE_DIR = os.getenv("STORAGE_DIR", "storage")
def get_doc_store(docType:str):
def get_doc_store():
# If the storage directory is there, load the document store from it. # If the storage directory is there, load the document store from it.
# If not, set up an in-memory document store since we can't load from a directory that doesn't exist. # If not, set up an in-memory document store since we can't load from a directory that doesn't exist.
if os.path.exists(STORAGE_DIR): storeDir = os.path.join(STORAGE_DIR,docType)
return SimpleDocumentStore.from_persist_dir(STORAGE_DIR) if os.path.exists(storeDir):
return SimpleDocumentStore.from_persist_dir(storeDir)
else: else:
return SimpleDocumentStore() return SimpleDocumentStore()
def run_pipeline(docstore, vector_store, documents): def run_pipeline(docstore, vector_store, documents):
pipeline = IngestionPipeline( pipeline = IngestionPipeline(
transformations=[ transformations=[
@@ -49,7 +48,6 @@ def run_pipeline(docstore, vector_store, documents):
return nodes return nodes
def persist_storage(docstore, vector_store): def persist_storage(docstore, vector_store):
storage_context = StorageContext.from_defaults( storage_context = StorageContext.from_defaults(
docstore=docstore, docstore=docstore,
@@ -57,28 +55,28 @@ def persist_storage(docstore, vector_store):
) )
storage_context.persist(STORAGE_DIR) storage_context.persist(STORAGE_DIR)
def generate_datasource(): def generate_datasource():
init_settings() init_settings()
logger.info("Generate index for the provided data") logger.info("Generate index for the provided data")
# Get the stores and documents or create new ones # Get the stores and documents or create new ones
documents = get_documents() docTypes = get_document_Types()
# Set private=false to mark the document as public (required for filtering) for docType in docTypes:
for doc in documents: documents = get_documents(docType)
doc.metadata["private"] = "false" # Set private=false to mark the document as public (required for filtering)
docstore = get_doc_store() for doc in documents:
vector_store = get_vector_store() doc.metadata["private"] = "false"
docstore = get_doc_store(docType)
vector_store = get_vector_store(docType)
# Run the ingestion pipeline # Run the ingestion pipeline
_ = run_pipeline(docstore, vector_store, documents) _ = run_pipeline(docstore, vector_store, documents)
# Build the index and persist storage # Build the index and persist storage
persist_storage(docstore, vector_store) persist_storage(docstore, vector_store)
logger.info("Finished generating the index") logger.info("Finished generating the index")
if __name__ == "__main__": if __name__ == "__main__":
from phoenix.trace import using_project from phoenix.trace import using_project
with using_project(os.getenv("PHOENIX_PROJECT_NAME") + "_generate") as obj: with using_project(os.getenv("PHOENIX_PROJECT_NAME") + "_generate") as obj:
+14 -13
View File
@@ -1,22 +1,23 @@
import logging import logging
from llama_index.core.indices import VectorStoreIndex from llama_index.core.indices import VectorStoreIndex
from app.engine.vectordb import get_vector_store from app.engine.vectordb import get_vector_store
from app.engine.generate import get_document_Types
logger = logging.getLogger("uvicorn") logger = logging.getLogger("uvicorn")
index = None indexs = {}
def get_index(params=None): def get_index(params=None):
global index global indexs
if index is None: if len(index) <= 0:
logger.info("Connecting vector store...") logger.info("Connecting vector store...")
docTypes = get_document_Types()
store = get_vector_store() for docType in docTypes:
# Load the index from the vector store store = get_vector_store(docType)
# If you are using a vector store that doesn't store text, # Load the index from the vector store
# you must load the index from both the vector store and the document store # If you are using a vector store that doesn't store text,
index = VectorStoreIndex.from_vector_store(store) # you must load the index from both the vector store and the document store
logger.info("Finished load index from vector store.") index = VectorStoreIndex.from_vector_store(store)
logger.info("Finished load index from vector store.")
return index indexs[docType] = index
return indexs
+42 -2
View File
@@ -13,8 +13,48 @@ def load_configs():
configs = yaml.safe_load(f) configs = yaml.safe_load(f)
return configs return configs
def path_difference(path1:str, path2:str):
import os
path1 = os.path.abspath(path1)
path2 = os.path.abspath(path2)
def get_documents(): path1_parts = path1.split(os.path.sep)
path2_parts = path2.split(os.path.sep)
for i, part in enumerate(path1_parts):
if part != path2_parts[i]:
break
else:
i += 1
pathKey = ''
for j in range(i,len(path2_parts)):
pathKey+=path2_parts[j] + '_'
return pathKey[0:-1]
def get_document_Types():
import os
rootPath = 'data'
configs = load_configs()
if configs is not None and len(configs.items()) > 0:
for loader_type, loader_config in configs.items():
if loader_type == "file":
rootPath = FileLoaderConfig(**loader_config).data_dir
break
types = []
dirStack = [rootPath]
while len(dirStack) > 0:
curDir = dirStack.pop()
dirs = [os.path.join(curDir, d) for d in os.listdir(curDir) if os.path.isdir(os.path.join(curDir, d))]
if len(dirs) > 0:
for dir in dirs:
dirStack.append(dir)
else:
types.append(path_difference(rootPath,curDir))
return types
def get_documents(docType:str):
documents = [] documents = []
config = load_configs() config = load_configs()
if config is None or len(config.items()) == 0: if config is None or len(config.items()) == 0:
@@ -28,7 +68,7 @@ def get_documents():
loader_config = loader_config or [] loader_config = loader_config or []
match loader_type: match loader_type:
case "file": case "file":
document = get_file_documents(FileLoaderConfig(**loader_config)) document = get_file_documents(FileLoaderConfig(**loader_config),docType)
case "web": case "web":
document = get_web_documents(WebLoaderConfig(**loader_config)) document = get_web_documents(WebLoaderConfig(**loader_config))
case "db": case "db":
+2 -5
View File
@@ -20,7 +20,6 @@ class FileLoaderConfig(BaseModel):
raise ValueError(f"Directory '{v}' does not exist") raise ValueError(f"Directory '{v}' does not exist")
return v return v
def llama_parse_parser(): def llama_parse_parser():
if os.getenv("LLAMA_CLOUD_API_KEY") is None: if os.getenv("LLAMA_CLOUD_API_KEY") is None:
raise ValueError( raise ValueError(
@@ -35,7 +34,6 @@ def llama_parse_parser():
) )
return parser return parser
def llama_parse_extractor() -> Dict[str, LlamaParse]: def llama_parse_extractor() -> Dict[str, LlamaParse]:
from llama_parse.utils import SUPPORTED_FILE_TYPES from llama_parse.utils import SUPPORTED_FILE_TYPES
@@ -45,8 +43,7 @@ def llama_parse_extractor() -> Dict[str, LlamaParse]:
def llama_local_extractor() -> Dict[str, BaseReader]: def llama_local_extractor() -> Dict[str, BaseReader]:
return {"json" : JSONReader} return {"json" : JSONReader}
def get_file_documents(config: FileLoaderConfig, childPath: str):
def get_file_documents(config: FileLoaderConfig):
from llama_index.core.readers import SimpleDirectoryReader from llama_index.core.readers import SimpleDirectoryReader
try: try:
@@ -63,7 +60,7 @@ def get_file_documents(config: FileLoaderConfig):
file_extractor = llama_local_extractor() file_extractor = llama_local_extractor()
reader = SimpleDirectoryReader( reader = SimpleDirectoryReader(
config.data_dir, os.path.join(config.data_dir,childPath.replace('_','\\')),
recursive=True, recursive=True,
filename_as_id=True, filename_as_id=True,
raise_on_error=True, raise_on_error=True,
+12 -8
View File
@@ -5,12 +5,14 @@ from qdrant_client import qdrant_client
qclient = None qclient = None
def get_qdrant_vector_store(): def get_qdrant_vector_store(docType:str):
collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default") collection_name = docType
#collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default")
vector_store_path = os.getenv("VECTOR_STORE_PATH") vector_store_path = os.getenv("VECTOR_STORE_PATH")
host=os.getenv("VECTOR_STORE_HOST", "127.0.0.1"), host=os.getenv("VECTOR_STORE_HOST", "127.0.0.1"),
port=int(os.getenv("VECTOR_STORE_PORT", "6333")), port=int(os.getenv("VECTOR_STORE_PORT", "6333")),
vector_store_path =os.path.join(vector_store_path,docType)
if not vector_store_path or not host: if not vector_store_path or not host:
raise ValueError( raise ValueError(
"Please provide either VECTOR_STORE_PATH or VECTOR_STORE_HOST and VECTOR_STORE_PORT" "Please provide either VECTOR_STORE_PATH or VECTOR_STORE_HOST and VECTOR_STORE_PORT"
@@ -32,9 +34,11 @@ def get_qdrant_vector_store():
vector_store = QdrantVectorStore(client=qclient, collection_name=collection_name) vector_store = QdrantVectorStore(client=qclient, collection_name=collection_name)
return vector_store return vector_store
def get_chroma_vector_store(): def get_chroma_vector_store(docType:str):
collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default") #collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default")
collection_name = docType
vector_store_path = os.getenv("VECTOR_STORE_PATH") vector_store_path = os.getenv("VECTOR_STORE_PATH")
vector_store_path =os.path.join(vector_store_path,docType)
# if VECTOR_STORE_PATH is set, use a local ChromaVectorStore from the path # if VECTOR_STORE_PATH is set, use a local ChromaVectorStore from the path
# otherwise, use a remote ChromaVectorStore (ChromaDB Cloud is not supported yet) # otherwise, use a remote ChromaVectorStore (ChromaDB Cloud is not supported yet)
if vector_store_path: if vector_store_path:
@@ -55,16 +59,16 @@ def get_chroma_vector_store():
) )
return store return store
def get_vector_store(): def get_vector_store(docType:str):
store_type=os.getenv("VECTOR_STORE_TYPE") store_type=os.getenv("VECTOR_STORE_TYPE")
store = None store = None
match store_type: match store_type:
case "chroma": case "chroma":
store = get_chroma_vector_store() store = get_chroma_vector_store(docType)
case "qdrant": case "qdrant":
store = get_qdrant_vector_store() store = get_qdrant_vector_store(docType)
case _: case _:
raise ValueError(f"Invalid vector store type: {store_type}") raise ValueError(f"Invalid vector store type: {store_type}")
+1 -1
View File
@@ -1,4 +1,4 @@
rmdir /S /Q storage_vector rmdir /S /Q storage_vector
rmdir /S /Q storage rmdir /S /Q storage
C:\Users\liuyue\AppData\Local\pypoetry\Cache\virtualenvs\app-laEO4lY0-py3.11\Scripts\python app/engine/generate.py python app/engine/generate.py
+3 -1
View File
@@ -19,7 +19,9 @@ def main():
init_settings() init_settings()
init_observability() init_observability()
index = get_index() indexs = get_index()
if len(indexs) > 0:
index = list(indexs.values())[0]
top_k = 5 top_k = 5
filters = generate_filters([]) filters = generate_filters([])