Files
2024-08-13 13:11:17 +08:00

75 lines
2.8 KiB
Python

import os
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.qdrant import QdrantVectorStore
from qdrant_client import qdrant_client
qclient = None
def get_qdrant_vector_store(docType:str):
collection_name = docType
#collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default")
vector_store_path = os.getenv("VECTOR_STORE_PATH")
host=os.getenv("VECTOR_STORE_HOST", "127.0.0.1"),
port=int(os.getenv("VECTOR_STORE_PORT", "6333")),
vector_store_path =os.path.join(vector_store_path,docType)
if not vector_store_path or not host:
raise ValueError(
"Please provide either VECTOR_STORE_PATH or VECTOR_STORE_HOST and VECTOR_STORE_PORT"
)
# if VECTOR_STORE_PATH is set, use a local QdrantVectorStore from the path
# otherwise, use a remote QdrantVectorStore
global qclient
if qclient == None:
if vector_store_path:
qclient = qdrant_client.QdrantClient(
path=vector_store_path,
)
else:
qclient = qdrant_client.QdrantClient(
host=host,
port=port,
)
vector_store = QdrantVectorStore(client=qclient, collection_name=collection_name)
return vector_store
def get_chroma_vector_store(docType:str):
#collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default")
collection_name = docType
vector_store_path = os.getenv("VECTOR_STORE_PATH")
vector_store_path =os.path.join(vector_store_path,docType)
# if VECTOR_STORE_PATH is set, use a local ChromaVectorStore from the path
# otherwise, use a remote ChromaVectorStore (ChromaDB Cloud is not supported yet)
if vector_store_path:
store = ChromaVectorStore.from_params(
persist_dir=vector_store_path, collection_name=collection_name,
collection_kwargs={"metadata":{"hnsw:space":"cosine"}},
)
else:
if not os.getenv("VECTOR_STORE_HOST") or not os.getenv("VECTOR_STORE_PORT"):
raise ValueError(
"Please provide either VECTOR_STORE_PATH or VECTOR_STORE_HOST and VECTOR_STORE_PORT"
)
store = ChromaVectorStore.from_params(
host=os.getenv("VECTOR_STORE_HOST"),
port=int(os.getenv("VECTOR_STORE_PORT")),
collection_name=collection_name,
collection_kwargs={"metadata":{"hnsw:space":"cosine"}},
)
return store
def get_vector_store(docType:str):
store_type=os.getenv("VECTOR_STORE_TYPE")
store = None
match store_type:
case "chroma":
store = get_chroma_vector_store(docType)
case "qdrant":
store = get_qdrant_vector_store(docType)
case _:
raise ValueError(f"Invalid vector store type: {store_type}")
return store