Support multiple collections (#26)
* Allow passing the collection name in each request to override the default * Allow getting the collection names in QdrantConnector * get vector size from model description * ruff format * add isort * apply pre-commit hooks --------- Co-authored-by: generall <andrey@vasnetsov.com>
This commit is contained in:
@@ -41,39 +41,29 @@ class QdrantConnector:
|
||||
):
|
||||
self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None
|
||||
self._qdrant_api_key = qdrant_api_key
|
||||
self._collection_name = collection_name
|
||||
self._default_collection_name = collection_name
|
||||
self._embedding_provider = embedding_provider
|
||||
self._client = AsyncQdrantClient(
|
||||
location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path
|
||||
)
|
||||
|
||||
async def _ensure_collection_exists(self):
|
||||
"""Ensure that the collection exists, creating it if necessary."""
|
||||
collection_exists = await self._client.collection_exists(self._collection_name)
|
||||
if not collection_exists:
|
||||
# Create the collection with the appropriate vector size
|
||||
# We'll get the vector size by embedding a sample text
|
||||
sample_vector = await self._embedding_provider.embed_query("sample text")
|
||||
vector_size = len(sample_vector)
|
||||
async def get_collection_names(self) -> list[str]:
|
||||
"""
|
||||
Get the names of all collections in the Qdrant server.
|
||||
:return: A list of collection names.
|
||||
"""
|
||||
response = await self._client.get_collections()
|
||||
return [collection.name for collection in response.collections]
|
||||
|
||||
# Use the vector name as defined in the embedding provider
|
||||
vector_name = self._embedding_provider.get_vector_name()
|
||||
await self._client.create_collection(
|
||||
collection_name=self._collection_name,
|
||||
vectors_config={
|
||||
vector_name: models.VectorParams(
|
||||
size=vector_size,
|
||||
distance=models.Distance.COSINE,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
async def store(self, entry: Entry):
|
||||
async def store(self, entry: Entry, *, collection_name: Optional[str] = None):
|
||||
"""
|
||||
Store some information in the Qdrant collection, along with the specified metadata.
|
||||
:param entry: The entry to store in the Qdrant collection.
|
||||
:param collection_name: The name of the collection to store the information in, optional. If not provided,
|
||||
the default collection is used.
|
||||
"""
|
||||
await self._ensure_collection_exists()
|
||||
collection_name = collection_name or self._default_collection_name
|
||||
await self._ensure_collection_exists(collection_name)
|
||||
|
||||
# Embed the document
|
||||
embeddings = await self._embedding_provider.embed_documents([entry.content])
|
||||
@@ -82,7 +72,7 @@ class QdrantConnector:
|
||||
vector_name = self._embedding_provider.get_vector_name()
|
||||
payload = {"document": entry.content, "metadata": entry.metadata}
|
||||
await self._client.upsert(
|
||||
collection_name=self._collection_name,
|
||||
collection_name=collection_name,
|
||||
points=[
|
||||
models.PointStruct(
|
||||
id=uuid.uuid4().hex,
|
||||
@@ -92,13 +82,19 @@ class QdrantConnector:
|
||||
],
|
||||
)
|
||||
|
||||
async def search(self, query: str) -> list[Entry]:
|
||||
async def search(
|
||||
self, query: str, *, collection_name: Optional[str] = None, limit: int = 10
|
||||
) -> list[Entry]:
|
||||
"""
|
||||
Find points in the Qdrant collection. If there are no entries found, an empty list is returned.
|
||||
:param query: The query to use for the search.
|
||||
:param collection_name: The name of the collection to search in, optional. If not provided,
|
||||
the default collection is used.
|
||||
:param limit: The maximum number of entries to return.
|
||||
:return: A list of entries found.
|
||||
"""
|
||||
collection_exists = await self._client.collection_exists(self._collection_name)
|
||||
collection_name = collection_name or self._default_collection_name
|
||||
collection_exists = await self._client.collection_exists(collection_name)
|
||||
if not collection_exists:
|
||||
return []
|
||||
|
||||
@@ -108,9 +104,9 @@ class QdrantConnector:
|
||||
|
||||
# Search in Qdrant
|
||||
search_results = await self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
collection_name=collection_name,
|
||||
query_vector=models.NamedVector(name=vector_name, vector=query_vector),
|
||||
limit=10,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return [
|
||||
@@ -120,3 +116,25 @@ class QdrantConnector:
|
||||
)
|
||||
for result in search_results
|
||||
]
|
||||
|
||||
async def _ensure_collection_exists(self, collection_name: str):
|
||||
"""
|
||||
Ensure that the collection exists, creating it if necessary.
|
||||
:param collection_name: The name of the collection to ensure exists.
|
||||
"""
|
||||
collection_exists = await self._client.collection_exists(collection_name)
|
||||
if not collection_exists:
|
||||
# Create the collection with the appropriate vector size
|
||||
vector_size = self._embedding_provider.get_vector_size()
|
||||
|
||||
# Use the vector name as defined in the embedding provider
|
||||
vector_name = self._embedding_provider.get_vector_name()
|
||||
await self._client.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config={
|
||||
vector_name: models.VectorParams(
|
||||
size=vector_size,
|
||||
distance=models.Distance.COSINE,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user