From 7aad8ebb3c6b3b50cb9f8d854d324c2cbdffbadf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20=C5=81ukawski?= Date: Mon, 7 Apr 2025 11:16:45 +0200 Subject: [PATCH] Support multiple collections (#26) * Allow passing the collection name in each request to override the default * Allow getting the collection names in QdrantConnector * get vector size from model description * ruff format * add isort * apply pre-commit hooks --------- Co-authored-by: generall --- README.md | 8 +- pyproject.toml | 3 +- src/mcp_server_qdrant/embeddings/base.py | 5 ++ src/mcp_server_qdrant/embeddings/fastembed.py | 8 ++ src/mcp_server_qdrant/qdrant.py | 74 ++++++++++------- src/mcp_server_qdrant/server.py | 25 +++++- tests/test_qdrant_integration.py | 81 ++++++++++++++++++- uv.lock | 11 +++ 8 files changed, 178 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index 106d70a..c9f5796 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![smithery badge](https://smithery.ai/badge/mcp-server-qdrant)](https://smithery.ai/protocol/mcp-server-qdrant) > The [Model Context Protocol (MCP)](https://modelcontextprotocol.io/introduction) is an open protocol that enables -> seamless integration between LLM applications and external data sources and tools. Whether you’re building an +> seamless integration between LLM applications and external data sources and tools. Whether you're building an > AI-powered IDE, enhancing a chat interface, or creating custom AI workflows, MCP provides a standardized way to > connect LLMs with the context they need. @@ -25,11 +25,15 @@ It acts as a semantic memory layer on top of the Qdrant database. - Input: - `information` (string): Information to store - `metadata` (JSON): Optional metadata to store + - `collection_name` (string): Name of the collection to store the information in, optional. If not provided, + the default collection name will be used. - Returns: Confirmation message 2. `qdrant-find` - Retrieve relevant information from the Qdrant database - Input: - `query` (string): Query to use for searching + - `collection_name` (string): Name of the collection to store the information in, optional. If not provided, + the default collection name will be used. - Returns: Information stored in the Qdrant database as separate messages ## Environment Variables @@ -40,7 +44,7 @@ The configuration of the server is done using environment variables: |--------------------------|---------------------------------------------------------------------|-------------------------------------------------------------------| | `QDRANT_URL` | URL of the Qdrant server | None | | `QDRANT_API_KEY` | API key for the Qdrant server | None | -| `COLLECTION_NAME` | Name of the collection to use | *Required* | +| `COLLECTION_NAME` | Name of the default collection to use. | *Required* | | `QDRANT_LOCAL_PATH` | Path to the local Qdrant database (alternative to `QDRANT_URL`) | None | | `EMBEDDING_PROVIDER` | Embedding provider to use (currently only "fastembed" is supported) | `fastembed` | | `EMBEDDING_MODEL` | Name of the embedding model to use | `sentence-transformers/all-MiniLM-L6-v2` | diff --git a/pyproject.toml b/pyproject.toml index 1c99f02..ef06bd5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,11 +18,12 @@ build-backend = "hatchling.build" [tool.uv] dev-dependencies = [ + "isort>=6.0.1", "pre-commit>=4.1.0", "pyright>=1.1.389", "pytest>=8.3.3", "pytest-asyncio>=0.23.0", - "ruff>=0.8.0" + "ruff>=0.8.0", ] [project.scripts] diff --git a/src/mcp_server_qdrant/embeddings/base.py b/src/mcp_server_qdrant/embeddings/base.py index e2fe34d..80c1d13 100644 --- a/src/mcp_server_qdrant/embeddings/base.py +++ b/src/mcp_server_qdrant/embeddings/base.py @@ -19,3 +19,8 @@ class EmbeddingProvider(ABC): def get_vector_name(self) -> str: """Get the name of the vector for the Qdrant collection.""" pass + + @abstractmethod + def get_vector_size(self) -> int: + """Get the size of the vector for the Qdrant collection.""" + pass diff --git a/src/mcp_server_qdrant/embeddings/fastembed.py b/src/mcp_server_qdrant/embeddings/fastembed.py index 9eb6b51..628fe14 100644 --- a/src/mcp_server_qdrant/embeddings/fastembed.py +++ b/src/mcp_server_qdrant/embeddings/fastembed.py @@ -2,6 +2,7 @@ import asyncio from typing import List from fastembed import TextEmbedding +from fastembed.common.model_description import DenseModelDescription from mcp_server_qdrant.embeddings.base import EmbeddingProvider @@ -41,3 +42,10 @@ class FastEmbedProvider(EmbeddingProvider): """ model_name = self.embedding_model.model_name.split("/")[-1].lower() return f"fast-{model_name}" + + def get_vector_size(self) -> int: + """Get the size of the vector for the Qdrant collection.""" + model_description: DenseModelDescription = ( + self.embedding_model._get_model_description(self.model_name) + ) + return model_description.dim diff --git a/src/mcp_server_qdrant/qdrant.py b/src/mcp_server_qdrant/qdrant.py index b56c200..e12459d 100644 --- a/src/mcp_server_qdrant/qdrant.py +++ b/src/mcp_server_qdrant/qdrant.py @@ -41,39 +41,29 @@ class QdrantConnector: ): self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None self._qdrant_api_key = qdrant_api_key - self._collection_name = collection_name + self._default_collection_name = collection_name self._embedding_provider = embedding_provider self._client = AsyncQdrantClient( location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path ) - async def _ensure_collection_exists(self): - """Ensure that the collection exists, creating it if necessary.""" - collection_exists = await self._client.collection_exists(self._collection_name) - if not collection_exists: - # Create the collection with the appropriate vector size - # We'll get the vector size by embedding a sample text - sample_vector = await self._embedding_provider.embed_query("sample text") - vector_size = len(sample_vector) + async def get_collection_names(self) -> list[str]: + """ + Get the names of all collections in the Qdrant server. + :return: A list of collection names. + """ + response = await self._client.get_collections() + return [collection.name for collection in response.collections] - # Use the vector name as defined in the embedding provider - vector_name = self._embedding_provider.get_vector_name() - await self._client.create_collection( - collection_name=self._collection_name, - vectors_config={ - vector_name: models.VectorParams( - size=vector_size, - distance=models.Distance.COSINE, - ) - }, - ) - - async def store(self, entry: Entry): + async def store(self, entry: Entry, *, collection_name: Optional[str] = None): """ Store some information in the Qdrant collection, along with the specified metadata. :param entry: The entry to store in the Qdrant collection. + :param collection_name: The name of the collection to store the information in, optional. If not provided, + the default collection is used. """ - await self._ensure_collection_exists() + collection_name = collection_name or self._default_collection_name + await self._ensure_collection_exists(collection_name) # Embed the document embeddings = await self._embedding_provider.embed_documents([entry.content]) @@ -82,7 +72,7 @@ class QdrantConnector: vector_name = self._embedding_provider.get_vector_name() payload = {"document": entry.content, "metadata": entry.metadata} await self._client.upsert( - collection_name=self._collection_name, + collection_name=collection_name, points=[ models.PointStruct( id=uuid.uuid4().hex, @@ -92,13 +82,19 @@ class QdrantConnector: ], ) - async def search(self, query: str) -> list[Entry]: + async def search( + self, query: str, *, collection_name: Optional[str] = None, limit: int = 10 + ) -> list[Entry]: """ Find points in the Qdrant collection. If there are no entries found, an empty list is returned. :param query: The query to use for the search. + :param collection_name: The name of the collection to search in, optional. If not provided, + the default collection is used. + :param limit: The maximum number of entries to return. :return: A list of entries found. """ - collection_exists = await self._client.collection_exists(self._collection_name) + collection_name = collection_name or self._default_collection_name + collection_exists = await self._client.collection_exists(collection_name) if not collection_exists: return [] @@ -108,9 +104,9 @@ class QdrantConnector: # Search in Qdrant search_results = await self._client.search( - collection_name=self._collection_name, + collection_name=collection_name, query_vector=models.NamedVector(name=vector_name, vector=query_vector), - limit=10, + limit=limit, ) return [ @@ -120,3 +116,25 @@ class QdrantConnector: ) for result in search_results ] + + async def _ensure_collection_exists(self, collection_name: str): + """ + Ensure that the collection exists, creating it if necessary. + :param collection_name: The name of the collection to ensure exists. + """ + collection_exists = await self._client.collection_exists(collection_name) + if not collection_exists: + # Create the collection with the appropriate vector size + vector_size = self._embedding_provider.get_vector_size() + + # Use the vector name as defined in the embedding provider + vector_name = self._embedding_provider.get_vector_name() + await self._client.create_collection( + collection_name=collection_name, + vectors_config={ + vector_name: models.VectorParams( + size=vector_size, + distance=models.Distance.COSINE, + ) + }, + ) diff --git a/src/mcp_server_qdrant/server.py b/src/mcp_server_qdrant/server.py index 051a999..4a726c0 100644 --- a/src/mcp_server_qdrant/server.py +++ b/src/mcp_server_qdrant/server.py @@ -1,7 +1,7 @@ import json import logging from contextlib import asynccontextmanager -from typing import AsyncIterator, List +from typing import AsyncIterator, List, Optional from mcp.server import Server from mcp.server.fastmcp import Context, FastMCP @@ -75,12 +75,15 @@ async def store( # If we set it to be optional, some of the MCP clients, like Cursor, cannot # handle the optional parameter correctly. metadata: Metadata = None, + collection_name: Optional[str] = None, ) -> str: """ Store some information in Qdrant. :param ctx: The context for the request. :param information: The information to store. :param metadata: JSON metadata to store with the information, optional. + :param collection_name: The name of the collection to store the information in, optional. If not provided, + the default collection is used. :return: A message indicating that the information was stored. """ await ctx.debug(f"Storing information {information} in Qdrant") @@ -88,23 +91,37 @@ async def store( "qdrant_connector" ] entry = Entry(content=information, metadata=metadata) - await qdrant_connector.store(entry) + await qdrant_connector.store(entry, collection_name=collection_name) + if collection_name: + return f"Remembered: {information} in collection {collection_name}" return f"Remembered: {information}" @mcp.tool(name="qdrant-find", description=tool_settings.tool_find_description) -async def find(ctx: Context, query: str) -> List[str]: +async def find( + ctx: Context, + query: str, + collection_name: Optional[str] = None, + limit: int = 10, +) -> List[str]: """ Find memories in Qdrant. :param ctx: The context for the request. :param query: The query to use for the search. + :param collection_name: The name of the collection to search in, optional. If not provided, + the default collection is used. + :param limit: The maximum number of entries to return, optional. Default is 10. :return: A list of entries found. """ await ctx.debug(f"Finding results for query {query}") + if collection_name: + await ctx.debug(f"Overriding the collection name with {collection_name}") qdrant_connector: QdrantConnector = ctx.request_context.lifespan_context[ "qdrant_connector" ] - entries = await qdrant_connector.search(query) + entries = await qdrant_connector.search( + query, collection_name=collection_name, limit=limit + ) if not entries: return [f"No information found for the query '{query}'"] content = [ diff --git a/tests/test_qdrant_integration.py b/tests/test_qdrant_integration.py index f990b8b..7efe659 100644 --- a/tests/test_qdrant_integration.py +++ b/tests/test_qdrant_integration.py @@ -97,7 +97,7 @@ async def test_ensure_collection_exists(qdrant_connector): """Test that the collection is created if it doesn't exist.""" # The collection shouldn't exist yet assert not await qdrant_connector._client.collection_exists( - qdrant_connector._collection_name + qdrant_connector._default_collection_name ) # Storing an entry should create the collection @@ -106,7 +106,7 @@ async def test_ensure_collection_exists(qdrant_connector): # Now the collection should exist assert await qdrant_connector._client.collection_exists( - qdrant_connector._collection_name + qdrant_connector._default_collection_name ) @@ -159,3 +159,80 @@ async def test_entry_without_metadata(qdrant_connector): assert len(results) == 1 assert results[0].content == "Entry without metadata" assert results[0].metadata is None + + +@pytest.mark.asyncio +async def test_custom_collection_store_and_search(qdrant_connector): + """Test storing and searching in a custom collection.""" + # Define a custom collection name + custom_collection = f"custom_collection_{uuid.uuid4().hex}" + + # Store a test entry in the custom collection + test_entry = Entry( + content="This is stored in a custom collection", + metadata={"custom": True}, + ) + await qdrant_connector.store(test_entry, collection_name=custom_collection) + + # Search in the custom collection + results = await qdrant_connector.search( + "custom collection", collection_name=custom_collection + ) + + # Verify results + assert len(results) == 1 + assert results[0].content == test_entry.content + assert results[0].metadata == test_entry.metadata + + # Verify the entry is not in the default collection + default_results = await qdrant_connector.search("custom collection") + assert len(default_results) == 0 + + +@pytest.mark.asyncio +async def test_multiple_collections(qdrant_connector): + """Test using multiple collections with the same connector.""" + # Define two custom collection names + collection_a = f"collection_a_{uuid.uuid4().hex}" + collection_b = f"collection_b_{uuid.uuid4().hex}" + + # Store entries in different collections + entry_a = Entry( + content="This belongs to collection A", metadata={"collection": "A"} + ) + entry_b = Entry( + content="This belongs to collection B", metadata={"collection": "B"} + ) + entry_default = Entry(content="This belongs to the default collection") + + await qdrant_connector.store(entry_a, collection_name=collection_a) + await qdrant_connector.store(entry_b, collection_name=collection_b) + await qdrant_connector.store(entry_default) + + # Search in collection A + results_a = await qdrant_connector.search("belongs", collection_name=collection_a) + assert len(results_a) == 1 + assert results_a[0].content == entry_a.content + + # Search in collection B + results_b = await qdrant_connector.search("belongs", collection_name=collection_b) + assert len(results_b) == 1 + assert results_b[0].content == entry_b.content + + # Search in default collection + results_default = await qdrant_connector.search("belongs") + assert len(results_default) == 1 + assert results_default[0].content == entry_default.content + + +@pytest.mark.asyncio +async def test_nonexistent_collection_search(qdrant_connector): + """Test searching in a collection that doesn't exist.""" + # Search in a collection that doesn't exist + nonexistent_collection = f"nonexistent_{uuid.uuid4().hex}" + results = await qdrant_connector.search( + "test query", collection_name=nonexistent_collection + ) + + # Verify results + assert len(results) == 0 diff --git a/uv.lock b/uv.lock index 94575d4..2d3b53b 100644 --- a/uv.lock +++ b/uv.lock @@ -441,6 +441,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 }, ] +[[package]] +name = "isort" +version = "6.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b8/21/1e2a441f74a653a144224d7d21afe8f4169e6c7c20bb13aec3a2dc3815e0/isort-6.0.1.tar.gz", hash = "sha256:1cb5df28dfbc742e490c5e41bad6da41b805b0a8be7bc93cd0fb2a8a890ac450", size = 821955 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/11/114d0a5f4dabbdcedc1125dee0888514c3c3b16d3e9facad87ed96fad97c/isort-6.0.1-py3-none-any.whl", hash = "sha256:2dc5d7f65c9678d94c88dfc29161a320eec67328bc97aad576874cb4be1e9615", size = 94186 }, +] + [[package]] name = "loguru" version = "0.7.3" @@ -504,6 +513,7 @@ dependencies = [ [package.dev-dependencies] dev = [ + { name = "isort" }, { name = "pre-commit" }, { name = "pyright" }, { name = "pytest" }, @@ -521,6 +531,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ + { name = "isort", specifier = ">=6.0.1" }, { name = "pre-commit", specifier = ">=4.1.0" }, { name = "pyright", specifier = ">=1.1.389" }, { name = "pytest", specifier = ">=8.3.3" },