Support multiple collections (#26)

* Allow passing the collection name in each request to override the default

* Allow getting the collection names in QdrantConnector

* get vector size from model description

* ruff format

* add isort

* apply pre-commit hooks

---------

Co-authored-by: generall <andrey@vasnetsov.com>
This commit is contained in:
Kacper Łukawski
2025-04-07 11:16:45 +02:00
committed by GitHub
parent 13cf930f8e
commit 7aad8ebb3c
8 changed files with 178 additions and 37 deletions

View File

@@ -3,7 +3,7 @@
[![smithery badge](https://smithery.ai/badge/mcp-server-qdrant)](https://smithery.ai/protocol/mcp-server-qdrant) [![smithery badge](https://smithery.ai/badge/mcp-server-qdrant)](https://smithery.ai/protocol/mcp-server-qdrant)
> The [Model Context Protocol (MCP)](https://modelcontextprotocol.io/introduction) is an open protocol that enables > The [Model Context Protocol (MCP)](https://modelcontextprotocol.io/introduction) is an open protocol that enables
> seamless integration between LLM applications and external data sources and tools. Whether youre building an > seamless integration between LLM applications and external data sources and tools. Whether you're building an
> AI-powered IDE, enhancing a chat interface, or creating custom AI workflows, MCP provides a standardized way to > AI-powered IDE, enhancing a chat interface, or creating custom AI workflows, MCP provides a standardized way to
> connect LLMs with the context they need. > connect LLMs with the context they need.
@@ -25,11 +25,15 @@ It acts as a semantic memory layer on top of the Qdrant database.
- Input: - Input:
- `information` (string): Information to store - `information` (string): Information to store
- `metadata` (JSON): Optional metadata to store - `metadata` (JSON): Optional metadata to store
- `collection_name` (string): Name of the collection to store the information in, optional. If not provided,
the default collection name will be used.
- Returns: Confirmation message - Returns: Confirmation message
2. `qdrant-find` 2. `qdrant-find`
- Retrieve relevant information from the Qdrant database - Retrieve relevant information from the Qdrant database
- Input: - Input:
- `query` (string): Query to use for searching - `query` (string): Query to use for searching
- `collection_name` (string): Name of the collection to store the information in, optional. If not provided,
the default collection name will be used.
- Returns: Information stored in the Qdrant database as separate messages - Returns: Information stored in the Qdrant database as separate messages
## Environment Variables ## Environment Variables
@@ -40,7 +44,7 @@ The configuration of the server is done using environment variables:
|--------------------------|---------------------------------------------------------------------|-------------------------------------------------------------------| |--------------------------|---------------------------------------------------------------------|-------------------------------------------------------------------|
| `QDRANT_URL` | URL of the Qdrant server | None | | `QDRANT_URL` | URL of the Qdrant server | None |
| `QDRANT_API_KEY` | API key for the Qdrant server | None | | `QDRANT_API_KEY` | API key for the Qdrant server | None |
| `COLLECTION_NAME` | Name of the collection to use | *Required* | | `COLLECTION_NAME` | Name of the default collection to use. | *Required* |
| `QDRANT_LOCAL_PATH` | Path to the local Qdrant database (alternative to `QDRANT_URL`) | None | | `QDRANT_LOCAL_PATH` | Path to the local Qdrant database (alternative to `QDRANT_URL`) | None |
| `EMBEDDING_PROVIDER` | Embedding provider to use (currently only "fastembed" is supported) | `fastembed` | | `EMBEDDING_PROVIDER` | Embedding provider to use (currently only "fastembed" is supported) | `fastembed` |
| `EMBEDDING_MODEL` | Name of the embedding model to use | `sentence-transformers/all-MiniLM-L6-v2` | | `EMBEDDING_MODEL` | Name of the embedding model to use | `sentence-transformers/all-MiniLM-L6-v2` |

View File

@@ -18,11 +18,12 @@ build-backend = "hatchling.build"
[tool.uv] [tool.uv]
dev-dependencies = [ dev-dependencies = [
"isort>=6.0.1",
"pre-commit>=4.1.0", "pre-commit>=4.1.0",
"pyright>=1.1.389", "pyright>=1.1.389",
"pytest>=8.3.3", "pytest>=8.3.3",
"pytest-asyncio>=0.23.0", "pytest-asyncio>=0.23.0",
"ruff>=0.8.0" "ruff>=0.8.0",
] ]
[project.scripts] [project.scripts]

View File

@@ -19,3 +19,8 @@ class EmbeddingProvider(ABC):
def get_vector_name(self) -> str: def get_vector_name(self) -> str:
"""Get the name of the vector for the Qdrant collection.""" """Get the name of the vector for the Qdrant collection."""
pass pass
@abstractmethod
def get_vector_size(self) -> int:
"""Get the size of the vector for the Qdrant collection."""
pass

View File

@@ -2,6 +2,7 @@ import asyncio
from typing import List from typing import List
from fastembed import TextEmbedding from fastembed import TextEmbedding
from fastembed.common.model_description import DenseModelDescription
from mcp_server_qdrant.embeddings.base import EmbeddingProvider from mcp_server_qdrant.embeddings.base import EmbeddingProvider
@@ -41,3 +42,10 @@ class FastEmbedProvider(EmbeddingProvider):
""" """
model_name = self.embedding_model.model_name.split("/")[-1].lower() model_name = self.embedding_model.model_name.split("/")[-1].lower()
return f"fast-{model_name}" return f"fast-{model_name}"
def get_vector_size(self) -> int:
"""Get the size of the vector for the Qdrant collection."""
model_description: DenseModelDescription = (
self.embedding_model._get_model_description(self.model_name)
)
return model_description.dim

View File

@@ -41,39 +41,29 @@ class QdrantConnector:
): ):
self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None
self._qdrant_api_key = qdrant_api_key self._qdrant_api_key = qdrant_api_key
self._collection_name = collection_name self._default_collection_name = collection_name
self._embedding_provider = embedding_provider self._embedding_provider = embedding_provider
self._client = AsyncQdrantClient( self._client = AsyncQdrantClient(
location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path
) )
async def _ensure_collection_exists(self): async def get_collection_names(self) -> list[str]:
"""Ensure that the collection exists, creating it if necessary.""" """
collection_exists = await self._client.collection_exists(self._collection_name) Get the names of all collections in the Qdrant server.
if not collection_exists: :return: A list of collection names.
# Create the collection with the appropriate vector size """
# We'll get the vector size by embedding a sample text response = await self._client.get_collections()
sample_vector = await self._embedding_provider.embed_query("sample text") return [collection.name for collection in response.collections]
vector_size = len(sample_vector)
# Use the vector name as defined in the embedding provider async def store(self, entry: Entry, *, collection_name: Optional[str] = None):
vector_name = self._embedding_provider.get_vector_name()
await self._client.create_collection(
collection_name=self._collection_name,
vectors_config={
vector_name: models.VectorParams(
size=vector_size,
distance=models.Distance.COSINE,
)
},
)
async def store(self, entry: Entry):
""" """
Store some information in the Qdrant collection, along with the specified metadata. Store some information in the Qdrant collection, along with the specified metadata.
:param entry: The entry to store in the Qdrant collection. :param entry: The entry to store in the Qdrant collection.
:param collection_name: The name of the collection to store the information in, optional. If not provided,
the default collection is used.
""" """
await self._ensure_collection_exists() collection_name = collection_name or self._default_collection_name
await self._ensure_collection_exists(collection_name)
# Embed the document # Embed the document
embeddings = await self._embedding_provider.embed_documents([entry.content]) embeddings = await self._embedding_provider.embed_documents([entry.content])
@@ -82,7 +72,7 @@ class QdrantConnector:
vector_name = self._embedding_provider.get_vector_name() vector_name = self._embedding_provider.get_vector_name()
payload = {"document": entry.content, "metadata": entry.metadata} payload = {"document": entry.content, "metadata": entry.metadata}
await self._client.upsert( await self._client.upsert(
collection_name=self._collection_name, collection_name=collection_name,
points=[ points=[
models.PointStruct( models.PointStruct(
id=uuid.uuid4().hex, id=uuid.uuid4().hex,
@@ -92,13 +82,19 @@ class QdrantConnector:
], ],
) )
async def search(self, query: str) -> list[Entry]: async def search(
self, query: str, *, collection_name: Optional[str] = None, limit: int = 10
) -> list[Entry]:
""" """
Find points in the Qdrant collection. If there are no entries found, an empty list is returned. Find points in the Qdrant collection. If there are no entries found, an empty list is returned.
:param query: The query to use for the search. :param query: The query to use for the search.
:param collection_name: The name of the collection to search in, optional. If not provided,
the default collection is used.
:param limit: The maximum number of entries to return.
:return: A list of entries found. :return: A list of entries found.
""" """
collection_exists = await self._client.collection_exists(self._collection_name) collection_name = collection_name or self._default_collection_name
collection_exists = await self._client.collection_exists(collection_name)
if not collection_exists: if not collection_exists:
return [] return []
@@ -108,9 +104,9 @@ class QdrantConnector:
# Search in Qdrant # Search in Qdrant
search_results = await self._client.search( search_results = await self._client.search(
collection_name=self._collection_name, collection_name=collection_name,
query_vector=models.NamedVector(name=vector_name, vector=query_vector), query_vector=models.NamedVector(name=vector_name, vector=query_vector),
limit=10, limit=limit,
) )
return [ return [
@@ -120,3 +116,25 @@ class QdrantConnector:
) )
for result in search_results for result in search_results
] ]
async def _ensure_collection_exists(self, collection_name: str):
"""
Ensure that the collection exists, creating it if necessary.
:param collection_name: The name of the collection to ensure exists.
"""
collection_exists = await self._client.collection_exists(collection_name)
if not collection_exists:
# Create the collection with the appropriate vector size
vector_size = self._embedding_provider.get_vector_size()
# Use the vector name as defined in the embedding provider
vector_name = self._embedding_provider.get_vector_name()
await self._client.create_collection(
collection_name=collection_name,
vectors_config={
vector_name: models.VectorParams(
size=vector_size,
distance=models.Distance.COSINE,
)
},
)

View File

@@ -1,7 +1,7 @@
import json import json
import logging import logging
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import AsyncIterator, List from typing import AsyncIterator, List, Optional
from mcp.server import Server from mcp.server import Server
from mcp.server.fastmcp import Context, FastMCP from mcp.server.fastmcp import Context, FastMCP
@@ -75,12 +75,15 @@ async def store(
# If we set it to be optional, some of the MCP clients, like Cursor, cannot # If we set it to be optional, some of the MCP clients, like Cursor, cannot
# handle the optional parameter correctly. # handle the optional parameter correctly.
metadata: Metadata = None, metadata: Metadata = None,
collection_name: Optional[str] = None,
) -> str: ) -> str:
""" """
Store some information in Qdrant. Store some information in Qdrant.
:param ctx: The context for the request. :param ctx: The context for the request.
:param information: The information to store. :param information: The information to store.
:param metadata: JSON metadata to store with the information, optional. :param metadata: JSON metadata to store with the information, optional.
:param collection_name: The name of the collection to store the information in, optional. If not provided,
the default collection is used.
:return: A message indicating that the information was stored. :return: A message indicating that the information was stored.
""" """
await ctx.debug(f"Storing information {information} in Qdrant") await ctx.debug(f"Storing information {information} in Qdrant")
@@ -88,23 +91,37 @@ async def store(
"qdrant_connector" "qdrant_connector"
] ]
entry = Entry(content=information, metadata=metadata) entry = Entry(content=information, metadata=metadata)
await qdrant_connector.store(entry) await qdrant_connector.store(entry, collection_name=collection_name)
if collection_name:
return f"Remembered: {information} in collection {collection_name}"
return f"Remembered: {information}" return f"Remembered: {information}"
@mcp.tool(name="qdrant-find", description=tool_settings.tool_find_description) @mcp.tool(name="qdrant-find", description=tool_settings.tool_find_description)
async def find(ctx: Context, query: str) -> List[str]: async def find(
ctx: Context,
query: str,
collection_name: Optional[str] = None,
limit: int = 10,
) -> List[str]:
""" """
Find memories in Qdrant. Find memories in Qdrant.
:param ctx: The context for the request. :param ctx: The context for the request.
:param query: The query to use for the search. :param query: The query to use for the search.
:param collection_name: The name of the collection to search in, optional. If not provided,
the default collection is used.
:param limit: The maximum number of entries to return, optional. Default is 10.
:return: A list of entries found. :return: A list of entries found.
""" """
await ctx.debug(f"Finding results for query {query}") await ctx.debug(f"Finding results for query {query}")
if collection_name:
await ctx.debug(f"Overriding the collection name with {collection_name}")
qdrant_connector: QdrantConnector = ctx.request_context.lifespan_context[ qdrant_connector: QdrantConnector = ctx.request_context.lifespan_context[
"qdrant_connector" "qdrant_connector"
] ]
entries = await qdrant_connector.search(query) entries = await qdrant_connector.search(
query, collection_name=collection_name, limit=limit
)
if not entries: if not entries:
return [f"No information found for the query '{query}'"] return [f"No information found for the query '{query}'"]
content = [ content = [

View File

@@ -97,7 +97,7 @@ async def test_ensure_collection_exists(qdrant_connector):
"""Test that the collection is created if it doesn't exist.""" """Test that the collection is created if it doesn't exist."""
# The collection shouldn't exist yet # The collection shouldn't exist yet
assert not await qdrant_connector._client.collection_exists( assert not await qdrant_connector._client.collection_exists(
qdrant_connector._collection_name qdrant_connector._default_collection_name
) )
# Storing an entry should create the collection # Storing an entry should create the collection
@@ -106,7 +106,7 @@ async def test_ensure_collection_exists(qdrant_connector):
# Now the collection should exist # Now the collection should exist
assert await qdrant_connector._client.collection_exists( assert await qdrant_connector._client.collection_exists(
qdrant_connector._collection_name qdrant_connector._default_collection_name
) )
@@ -159,3 +159,80 @@ async def test_entry_without_metadata(qdrant_connector):
assert len(results) == 1 assert len(results) == 1
assert results[0].content == "Entry without metadata" assert results[0].content == "Entry without metadata"
assert results[0].metadata is None assert results[0].metadata is None
@pytest.mark.asyncio
async def test_custom_collection_store_and_search(qdrant_connector):
"""Test storing and searching in a custom collection."""
# Define a custom collection name
custom_collection = f"custom_collection_{uuid.uuid4().hex}"
# Store a test entry in the custom collection
test_entry = Entry(
content="This is stored in a custom collection",
metadata={"custom": True},
)
await qdrant_connector.store(test_entry, collection_name=custom_collection)
# Search in the custom collection
results = await qdrant_connector.search(
"custom collection", collection_name=custom_collection
)
# Verify results
assert len(results) == 1
assert results[0].content == test_entry.content
assert results[0].metadata == test_entry.metadata
# Verify the entry is not in the default collection
default_results = await qdrant_connector.search("custom collection")
assert len(default_results) == 0
@pytest.mark.asyncio
async def test_multiple_collections(qdrant_connector):
"""Test using multiple collections with the same connector."""
# Define two custom collection names
collection_a = f"collection_a_{uuid.uuid4().hex}"
collection_b = f"collection_b_{uuid.uuid4().hex}"
# Store entries in different collections
entry_a = Entry(
content="This belongs to collection A", metadata={"collection": "A"}
)
entry_b = Entry(
content="This belongs to collection B", metadata={"collection": "B"}
)
entry_default = Entry(content="This belongs to the default collection")
await qdrant_connector.store(entry_a, collection_name=collection_a)
await qdrant_connector.store(entry_b, collection_name=collection_b)
await qdrant_connector.store(entry_default)
# Search in collection A
results_a = await qdrant_connector.search("belongs", collection_name=collection_a)
assert len(results_a) == 1
assert results_a[0].content == entry_a.content
# Search in collection B
results_b = await qdrant_connector.search("belongs", collection_name=collection_b)
assert len(results_b) == 1
assert results_b[0].content == entry_b.content
# Search in default collection
results_default = await qdrant_connector.search("belongs")
assert len(results_default) == 1
assert results_default[0].content == entry_default.content
@pytest.mark.asyncio
async def test_nonexistent_collection_search(qdrant_connector):
"""Test searching in a collection that doesn't exist."""
# Search in a collection that doesn't exist
nonexistent_collection = f"nonexistent_{uuid.uuid4().hex}"
results = await qdrant_connector.search(
"test query", collection_name=nonexistent_collection
)
# Verify results
assert len(results) == 0

11
uv.lock generated
View File

@@ -441,6 +441,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 }, { url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 },
] ]
[[package]]
name = "isort"
version = "6.0.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/b8/21/1e2a441f74a653a144224d7d21afe8f4169e6c7c20bb13aec3a2dc3815e0/isort-6.0.1.tar.gz", hash = "sha256:1cb5df28dfbc742e490c5e41bad6da41b805b0a8be7bc93cd0fb2a8a890ac450", size = 821955 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c1/11/114d0a5f4dabbdcedc1125dee0888514c3c3b16d3e9facad87ed96fad97c/isort-6.0.1-py3-none-any.whl", hash = "sha256:2dc5d7f65c9678d94c88dfc29161a320eec67328bc97aad576874cb4be1e9615", size = 94186 },
]
[[package]] [[package]]
name = "loguru" name = "loguru"
version = "0.7.3" version = "0.7.3"
@@ -504,6 +513,7 @@ dependencies = [
[package.dev-dependencies] [package.dev-dependencies]
dev = [ dev = [
{ name = "isort" },
{ name = "pre-commit" }, { name = "pre-commit" },
{ name = "pyright" }, { name = "pyright" },
{ name = "pytest" }, { name = "pytest" },
@@ -521,6 +531,7 @@ requires-dist = [
[package.metadata.requires-dev] [package.metadata.requires-dev]
dev = [ dev = [
{ name = "isort", specifier = ">=6.0.1" },
{ name = "pre-commit", specifier = ">=4.1.0" }, { name = "pre-commit", specifier = ">=4.1.0" },
{ name = "pyright", specifier = ">=1.1.389" }, { name = "pyright", specifier = ">=1.1.389" },
{ name = "pytest", specifier = ">=8.3.3" }, { name = "pytest", specifier = ">=8.3.3" },