Abstract the embedding providers

This commit is contained in:
Kacper Łukawski
2025-03-05 22:40:58 +01:00
parent 7c1d05ae17
commit ab110da65e
11 changed files with 253 additions and 30 deletions

5
.gitignore vendored
View File

@@ -159,4 +159,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/
# Project-specific settings
.aider*

View File

@@ -36,7 +36,7 @@ uv run mcp-server-qdrant \
--qdrant-url "http://localhost:6333" \
--qdrant-api-key "your_api_key" \
--collection-name "my_collection" \
--fastembed-model-name "sentence-transformers/all-MiniLM-L6-v2"
--embedding-model "sentence-transformers/all-MiniLM-L6-v2"
```
### Installing via Smithery
@@ -76,7 +76,7 @@ This MCP server will automatically create a collection with the specified name i
By default, the server will use the `sentence-transformers/all-MiniLM-L6-v2` embedding model to encode memories.
For the time being, only [FastEmbed](https://qdrant.github.io/fastembed/) models are supported, and you can change it
by passing the `--fastembed-model-name` argument to the server.
by passing the `--embedding-model` argument to the server.
### Using the local mode of Qdrant
@@ -106,11 +106,31 @@ The configuration of the server can be also done using environment variables:
- `QDRANT_URL`: URL of the Qdrant server, e.g. `http://localhost:6333`
- `QDRANT_API_KEY`: API key for the Qdrant server
- `COLLECTION_NAME`: Name of the collection to use
- `FASTEMBED_MODEL_NAME`: Name of the FastEmbed model to use
- `EMBEDDING_MODEL`: Name of the embedding model to use
- `EMBEDDING_PROVIDER`: Embedding provider to use (currently only "fastembed" is supported)
- `QDRANT_LOCAL_PATH`: Path to the local Qdrant database
You cannot provide `QDRANT_URL` and `QDRANT_LOCAL_PATH` at the same time.
## Contributing
If you have suggestions for how mcp-server-qdrant could be improved, or want to report a bug, open an issue!
We'd love all and any contributions.
### Testing `mcp-server-qdrant` locally
The [MCP inspector](https://github.com/modelcontextprotocol/inspector) is a developer tool for testing and debugging MCP
servers. It runs both a client UI (default port 5173) and an MCP proxy server (default port 3000). Open the client UI in
your browser to use the inspector.
```shell
npx @modelcontextprotocol/inspector uv run mcp-server-qdrant \
--collection-name test \
--qdrant-local-path /tmp/qdrant-local-test
```
Once started, open your browser to http://localhost:5173 to access the inspector interface.
## License
This MCP server is licensed under the MIT License. This means you are free to use, modify, and distribute the software,

View File

@@ -14,7 +14,18 @@ requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.uv]
dev-dependencies = ["pyright>=1.1.389", "pytest>=8.3.3", "ruff>=0.8.0"]
dev-dependencies = [
"pyright>=1.1.389",
"pytest>=8.3.3",
"pytest-asyncio>=0.23.0",
"ruff>=0.8.0"
]
[project.scripts]
mcp-server-qdrant = "mcp_server_qdrant:main"
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = "test_*.py"
python_functions = "test_*"
asyncio_mode = "auto"

View File

@@ -0,0 +1,5 @@
from .base import EmbeddingProvider
from .factory import create_embedding_provider
from .fastembed import FastEmbedProvider
__all__ = ["EmbeddingProvider", "FastEmbedProvider", "create_embedding_provider"]

View File

@@ -0,0 +1,16 @@
from abc import ABC, abstractmethod
from typing import List
class EmbeddingProvider(ABC):
"""Abstract base class for embedding providers."""
@abstractmethod
async def embed_documents(self, documents: List[str]) -> List[List[float]]:
"""Embed a list of documents into vectors."""
pass
@abstractmethod
async def embed_query(self, query: str) -> List[float]:
"""Embed a query into a vector."""
pass

View File

@@ -0,0 +1,67 @@
from mcp_server_qdrant.embeddings import EmbeddingProvider
def create_embedding_provider(provider_type: str, **kwargs) -> EmbeddingProvider:
"""
Create an embedding provider based on the specified type.
:param provider_type: The type of embedding provider to create.
:param kwargs: Additional arguments to pass to the provider constructor.
:return: An instance of the specified embedding provider.
"""
if provider_type.lower() == "fastembed":
from .fastembed import FastEmbedProvider
model_name = kwargs.get("model_name", "sentence-transformers/all-MiniLM-L6-v2")
return FastEmbedProvider(model_name)
else:
raise ValueError(f"Unsupported embedding provider: {provider_type}")
from typing import Optional
from .fastembed import FastEmbedProvider
from .base import EmbeddingProvider
def create_embedding_provider(provider_type: str, model_name: Optional[str] = None) -> EmbeddingProvider:
"""
Create an embedding provider based on the provider type.
Args:
provider_type: The type of embedding provider to create.
model_name: The name of the model to use.
Returns:
An instance of EmbeddingProvider.
Raises:
ValueError: If the provider type is not supported.
"""
if provider_type.lower() == "fastembed":
return FastEmbedProvider(model_name)
else:
raise ValueError(f"Unsupported embedding provider: {provider_type}")
from typing import Literal
from .fastembed import FastEmbedProvider
def create_embedding_provider(
provider_type: Literal["fastembed"],
**kwargs
) -> FastEmbedProvider:
"""
Factory function to create an embedding provider.
Args:
provider_type: The type of embedding provider to create.
**kwargs: Additional arguments to pass to the provider constructor.
Returns:
An instance of the requested embedding provider.
Raises:
ValueError: If the provider type is not supported.
"""
if provider_type == "fastembed":
model_name = kwargs.get("model_name", "sentence-transformers/all-MiniLM-L6-v2")
return FastEmbedProvider(model_name)
else:
raise ValueError(f"Unsupported embedding provider: {provider_type}")

View File

@@ -0,0 +1,36 @@
from typing import List
import asyncio
from fastembed import TextEmbedding
from .base import EmbeddingProvider
class FastEmbedProvider(EmbeddingProvider):
"""FastEmbed implementation of the embedding provider."""
def __init__(self, model_name: str):
"""
Initialize the FastEmbed provider.
:param model_name: The name of the FastEmbed model to use.
"""
self.model_name = model_name
self.embedding_model = TextEmbedding(model_name)
async def embed_documents(self, documents: List[str]) -> List[List[float]]:
"""Embed a list of documents into vectors."""
# Run in a thread pool since FastEmbed is synchronous
loop = asyncio.get_event_loop()
embeddings = await loop.run_in_executor(
None, lambda: list(self.embedding_model.passage_embed(documents))
)
return [embedding.tolist() for embedding in embeddings]
async def embed_query(self, query: str) -> List[float]:
"""Embed a query into a vector."""
# Run in a thread pool since FastEmbed is synchronous
loop = asyncio.get_event_loop()
embeddings = await loop.run_in_executor(
None, lambda: list(self.embedding_model.query_embed([query]))
)
return embeddings[0].tolist()

View File

@@ -1,5 +1,6 @@
from typing import Optional
from typing import Optional, List
from qdrant_client import AsyncQdrantClient, models
from .embeddings.base import EmbeddingProvider
class QdrantConnector:
@@ -8,7 +9,7 @@ class QdrantConnector:
:param qdrant_url: The URL of the Qdrant server.
:param qdrant_api_key: The API key to use for the Qdrant server.
:param collection_name: The name of the collection to use.
:param fastembed_model_name: The name of the FastEmbed model to use.
:param embedding_provider: The embedding provider to use.
:param qdrant_local_path: The path to the storage directory for the Qdrant client, if local mode is used.
"""
@@ -17,27 +18,52 @@ class QdrantConnector:
qdrant_url: Optional[str],
qdrant_api_key: Optional[str],
collection_name: str,
fastembed_model_name: str,
embedding_provider: EmbeddingProvider,
qdrant_local_path: Optional[str] = None,
):
self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None
self._qdrant_api_key = qdrant_api_key
self._collection_name = collection_name
self._fastembed_model_name = fastembed_model_name
# For the time being, FastEmbed models are the only supported ones.
# A list of all available models can be found here:
# https://qdrant.github.io/fastembed/examples/Supported_Models/
self._embedding_provider = embedding_provider
self._client = AsyncQdrantClient(location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path)
self._client.set_model(fastembed_model_name)
async def _ensure_collection_exists(self):
"""Ensure that the collection exists, creating it if necessary."""
collection_exists = await self._client.collection_exists(self._collection_name)
if not collection_exists:
# Create the collection with the appropriate vector size
# We'll get the vector size by embedding a sample text
sample_vector = await self._embedding_provider.embed_query("sample text")
vector_size = len(sample_vector)
await self._client.create_collection(
collection_name=self._collection_name,
vectors_config=models.VectorParams(
size=vector_size,
distance=models.Distance.COSINE,
),
)
async def store_memory(self, information: str):
"""
Store a memory in the Qdrant collection.
:param information: The information to store.
"""
await self._client.add(
self._collection_name,
documents=[information],
await self._ensure_collection_exists()
# Embed the document
embeddings = await self._embedding_provider.embed_documents([information])
# Add to Qdrant
await self._client.upsert(
collection_name=self._collection_name,
points=[
models.PointStruct(
id=hash(information), # Simple hash as ID
vector=embeddings[0],
payload={"document": information},
)
],
)
async def find_memories(self, query: str) -> list[str]:
@@ -50,9 +76,14 @@ class QdrantConnector:
if not collection_exists:
return []
search_results = await self._client.query(
self._collection_name,
query_text=query,
# Embed the query
query_vector = await self._embedding_provider.embed_query(query)
# Search in Qdrant
search_results = await self._client.search(
collection_name=self._collection_name,
query_vector=query_vector,
limit=10,
)
return [result.document for result in search_results]
return [result.payload["document"] for result in search_results]

View File

@@ -9,13 +9,15 @@ import asyncio
import mcp
from .qdrant import QdrantConnector
from .embeddings.factory import create_embedding_provider
def serve(
qdrant_url: Optional[str],
qdrant_api_key: Optional[str],
collection_name: str,
fastembed_model_name: str,
embedding_provider_type: str,
embedding_model_name: str,
qdrant_local_path: Optional[str] = None,
) -> Server:
"""
@@ -23,13 +25,20 @@ def serve(
:param qdrant_url: The URL of the Qdrant server.
:param qdrant_api_key: The API key to use for the Qdrant server.
:param collection_name: The name of the collection to use.
:param fastembed_model_name: The name of the FastEmbed model to use.
:param embedding_provider_type: The type of embedding provider to use.
:param embedding_model_name: The name of the embedding model to use.
:param qdrant_local_path: The path to the storage directory for the Qdrant client, if local mode is used.
"""
server = Server("qdrant")
# Create the embedding provider
embedding_provider = create_embedding_provider(
embedding_provider_type,
model_name=embedding_model_name
)
qdrant = QdrantConnector(
qdrant_url, qdrant_api_key, collection_name, fastembed_model_name, qdrant_local_path
qdrant_url, qdrant_api_key, collection_name, embedding_provider, qdrant_local_path
)
@server.list_tools()
@@ -130,10 +139,18 @@ def serve(
help="Collection name",
)
@click.option(
"--fastembed-model-name",
envvar="FASTEMBED_MODEL_NAME",
required=True,
help="FastEmbed model name",
"--embedding-provider",
envvar="EMBEDDING_PROVIDER",
required=False,
help="Embedding provider to use",
default="fastembed",
type=click.Choice(["fastembed"], case_sensitive=False),
)
@click.option(
"--embedding-model",
envvar="EMBEDDING_MODEL",
required=False,
help="Embedding model name",
default="sentence-transformers/all-MiniLM-L6-v2",
)
@click.option(
@@ -146,7 +163,8 @@ def main(
qdrant_url: Optional[str],
qdrant_api_key: str,
collection_name: Optional[str],
fastembed_model_name: str,
embedding_provider: str,
embedding_model: str,
qdrant_local_path: Optional[str],
):
# XOR of url and local path, since we accept only one of them
@@ -159,7 +177,8 @@ def main(
qdrant_url,
qdrant_api_key,
collection_name,
fastembed_model_name,
embedding_provider,
embedding_model,
qdrant_local_path,
)
await server.run(

1
tests/__init__.py Normal file
View File

@@ -0,0 +1 @@
# This file can be empty, it just marks the directory as a Python package

16
uv.lock generated
View File

@@ -455,7 +455,7 @@ wheels = [
[[package]]
name = "mcp-server-qdrant"
version = "0.5.1"
version = "0.5.2"
source = { editable = "." }
dependencies = [
{ name = "mcp" },
@@ -466,6 +466,7 @@ dependencies = [
dev = [
{ name = "pyright" },
{ name = "pytest" },
{ name = "pytest-asyncio" },
{ name = "ruff" },
]
@@ -479,6 +480,7 @@ requires-dist = [
dev = [
{ name = "pyright", specifier = ">=1.1.389" },
{ name = "pytest", specifier = ">=8.3.3" },
{ name = "pytest-asyncio", specifier = ">=0.23.0" },
{ name = "ruff", specifier = ">=0.8.0" },
]
@@ -927,6 +929,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/6b/77/7440a06a8ead44c7757a64362dd22df5760f9b12dc5f11b6188cd2fc27a0/pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2", size = 342341 },
]
[[package]]
name = "pytest-asyncio"
version = "0.25.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pytest" },
]
sdist = { url = "https://files.pythonhosted.org/packages/f2/a8/ecbc8ede70921dd2f544ab1cadd3ff3bf842af27f87bbdea774c7baa1d38/pytest_asyncio-0.25.3.tar.gz", hash = "sha256:fc1da2cf9f125ada7e710b4ddad05518d4cee187ae9412e9ac9271003497f07a", size = 54239 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/67/17/3493c5624e48fd97156ebaec380dcaafee9506d7e2c46218ceebbb57d7de/pytest_asyncio-0.25.3-py3-none-any.whl", hash = "sha256:9e89518e0f9bd08928f97a3482fdc4e244df17529460bc038291ccaf8f85c7c3", size = 19467 },
]
[[package]]
name = "pywin32"
version = "308"