Abstract the embedding providers

This commit is contained in:
Kacper Łukawski
2025-03-05 22:40:58 +01:00
parent f8ecab23be
commit 56f1c797fc
11 changed files with 257 additions and 46 deletions

5
.gitignore vendored
View File

@@ -159,4 +159,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear # and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ .idea/
# Project-specific settings
.aider*

View File

@@ -36,7 +36,7 @@ uv run mcp-server-qdrant \
--qdrant-url "http://localhost:6333" \ --qdrant-url "http://localhost:6333" \
--qdrant-api-key "your_api_key" \ --qdrant-api-key "your_api_key" \
--collection-name "my_collection" \ --collection-name "my_collection" \
--fastembed-model-name "sentence-transformers/all-MiniLM-L6-v2" --embedding-model "sentence-transformers/all-MiniLM-L6-v2"
``` ```
### Installing via Smithery ### Installing via Smithery
@@ -76,7 +76,7 @@ This MCP server will automatically create a collection with the specified name i
By default, the server will use the `sentence-transformers/all-MiniLM-L6-v2` embedding model to encode memories. By default, the server will use the `sentence-transformers/all-MiniLM-L6-v2` embedding model to encode memories.
For the time being, only [FastEmbed](https://qdrant.github.io/fastembed/) models are supported, and you can change it For the time being, only [FastEmbed](https://qdrant.github.io/fastembed/) models are supported, and you can change it
by passing the `--fastembed-model-name` argument to the server. by passing the `--embedding-model` argument to the server.
### Using the local mode of Qdrant ### Using the local mode of Qdrant
@@ -106,11 +106,31 @@ The configuration of the server can be also done using environment variables:
- `QDRANT_URL`: URL of the Qdrant server, e.g. `http://localhost:6333` - `QDRANT_URL`: URL of the Qdrant server, e.g. `http://localhost:6333`
- `QDRANT_API_KEY`: API key for the Qdrant server - `QDRANT_API_KEY`: API key for the Qdrant server
- `COLLECTION_NAME`: Name of the collection to use - `COLLECTION_NAME`: Name of the collection to use
- `FASTEMBED_MODEL_NAME`: Name of the FastEmbed model to use - `EMBEDDING_MODEL`: Name of the embedding model to use
- `EMBEDDING_PROVIDER`: Embedding provider to use (currently only "fastembed" is supported)
- `QDRANT_LOCAL_PATH`: Path to the local Qdrant database - `QDRANT_LOCAL_PATH`: Path to the local Qdrant database
You cannot provide `QDRANT_URL` and `QDRANT_LOCAL_PATH` at the same time. You cannot provide `QDRANT_URL` and `QDRANT_LOCAL_PATH` at the same time.
## Contributing
If you have suggestions for how mcp-server-qdrant could be improved, or want to report a bug, open an issue!
We'd love all and any contributions.
### Testing `mcp-server-qdrant` locally
The [MCP inspector](https://github.com/modelcontextprotocol/inspector) is a developer tool for testing and debugging MCP
servers. It runs both a client UI (default port 5173) and an MCP proxy server (default port 3000). Open the client UI in
your browser to use the inspector.
```shell
npx @modelcontextprotocol/inspector uv run mcp-server-qdrant \
--collection-name test \
--qdrant-local-path /tmp/qdrant-local-test
```
Once started, open your browser to http://localhost:5173 to access the inspector interface.
## License ## License
This MCP server is licensed under the MIT License. This means you are free to use, modify, and distribute the software, This MCP server is licensed under the MIT License. This means you are free to use, modify, and distribute the software,

View File

@@ -18,8 +18,15 @@ dev-dependencies = [
"pre-commit>=4.1.0", "pre-commit>=4.1.0",
"pyright>=1.1.389", "pyright>=1.1.389",
"pytest>=8.3.3", "pytest>=8.3.3",
"ruff>=0.8.0", "pytest-asyncio>=0.23.0",
"ruff>=0.8.0"
] ]
[project.scripts] [project.scripts]
mcp-server-qdrant = "mcp_server_qdrant:main" mcp-server-qdrant = "mcp_server_qdrant:main"
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = "test_*.py"
python_functions = "test_*"
asyncio_mode = "auto"

View File

@@ -0,0 +1,5 @@
from .base import EmbeddingProvider
from .factory import create_embedding_provider
from .fastembed import FastEmbedProvider
__all__ = ["EmbeddingProvider", "FastEmbedProvider", "create_embedding_provider"]

View File

@@ -0,0 +1,16 @@
from abc import ABC, abstractmethod
from typing import List
class EmbeddingProvider(ABC):
"""Abstract base class for embedding providers."""
@abstractmethod
async def embed_documents(self, documents: List[str]) -> List[List[float]]:
"""Embed a list of documents into vectors."""
pass
@abstractmethod
async def embed_query(self, query: str) -> List[float]:
"""Embed a query into a vector."""
pass

View File

@@ -0,0 +1,67 @@
from mcp_server_qdrant.embeddings import EmbeddingProvider
def create_embedding_provider(provider_type: str, **kwargs) -> EmbeddingProvider:
"""
Create an embedding provider based on the specified type.
:param provider_type: The type of embedding provider to create.
:param kwargs: Additional arguments to pass to the provider constructor.
:return: An instance of the specified embedding provider.
"""
if provider_type.lower() == "fastembed":
from .fastembed import FastEmbedProvider
model_name = kwargs.get("model_name", "sentence-transformers/all-MiniLM-L6-v2")
return FastEmbedProvider(model_name)
else:
raise ValueError(f"Unsupported embedding provider: {provider_type}")
from typing import Optional
from .fastembed import FastEmbedProvider
from .base import EmbeddingProvider
def create_embedding_provider(provider_type: str, model_name: Optional[str] = None) -> EmbeddingProvider:
"""
Create an embedding provider based on the provider type.
Args:
provider_type: The type of embedding provider to create.
model_name: The name of the model to use.
Returns:
An instance of EmbeddingProvider.
Raises:
ValueError: If the provider type is not supported.
"""
if provider_type.lower() == "fastembed":
return FastEmbedProvider(model_name)
else:
raise ValueError(f"Unsupported embedding provider: {provider_type}")
from typing import Literal
from .fastembed import FastEmbedProvider
def create_embedding_provider(
provider_type: Literal["fastembed"],
**kwargs
) -> FastEmbedProvider:
"""
Factory function to create an embedding provider.
Args:
provider_type: The type of embedding provider to create.
**kwargs: Additional arguments to pass to the provider constructor.
Returns:
An instance of the requested embedding provider.
Raises:
ValueError: If the provider type is not supported.
"""
if provider_type == "fastembed":
model_name = kwargs.get("model_name", "sentence-transformers/all-MiniLM-L6-v2")
return FastEmbedProvider(model_name)
else:
raise ValueError(f"Unsupported embedding provider: {provider_type}")

View File

@@ -0,0 +1,36 @@
from typing import List
import asyncio
from fastembed import TextEmbedding
from .base import EmbeddingProvider
class FastEmbedProvider(EmbeddingProvider):
"""FastEmbed implementation of the embedding provider."""
def __init__(self, model_name: str):
"""
Initialize the FastEmbed provider.
:param model_name: The name of the FastEmbed model to use.
"""
self.model_name = model_name
self.embedding_model = TextEmbedding(model_name)
async def embed_documents(self, documents: List[str]) -> List[List[float]]:
"""Embed a list of documents into vectors."""
# Run in a thread pool since FastEmbed is synchronous
loop = asyncio.get_event_loop()
embeddings = await loop.run_in_executor(
None, lambda: list(self.embedding_model.passage_embed(documents))
)
return [embedding.tolist() for embedding in embeddings]
async def embed_query(self, query: str) -> List[float]:
"""Embed a query into a vector."""
# Run in a thread pool since FastEmbed is synchronous
loop = asyncio.get_event_loop()
embeddings = await loop.run_in_executor(
None, lambda: list(self.embedding_model.query_embed([query]))
)
return embeddings[0].tolist()

View File

@@ -1,6 +1,6 @@
from typing import Optional from typing import Optional, List
from qdrant_client import AsyncQdrantClient, models
from qdrant_client import AsyncQdrantClient from .embeddings.base import EmbeddingProvider
class QdrantConnector: class QdrantConnector:
@@ -9,7 +9,7 @@ class QdrantConnector:
:param qdrant_url: The URL of the Qdrant server. :param qdrant_url: The URL of the Qdrant server.
:param qdrant_api_key: The API key to use for the Qdrant server. :param qdrant_api_key: The API key to use for the Qdrant server.
:param collection_name: The name of the collection to use. :param collection_name: The name of the collection to use.
:param fastembed_model_name: The name of the FastEmbed model to use. :param embedding_provider: The embedding provider to use.
:param qdrant_local_path: The path to the storage directory for the Qdrant client, if local mode is used. :param qdrant_local_path: The path to the storage directory for the Qdrant client, if local mode is used.
""" """
@@ -18,29 +18,52 @@ class QdrantConnector:
qdrant_url: Optional[str], qdrant_url: Optional[str],
qdrant_api_key: Optional[str], qdrant_api_key: Optional[str],
collection_name: str, collection_name: str,
fastembed_model_name: str, embedding_provider: EmbeddingProvider,
qdrant_local_path: Optional[str] = None, qdrant_local_path: Optional[str] = None,
): ):
self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None
self._qdrant_api_key = qdrant_api_key self._qdrant_api_key = qdrant_api_key
self._collection_name = collection_name self._collection_name = collection_name
self._fastembed_model_name = fastembed_model_name self._embedding_provider = embedding_provider
# For the time being, FastEmbed models are the only supported ones. self._client = AsyncQdrantClient(location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path)
# A list of all available models can be found here:
# https://qdrant.github.io/fastembed/examples/Supported_Models/ async def _ensure_collection_exists(self):
self._client = AsyncQdrantClient( """Ensure that the collection exists, creating it if necessary."""
location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path collection_exists = await self._client.collection_exists(self._collection_name)
if not collection_exists:
# Create the collection with the appropriate vector size
# We'll get the vector size by embedding a sample text
sample_vector = await self._embedding_provider.embed_query("sample text")
vector_size = len(sample_vector)
await self._client.create_collection(
collection_name=self._collection_name,
vectors_config=models.VectorParams(
size=vector_size,
distance=models.Distance.COSINE,
),
) )
self._client.set_model(fastembed_model_name)
async def store_memory(self, information: str): async def store_memory(self, information: str):
""" """
Store a memory in the Qdrant collection. Store a memory in the Qdrant collection.
:param information: The information to store. :param information: The information to store.
""" """
await self._client.add( await self._ensure_collection_exists()
self._collection_name,
documents=[information], # Embed the document
embeddings = await self._embedding_provider.embed_documents([information])
# Add to Qdrant
await self._client.upsert(
collection_name=self._collection_name,
points=[
models.PointStruct(
id=hash(information), # Simple hash as ID
vector=embeddings[0],
payload={"document": information},
)
],
) )
async def find_memories(self, query: str) -> list[str]: async def find_memories(self, query: str) -> list[str]:
@@ -53,9 +76,14 @@ class QdrantConnector:
if not collection_exists: if not collection_exists:
return [] return []
search_results = await self._client.query( # Embed the query
self._collection_name, query_vector = await self._embedding_provider.embed_query(query)
query_text=query,
# Search in Qdrant
search_results = await self._client.search(
collection_name=self._collection_name,
query_vector=query_vector,
limit=10, limit=10,
) )
return [result.document for result in search_results]
return [result.payload["document"] for result in search_results]

View File

@@ -1,20 +1,23 @@
import asyncio
from typing import Optional from typing import Optional
import click from mcp.server import Server, NotificationOptions
import mcp
import mcp.types as types
from mcp.server import NotificationOptions, Server
from mcp.server.models import InitializationOptions from mcp.server.models import InitializationOptions
import click
import mcp.types as types
import asyncio
import mcp
from .qdrant import QdrantConnector from .qdrant import QdrantConnector
from .embeddings.factory import create_embedding_provider
def serve( def serve(
qdrant_url: Optional[str], qdrant_url: Optional[str],
qdrant_api_key: Optional[str], qdrant_api_key: Optional[str],
collection_name: str, collection_name: str,
fastembed_model_name: str, embedding_provider_type: str,
embedding_model_name: str,
qdrant_local_path: Optional[str] = None, qdrant_local_path: Optional[str] = None,
) -> Server: ) -> Server:
""" """
@@ -22,17 +25,20 @@ def serve(
:param qdrant_url: The URL of the Qdrant server. :param qdrant_url: The URL of the Qdrant server.
:param qdrant_api_key: The API key to use for the Qdrant server. :param qdrant_api_key: The API key to use for the Qdrant server.
:param collection_name: The name of the collection to use. :param collection_name: The name of the collection to use.
:param fastembed_model_name: The name of the FastEmbed model to use. :param embedding_provider_type: The type of embedding provider to use.
:param embedding_model_name: The name of the embedding model to use.
:param qdrant_local_path: The path to the storage directory for the Qdrant client, if local mode is used. :param qdrant_local_path: The path to the storage directory for the Qdrant client, if local mode is used.
""" """
server = Server("qdrant") server = Server("qdrant")
# Create the embedding provider
embedding_provider = create_embedding_provider(
embedding_provider_type,
model_name=embedding_model_name
)
qdrant = QdrantConnector( qdrant = QdrantConnector(
qdrant_url, qdrant_url, qdrant_api_key, collection_name, embedding_provider, qdrant_local_path
qdrant_api_key,
collection_name,
fastembed_model_name,
qdrant_local_path,
) )
@server.list_tools() @server.list_tools()
@@ -133,10 +139,18 @@ def serve(
help="Collection name", help="Collection name",
) )
@click.option( @click.option(
"--fastembed-model-name", "--embedding-provider",
envvar="FASTEMBED_MODEL_NAME", envvar="EMBEDDING_PROVIDER",
required=True, required=False,
help="FastEmbed model name", help="Embedding provider to use",
default="fastembed",
type=click.Choice(["fastembed"], case_sensitive=False),
)
@click.option(
"--embedding-model",
envvar="EMBEDDING_MODEL",
required=False,
help="Embedding model name",
default="sentence-transformers/all-MiniLM-L6-v2", default="sentence-transformers/all-MiniLM-L6-v2",
) )
@click.option( @click.option(
@@ -149,14 +163,13 @@ def main(
qdrant_url: Optional[str], qdrant_url: Optional[str],
qdrant_api_key: str, qdrant_api_key: str,
collection_name: Optional[str], collection_name: Optional[str],
fastembed_model_name: str, embedding_provider: str,
embedding_model: str,
qdrant_local_path: Optional[str], qdrant_local_path: Optional[str],
): ):
# XOR of url and local path, since we accept only one of them # XOR of url and local path, since we accept only one of them
if not (bool(qdrant_url) ^ bool(qdrant_local_path)): if not (bool(qdrant_url) ^ bool(qdrant_local_path)):
raise ValueError( raise ValueError("Exactly one of qdrant-url or qdrant-local-path must be provided")
"Exactly one of qdrant-url or qdrant-local-path must be provided"
)
async def _run(): async def _run():
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
@@ -164,7 +177,8 @@ def main(
qdrant_url, qdrant_url,
qdrant_api_key, qdrant_api_key,
collection_name, collection_name,
fastembed_model_name, embedding_provider,
embedding_model,
qdrant_local_path, qdrant_local_path,
) )
await server.run( await server.run(

1
tests/__init__.py Normal file
View File

@@ -0,0 +1 @@
# This file can be empty, it just marks the directory as a Python package

14
uv.lock generated
View File

@@ -494,6 +494,7 @@ dev = [
{ name = "pre-commit" }, { name = "pre-commit" },
{ name = "pyright" }, { name = "pyright" },
{ name = "pytest" }, { name = "pytest" },
{ name = "pytest-asyncio" },
{ name = "ruff" }, { name = "ruff" },
] ]
@@ -508,6 +509,7 @@ dev = [
{ name = "pre-commit", specifier = ">=4.1.0" }, { name = "pre-commit", specifier = ">=4.1.0" },
{ name = "pyright", specifier = ">=1.1.389" }, { name = "pyright", specifier = ">=1.1.389" },
{ name = "pytest", specifier = ">=8.3.3" }, { name = "pytest", specifier = ">=8.3.3" },
{ name = "pytest-asyncio", specifier = ">=0.23.0" },
{ name = "ruff", specifier = ">=0.8.0" }, { name = "ruff", specifier = ">=0.8.0" },
] ]
@@ -981,6 +983,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/6b/77/7440a06a8ead44c7757a64362dd22df5760f9b12dc5f11b6188cd2fc27a0/pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2", size = 342341 }, { url = "https://files.pythonhosted.org/packages/6b/77/7440a06a8ead44c7757a64362dd22df5760f9b12dc5f11b6188cd2fc27a0/pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2", size = 342341 },
] ]
[[package]]
name = "pytest-asyncio"
version = "0.25.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pytest" },
]
sdist = { url = "https://files.pythonhosted.org/packages/f2/a8/ecbc8ede70921dd2f544ab1cadd3ff3bf842af27f87bbdea774c7baa1d38/pytest_asyncio-0.25.3.tar.gz", hash = "sha256:fc1da2cf9f125ada7e710b4ddad05518d4cee187ae9412e9ac9271003497f07a", size = 54239 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/67/17/3493c5624e48fd97156ebaec380dcaafee9506d7e2c46218ceebbb57d7de/pytest_asyncio-0.25.3-py3-none-any.whl", hash = "sha256:9e89518e0f9bd08928f97a3482fdc4e244df17529460bc038291ccaf8f85c7c3", size = 19467 },
]
[[package]] [[package]]
name = "pywin32" name = "pywin32"
version = "308" version = "308"