Support multiple collections (#26)

* Allow passing the collection name in each request to override the default

* Allow getting the collection names in QdrantConnector

* get vector size from model description

* ruff format

* add isort

* apply pre-commit hooks

---------

Co-authored-by: generall <andrey@vasnetsov.com>
This commit is contained in:
Kacper Łukawski
2025-04-07 11:16:45 +02:00
committed by GitHub
parent 13cf930f8e
commit 7aad8ebb3c
8 changed files with 178 additions and 37 deletions

View File

@@ -19,3 +19,8 @@ class EmbeddingProvider(ABC):
def get_vector_name(self) -> str:
"""Get the name of the vector for the Qdrant collection."""
pass
@abstractmethod
def get_vector_size(self) -> int:
"""Get the size of the vector for the Qdrant collection."""
pass

View File

@@ -2,6 +2,7 @@ import asyncio
from typing import List
from fastembed import TextEmbedding
from fastembed.common.model_description import DenseModelDescription
from mcp_server_qdrant.embeddings.base import EmbeddingProvider
@@ -41,3 +42,10 @@ class FastEmbedProvider(EmbeddingProvider):
"""
model_name = self.embedding_model.model_name.split("/")[-1].lower()
return f"fast-{model_name}"
def get_vector_size(self) -> int:
"""Get the size of the vector for the Qdrant collection."""
model_description: DenseModelDescription = (
self.embedding_model._get_model_description(self.model_name)
)
return model_description.dim

View File

@@ -41,39 +41,29 @@ class QdrantConnector:
):
self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None
self._qdrant_api_key = qdrant_api_key
self._collection_name = collection_name
self._default_collection_name = collection_name
self._embedding_provider = embedding_provider
self._client = AsyncQdrantClient(
location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path
)
async def _ensure_collection_exists(self):
"""Ensure that the collection exists, creating it if necessary."""
collection_exists = await self._client.collection_exists(self._collection_name)
if not collection_exists:
# Create the collection with the appropriate vector size
# We'll get the vector size by embedding a sample text
sample_vector = await self._embedding_provider.embed_query("sample text")
vector_size = len(sample_vector)
async def get_collection_names(self) -> list[str]:
"""
Get the names of all collections in the Qdrant server.
:return: A list of collection names.
"""
response = await self._client.get_collections()
return [collection.name for collection in response.collections]
# Use the vector name as defined in the embedding provider
vector_name = self._embedding_provider.get_vector_name()
await self._client.create_collection(
collection_name=self._collection_name,
vectors_config={
vector_name: models.VectorParams(
size=vector_size,
distance=models.Distance.COSINE,
)
},
)
async def store(self, entry: Entry):
async def store(self, entry: Entry, *, collection_name: Optional[str] = None):
"""
Store some information in the Qdrant collection, along with the specified metadata.
:param entry: The entry to store in the Qdrant collection.
:param collection_name: The name of the collection to store the information in, optional. If not provided,
the default collection is used.
"""
await self._ensure_collection_exists()
collection_name = collection_name or self._default_collection_name
await self._ensure_collection_exists(collection_name)
# Embed the document
embeddings = await self._embedding_provider.embed_documents([entry.content])
@@ -82,7 +72,7 @@ class QdrantConnector:
vector_name = self._embedding_provider.get_vector_name()
payload = {"document": entry.content, "metadata": entry.metadata}
await self._client.upsert(
collection_name=self._collection_name,
collection_name=collection_name,
points=[
models.PointStruct(
id=uuid.uuid4().hex,
@@ -92,13 +82,19 @@ class QdrantConnector:
],
)
async def search(self, query: str) -> list[Entry]:
async def search(
self, query: str, *, collection_name: Optional[str] = None, limit: int = 10
) -> list[Entry]:
"""
Find points in the Qdrant collection. If there are no entries found, an empty list is returned.
:param query: The query to use for the search.
:param collection_name: The name of the collection to search in, optional. If not provided,
the default collection is used.
:param limit: The maximum number of entries to return.
:return: A list of entries found.
"""
collection_exists = await self._client.collection_exists(self._collection_name)
collection_name = collection_name or self._default_collection_name
collection_exists = await self._client.collection_exists(collection_name)
if not collection_exists:
return []
@@ -108,9 +104,9 @@ class QdrantConnector:
# Search in Qdrant
search_results = await self._client.search(
collection_name=self._collection_name,
collection_name=collection_name,
query_vector=models.NamedVector(name=vector_name, vector=query_vector),
limit=10,
limit=limit,
)
return [
@@ -120,3 +116,25 @@ class QdrantConnector:
)
for result in search_results
]
async def _ensure_collection_exists(self, collection_name: str):
"""
Ensure that the collection exists, creating it if necessary.
:param collection_name: The name of the collection to ensure exists.
"""
collection_exists = await self._client.collection_exists(collection_name)
if not collection_exists:
# Create the collection with the appropriate vector size
vector_size = self._embedding_provider.get_vector_size()
# Use the vector name as defined in the embedding provider
vector_name = self._embedding_provider.get_vector_name()
await self._client.create_collection(
collection_name=collection_name,
vectors_config={
vector_name: models.VectorParams(
size=vector_size,
distance=models.Distance.COSINE,
)
},
)

View File

@@ -1,7 +1,7 @@
import json
import logging
from contextlib import asynccontextmanager
from typing import AsyncIterator, List
from typing import AsyncIterator, List, Optional
from mcp.server import Server
from mcp.server.fastmcp import Context, FastMCP
@@ -75,12 +75,15 @@ async def store(
# If we set it to be optional, some of the MCP clients, like Cursor, cannot
# handle the optional parameter correctly.
metadata: Metadata = None,
collection_name: Optional[str] = None,
) -> str:
"""
Store some information in Qdrant.
:param ctx: The context for the request.
:param information: The information to store.
:param metadata: JSON metadata to store with the information, optional.
:param collection_name: The name of the collection to store the information in, optional. If not provided,
the default collection is used.
:return: A message indicating that the information was stored.
"""
await ctx.debug(f"Storing information {information} in Qdrant")
@@ -88,23 +91,37 @@ async def store(
"qdrant_connector"
]
entry = Entry(content=information, metadata=metadata)
await qdrant_connector.store(entry)
await qdrant_connector.store(entry, collection_name=collection_name)
if collection_name:
return f"Remembered: {information} in collection {collection_name}"
return f"Remembered: {information}"
@mcp.tool(name="qdrant-find", description=tool_settings.tool_find_description)
async def find(ctx: Context, query: str) -> List[str]:
async def find(
ctx: Context,
query: str,
collection_name: Optional[str] = None,
limit: int = 10,
) -> List[str]:
"""
Find memories in Qdrant.
:param ctx: The context for the request.
:param query: The query to use for the search.
:param collection_name: The name of the collection to search in, optional. If not provided,
the default collection is used.
:param limit: The maximum number of entries to return, optional. Default is 10.
:return: A list of entries found.
"""
await ctx.debug(f"Finding results for query {query}")
if collection_name:
await ctx.debug(f"Overriding the collection name with {collection_name}")
qdrant_connector: QdrantConnector = ctx.request_context.lifespan_context[
"qdrant_connector"
]
entries = await qdrant_connector.search(query)
entries = await qdrant_connector.search(
query, collection_name=collection_name, limit=limit
)
if not entries:
return [f"No information found for the query '{query}'"]
content = [