Support multiple collections (#26)
* Allow passing the collection name in each request to override the default * Allow getting the collection names in QdrantConnector * get vector size from model description * ruff format * add isort * apply pre-commit hooks --------- Co-authored-by: generall <andrey@vasnetsov.com>
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
[](https://smithery.ai/protocol/mcp-server-qdrant)
|
||||
|
||||
> The [Model Context Protocol (MCP)](https://modelcontextprotocol.io/introduction) is an open protocol that enables
|
||||
> seamless integration between LLM applications and external data sources and tools. Whether you’re building an
|
||||
> seamless integration between LLM applications and external data sources and tools. Whether you're building an
|
||||
> AI-powered IDE, enhancing a chat interface, or creating custom AI workflows, MCP provides a standardized way to
|
||||
> connect LLMs with the context they need.
|
||||
|
||||
@@ -25,11 +25,15 @@ It acts as a semantic memory layer on top of the Qdrant database.
|
||||
- Input:
|
||||
- `information` (string): Information to store
|
||||
- `metadata` (JSON): Optional metadata to store
|
||||
- `collection_name` (string): Name of the collection to store the information in, optional. If not provided,
|
||||
the default collection name will be used.
|
||||
- Returns: Confirmation message
|
||||
2. `qdrant-find`
|
||||
- Retrieve relevant information from the Qdrant database
|
||||
- Input:
|
||||
- `query` (string): Query to use for searching
|
||||
- `collection_name` (string): Name of the collection to store the information in, optional. If not provided,
|
||||
the default collection name will be used.
|
||||
- Returns: Information stored in the Qdrant database as separate messages
|
||||
|
||||
## Environment Variables
|
||||
@@ -40,7 +44,7 @@ The configuration of the server is done using environment variables:
|
||||
|--------------------------|---------------------------------------------------------------------|-------------------------------------------------------------------|
|
||||
| `QDRANT_URL` | URL of the Qdrant server | None |
|
||||
| `QDRANT_API_KEY` | API key for the Qdrant server | None |
|
||||
| `COLLECTION_NAME` | Name of the collection to use | *Required* |
|
||||
| `COLLECTION_NAME` | Name of the default collection to use. | *Required* |
|
||||
| `QDRANT_LOCAL_PATH` | Path to the local Qdrant database (alternative to `QDRANT_URL`) | None |
|
||||
| `EMBEDDING_PROVIDER` | Embedding provider to use (currently only "fastembed" is supported) | `fastembed` |
|
||||
| `EMBEDDING_MODEL` | Name of the embedding model to use | `sentence-transformers/all-MiniLM-L6-v2` |
|
||||
|
||||
@@ -18,11 +18,12 @@ build-backend = "hatchling.build"
|
||||
|
||||
[tool.uv]
|
||||
dev-dependencies = [
|
||||
"isort>=6.0.1",
|
||||
"pre-commit>=4.1.0",
|
||||
"pyright>=1.1.389",
|
||||
"pytest>=8.3.3",
|
||||
"pytest-asyncio>=0.23.0",
|
||||
"ruff>=0.8.0"
|
||||
"ruff>=0.8.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -19,3 +19,8 @@ class EmbeddingProvider(ABC):
|
||||
def get_vector_name(self) -> str:
|
||||
"""Get the name of the vector for the Qdrant collection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_vector_size(self) -> int:
|
||||
"""Get the size of the vector for the Qdrant collection."""
|
||||
pass
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
from typing import List
|
||||
|
||||
from fastembed import TextEmbedding
|
||||
from fastembed.common.model_description import DenseModelDescription
|
||||
|
||||
from mcp_server_qdrant.embeddings.base import EmbeddingProvider
|
||||
|
||||
@@ -41,3 +42,10 @@ class FastEmbedProvider(EmbeddingProvider):
|
||||
"""
|
||||
model_name = self.embedding_model.model_name.split("/")[-1].lower()
|
||||
return f"fast-{model_name}"
|
||||
|
||||
def get_vector_size(self) -> int:
|
||||
"""Get the size of the vector for the Qdrant collection."""
|
||||
model_description: DenseModelDescription = (
|
||||
self.embedding_model._get_model_description(self.model_name)
|
||||
)
|
||||
return model_description.dim
|
||||
|
||||
@@ -41,39 +41,29 @@ class QdrantConnector:
|
||||
):
|
||||
self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None
|
||||
self._qdrant_api_key = qdrant_api_key
|
||||
self._collection_name = collection_name
|
||||
self._default_collection_name = collection_name
|
||||
self._embedding_provider = embedding_provider
|
||||
self._client = AsyncQdrantClient(
|
||||
location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path
|
||||
)
|
||||
|
||||
async def _ensure_collection_exists(self):
|
||||
"""Ensure that the collection exists, creating it if necessary."""
|
||||
collection_exists = await self._client.collection_exists(self._collection_name)
|
||||
if not collection_exists:
|
||||
# Create the collection with the appropriate vector size
|
||||
# We'll get the vector size by embedding a sample text
|
||||
sample_vector = await self._embedding_provider.embed_query("sample text")
|
||||
vector_size = len(sample_vector)
|
||||
async def get_collection_names(self) -> list[str]:
|
||||
"""
|
||||
Get the names of all collections in the Qdrant server.
|
||||
:return: A list of collection names.
|
||||
"""
|
||||
response = await self._client.get_collections()
|
||||
return [collection.name for collection in response.collections]
|
||||
|
||||
# Use the vector name as defined in the embedding provider
|
||||
vector_name = self._embedding_provider.get_vector_name()
|
||||
await self._client.create_collection(
|
||||
collection_name=self._collection_name,
|
||||
vectors_config={
|
||||
vector_name: models.VectorParams(
|
||||
size=vector_size,
|
||||
distance=models.Distance.COSINE,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
async def store(self, entry: Entry):
|
||||
async def store(self, entry: Entry, *, collection_name: Optional[str] = None):
|
||||
"""
|
||||
Store some information in the Qdrant collection, along with the specified metadata.
|
||||
:param entry: The entry to store in the Qdrant collection.
|
||||
:param collection_name: The name of the collection to store the information in, optional. If not provided,
|
||||
the default collection is used.
|
||||
"""
|
||||
await self._ensure_collection_exists()
|
||||
collection_name = collection_name or self._default_collection_name
|
||||
await self._ensure_collection_exists(collection_name)
|
||||
|
||||
# Embed the document
|
||||
embeddings = await self._embedding_provider.embed_documents([entry.content])
|
||||
@@ -82,7 +72,7 @@ class QdrantConnector:
|
||||
vector_name = self._embedding_provider.get_vector_name()
|
||||
payload = {"document": entry.content, "metadata": entry.metadata}
|
||||
await self._client.upsert(
|
||||
collection_name=self._collection_name,
|
||||
collection_name=collection_name,
|
||||
points=[
|
||||
models.PointStruct(
|
||||
id=uuid.uuid4().hex,
|
||||
@@ -92,13 +82,19 @@ class QdrantConnector:
|
||||
],
|
||||
)
|
||||
|
||||
async def search(self, query: str) -> list[Entry]:
|
||||
async def search(
|
||||
self, query: str, *, collection_name: Optional[str] = None, limit: int = 10
|
||||
) -> list[Entry]:
|
||||
"""
|
||||
Find points in the Qdrant collection. If there are no entries found, an empty list is returned.
|
||||
:param query: The query to use for the search.
|
||||
:param collection_name: The name of the collection to search in, optional. If not provided,
|
||||
the default collection is used.
|
||||
:param limit: The maximum number of entries to return.
|
||||
:return: A list of entries found.
|
||||
"""
|
||||
collection_exists = await self._client.collection_exists(self._collection_name)
|
||||
collection_name = collection_name or self._default_collection_name
|
||||
collection_exists = await self._client.collection_exists(collection_name)
|
||||
if not collection_exists:
|
||||
return []
|
||||
|
||||
@@ -108,9 +104,9 @@ class QdrantConnector:
|
||||
|
||||
# Search in Qdrant
|
||||
search_results = await self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
collection_name=collection_name,
|
||||
query_vector=models.NamedVector(name=vector_name, vector=query_vector),
|
||||
limit=10,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return [
|
||||
@@ -120,3 +116,25 @@ class QdrantConnector:
|
||||
)
|
||||
for result in search_results
|
||||
]
|
||||
|
||||
async def _ensure_collection_exists(self, collection_name: str):
|
||||
"""
|
||||
Ensure that the collection exists, creating it if necessary.
|
||||
:param collection_name: The name of the collection to ensure exists.
|
||||
"""
|
||||
collection_exists = await self._client.collection_exists(collection_name)
|
||||
if not collection_exists:
|
||||
# Create the collection with the appropriate vector size
|
||||
vector_size = self._embedding_provider.get_vector_size()
|
||||
|
||||
# Use the vector name as defined in the embedding provider
|
||||
vector_name = self._embedding_provider.get_vector_name()
|
||||
await self._client.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config={
|
||||
vector_name: models.VectorParams(
|
||||
size=vector_size,
|
||||
distance=models.Distance.COSINE,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncIterator, List
|
||||
from typing import AsyncIterator, List, Optional
|
||||
|
||||
from mcp.server import Server
|
||||
from mcp.server.fastmcp import Context, FastMCP
|
||||
@@ -75,12 +75,15 @@ async def store(
|
||||
# If we set it to be optional, some of the MCP clients, like Cursor, cannot
|
||||
# handle the optional parameter correctly.
|
||||
metadata: Metadata = None,
|
||||
collection_name: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Store some information in Qdrant.
|
||||
:param ctx: The context for the request.
|
||||
:param information: The information to store.
|
||||
:param metadata: JSON metadata to store with the information, optional.
|
||||
:param collection_name: The name of the collection to store the information in, optional. If not provided,
|
||||
the default collection is used.
|
||||
:return: A message indicating that the information was stored.
|
||||
"""
|
||||
await ctx.debug(f"Storing information {information} in Qdrant")
|
||||
@@ -88,23 +91,37 @@ async def store(
|
||||
"qdrant_connector"
|
||||
]
|
||||
entry = Entry(content=information, metadata=metadata)
|
||||
await qdrant_connector.store(entry)
|
||||
await qdrant_connector.store(entry, collection_name=collection_name)
|
||||
if collection_name:
|
||||
return f"Remembered: {information} in collection {collection_name}"
|
||||
return f"Remembered: {information}"
|
||||
|
||||
|
||||
@mcp.tool(name="qdrant-find", description=tool_settings.tool_find_description)
|
||||
async def find(ctx: Context, query: str) -> List[str]:
|
||||
async def find(
|
||||
ctx: Context,
|
||||
query: str,
|
||||
collection_name: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Find memories in Qdrant.
|
||||
:param ctx: The context for the request.
|
||||
:param query: The query to use for the search.
|
||||
:param collection_name: The name of the collection to search in, optional. If not provided,
|
||||
the default collection is used.
|
||||
:param limit: The maximum number of entries to return, optional. Default is 10.
|
||||
:return: A list of entries found.
|
||||
"""
|
||||
await ctx.debug(f"Finding results for query {query}")
|
||||
if collection_name:
|
||||
await ctx.debug(f"Overriding the collection name with {collection_name}")
|
||||
qdrant_connector: QdrantConnector = ctx.request_context.lifespan_context[
|
||||
"qdrant_connector"
|
||||
]
|
||||
entries = await qdrant_connector.search(query)
|
||||
entries = await qdrant_connector.search(
|
||||
query, collection_name=collection_name, limit=limit
|
||||
)
|
||||
if not entries:
|
||||
return [f"No information found for the query '{query}'"]
|
||||
content = [
|
||||
|
||||
@@ -97,7 +97,7 @@ async def test_ensure_collection_exists(qdrant_connector):
|
||||
"""Test that the collection is created if it doesn't exist."""
|
||||
# The collection shouldn't exist yet
|
||||
assert not await qdrant_connector._client.collection_exists(
|
||||
qdrant_connector._collection_name
|
||||
qdrant_connector._default_collection_name
|
||||
)
|
||||
|
||||
# Storing an entry should create the collection
|
||||
@@ -106,7 +106,7 @@ async def test_ensure_collection_exists(qdrant_connector):
|
||||
|
||||
# Now the collection should exist
|
||||
assert await qdrant_connector._client.collection_exists(
|
||||
qdrant_connector._collection_name
|
||||
qdrant_connector._default_collection_name
|
||||
)
|
||||
|
||||
|
||||
@@ -159,3 +159,80 @@ async def test_entry_without_metadata(qdrant_connector):
|
||||
assert len(results) == 1
|
||||
assert results[0].content == "Entry without metadata"
|
||||
assert results[0].metadata is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_collection_store_and_search(qdrant_connector):
|
||||
"""Test storing and searching in a custom collection."""
|
||||
# Define a custom collection name
|
||||
custom_collection = f"custom_collection_{uuid.uuid4().hex}"
|
||||
|
||||
# Store a test entry in the custom collection
|
||||
test_entry = Entry(
|
||||
content="This is stored in a custom collection",
|
||||
metadata={"custom": True},
|
||||
)
|
||||
await qdrant_connector.store(test_entry, collection_name=custom_collection)
|
||||
|
||||
# Search in the custom collection
|
||||
results = await qdrant_connector.search(
|
||||
"custom collection", collection_name=custom_collection
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert len(results) == 1
|
||||
assert results[0].content == test_entry.content
|
||||
assert results[0].metadata == test_entry.metadata
|
||||
|
||||
# Verify the entry is not in the default collection
|
||||
default_results = await qdrant_connector.search("custom collection")
|
||||
assert len(default_results) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_collections(qdrant_connector):
|
||||
"""Test using multiple collections with the same connector."""
|
||||
# Define two custom collection names
|
||||
collection_a = f"collection_a_{uuid.uuid4().hex}"
|
||||
collection_b = f"collection_b_{uuid.uuid4().hex}"
|
||||
|
||||
# Store entries in different collections
|
||||
entry_a = Entry(
|
||||
content="This belongs to collection A", metadata={"collection": "A"}
|
||||
)
|
||||
entry_b = Entry(
|
||||
content="This belongs to collection B", metadata={"collection": "B"}
|
||||
)
|
||||
entry_default = Entry(content="This belongs to the default collection")
|
||||
|
||||
await qdrant_connector.store(entry_a, collection_name=collection_a)
|
||||
await qdrant_connector.store(entry_b, collection_name=collection_b)
|
||||
await qdrant_connector.store(entry_default)
|
||||
|
||||
# Search in collection A
|
||||
results_a = await qdrant_connector.search("belongs", collection_name=collection_a)
|
||||
assert len(results_a) == 1
|
||||
assert results_a[0].content == entry_a.content
|
||||
|
||||
# Search in collection B
|
||||
results_b = await qdrant_connector.search("belongs", collection_name=collection_b)
|
||||
assert len(results_b) == 1
|
||||
assert results_b[0].content == entry_b.content
|
||||
|
||||
# Search in default collection
|
||||
results_default = await qdrant_connector.search("belongs")
|
||||
assert len(results_default) == 1
|
||||
assert results_default[0].content == entry_default.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonexistent_collection_search(qdrant_connector):
|
||||
"""Test searching in a collection that doesn't exist."""
|
||||
# Search in a collection that doesn't exist
|
||||
nonexistent_collection = f"nonexistent_{uuid.uuid4().hex}"
|
||||
results = await qdrant_connector.search(
|
||||
"test query", collection_name=nonexistent_collection
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert len(results) == 0
|
||||
|
||||
11
uv.lock
generated
11
uv.lock
generated
@@ -441,6 +441,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "isort"
|
||||
version = "6.0.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b8/21/1e2a441f74a653a144224d7d21afe8f4169e6c7c20bb13aec3a2dc3815e0/isort-6.0.1.tar.gz", hash = "sha256:1cb5df28dfbc742e490c5e41bad6da41b805b0a8be7bc93cd0fb2a8a890ac450", size = 821955 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c1/11/114d0a5f4dabbdcedc1125dee0888514c3c3b16d3e9facad87ed96fad97c/isort-6.0.1-py3-none-any.whl", hash = "sha256:2dc5d7f65c9678d94c88dfc29161a320eec67328bc97aad576874cb4be1e9615", size = 94186 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "loguru"
|
||||
version = "0.7.3"
|
||||
@@ -504,6 +513,7 @@ dependencies = [
|
||||
|
||||
[package.dev-dependencies]
|
||||
dev = [
|
||||
{ name = "isort" },
|
||||
{ name = "pre-commit" },
|
||||
{ name = "pyright" },
|
||||
{ name = "pytest" },
|
||||
@@ -521,6 +531,7 @@ requires-dist = [
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
{ name = "isort", specifier = ">=6.0.1" },
|
||||
{ name = "pre-commit", specifier = ">=4.1.0" },
|
||||
{ name = "pyright", specifier = ">=1.1.389" },
|
||||
{ name = "pytest", specifier = ">=8.3.3" },
|
||||
|
||||
Reference in New Issue
Block a user