71 lines
2.6 KiB
Python
71 lines
2.6 KiB
Python
import os
|
|
from llama_index.vector_stores.chroma import ChromaVectorStore
|
|
from llama_index.vector_stores.qdrant import QdrantVectorStore
|
|
from qdrant_client import qdrant_client
|
|
|
|
qclient = None
|
|
|
|
def get_qdrant_vector_store():
|
|
collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default")
|
|
vector_store_path = os.getenv("VECTOR_STORE_PATH")
|
|
host=os.getenv("VECTOR_STORE_HOST", "127.0.0.1"),
|
|
port=int(os.getenv("VECTOR_STORE_PORT", "6333")),
|
|
|
|
if not vector_store_path or not host:
|
|
raise ValueError(
|
|
"Please provide either VECTOR_STORE_PATH or VECTOR_STORE_HOST and VECTOR_STORE_PORT"
|
|
)
|
|
# if VECTOR_STORE_PATH is set, use a local QdrantVectorStore from the path
|
|
# otherwise, use a remote QdrantVectorStore
|
|
global qclient
|
|
if qclient == None:
|
|
if vector_store_path:
|
|
qclient = qdrant_client.QdrantClient(
|
|
path=vector_store_path,
|
|
)
|
|
else:
|
|
qclient = qdrant_client.QdrantClient(
|
|
host=host,
|
|
port=port,
|
|
)
|
|
|
|
vector_store = QdrantVectorStore(client=qclient, collection_name=collection_name)
|
|
return vector_store
|
|
|
|
def get_chroma_vector_store():
|
|
collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default")
|
|
vector_store_path = os.getenv("VECTOR_STORE_PATH")
|
|
# if VECTOR_STORE_PATH is set, use a local ChromaVectorStore from the path
|
|
# otherwise, use a remote ChromaVectorStore (ChromaDB Cloud is not supported yet)
|
|
if vector_store_path:
|
|
store = ChromaVectorStore.from_params(
|
|
persist_dir=vector_store_path, collection_name=collection_name,
|
|
collection_kwargs={"metadata":{"hnsw:space":"cosine"}},
|
|
)
|
|
else:
|
|
if not os.getenv("VECTOR_STORE_HOST") or not os.getenv("VECTOR_STORE_PORT"):
|
|
raise ValueError(
|
|
"Please provide either VECTOR_STORE_PATH or VECTOR_STORE_HOST and VECTOR_STORE_PORT"
|
|
)
|
|
store = ChromaVectorStore.from_params(
|
|
host=os.getenv("VECTOR_STORE_HOST"),
|
|
port=int(os.getenv("VECTOR_STORE_PORT")),
|
|
collection_name=collection_name,
|
|
collection_kwargs={"metadata":{"hnsw:space":"cosine"}},
|
|
)
|
|
return store
|
|
|
|
def get_vector_store():
|
|
store_type=os.getenv("VECTOR_STORE_TYPE")
|
|
|
|
store = None
|
|
|
|
match store_type:
|
|
case "chroma":
|
|
store = get_chroma_vector_store()
|
|
case "qdrant":
|
|
store = get_qdrant_vector_store()
|
|
case _:
|
|
raise ValueError(f"Invalid vector store type: {store_type}")
|
|
|
|
return store |