From 56f1c797fcf9d02826163baf66d4c8cc5ae2f602 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20=C5=81ukawski?= Date: Wed, 5 Mar 2025 22:40:58 +0100 Subject: [PATCH] Abstract the embedding providers --- .gitignore | 5 +- README.md | 26 ++++++- pyproject.toml | 9 ++- src/mcp_server_qdrant/embeddings/__init__.py | 5 ++ src/mcp_server_qdrant/embeddings/base.py | 16 +++++ src/mcp_server_qdrant/embeddings/factory.py | 67 ++++++++++++++++++ src/mcp_server_qdrant/embeddings/fastembed.py | 36 ++++++++++ src/mcp_server_qdrant/qdrant.py | 68 +++++++++++++------ src/mcp_server_qdrant/server.py | 56 +++++++++------ tests/__init__.py | 1 + uv.lock | 14 ++++ 11 files changed, 257 insertions(+), 46 deletions(-) create mode 100644 src/mcp_server_qdrant/embeddings/__init__.py create mode 100644 src/mcp_server_qdrant/embeddings/base.py create mode 100644 src/mcp_server_qdrant/embeddings/factory.py create mode 100644 src/mcp_server_qdrant/embeddings/fastembed.py create mode 100644 tests/__init__.py diff --git a/.gitignore b/.gitignore index 82f9275..66620d0 100644 --- a/.gitignore +++ b/.gitignore @@ -159,4 +159,7 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ + +# Project-specific settings +.aider* diff --git a/README.md b/README.md index 1ee679b..f430e0c 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ uv run mcp-server-qdrant \ --qdrant-url "http://localhost:6333" \ --qdrant-api-key "your_api_key" \ --collection-name "my_collection" \ - --fastembed-model-name "sentence-transformers/all-MiniLM-L6-v2" + --embedding-model "sentence-transformers/all-MiniLM-L6-v2" ``` ### Installing via Smithery @@ -76,7 +76,7 @@ This MCP server will automatically create a collection with the specified name i By default, the server will use the `sentence-transformers/all-MiniLM-L6-v2` embedding model to encode memories. For the time being, only [FastEmbed](https://qdrant.github.io/fastembed/) models are supported, and you can change it -by passing the `--fastembed-model-name` argument to the server. +by passing the `--embedding-model` argument to the server. ### Using the local mode of Qdrant @@ -106,11 +106,31 @@ The configuration of the server can be also done using environment variables: - `QDRANT_URL`: URL of the Qdrant server, e.g. `http://localhost:6333` - `QDRANT_API_KEY`: API key for the Qdrant server - `COLLECTION_NAME`: Name of the collection to use -- `FASTEMBED_MODEL_NAME`: Name of the FastEmbed model to use +- `EMBEDDING_MODEL`: Name of the embedding model to use +- `EMBEDDING_PROVIDER`: Embedding provider to use (currently only "fastembed" is supported) - `QDRANT_LOCAL_PATH`: Path to the local Qdrant database You cannot provide `QDRANT_URL` and `QDRANT_LOCAL_PATH` at the same time. +## Contributing + +If you have suggestions for how mcp-server-qdrant could be improved, or want to report a bug, open an issue! +We'd love all and any contributions. + +### Testing `mcp-server-qdrant` locally + +The [MCP inspector](https://github.com/modelcontextprotocol/inspector) is a developer tool for testing and debugging MCP +servers. It runs both a client UI (default port 5173) and an MCP proxy server (default port 3000). Open the client UI in +your browser to use the inspector. + +```shell +npx @modelcontextprotocol/inspector uv run mcp-server-qdrant \ + --collection-name test \ + --qdrant-local-path /tmp/qdrant-local-test +``` + +Once started, open your browser to http://localhost:5173 to access the inspector interface. + ## License This MCP server is licensed under the MIT License. This means you are free to use, modify, and distribute the software, diff --git a/pyproject.toml b/pyproject.toml index f5c5569..ff08249 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,8 +18,15 @@ dev-dependencies = [ "pre-commit>=4.1.0", "pyright>=1.1.389", "pytest>=8.3.3", - "ruff>=0.8.0", + "pytest-asyncio>=0.23.0", + "ruff>=0.8.0" ] [project.scripts] mcp-server-qdrant = "mcp_server_qdrant:main" + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = "test_*.py" +python_functions = "test_*" +asyncio_mode = "auto" diff --git a/src/mcp_server_qdrant/embeddings/__init__.py b/src/mcp_server_qdrant/embeddings/__init__.py new file mode 100644 index 0000000..94f0a4c --- /dev/null +++ b/src/mcp_server_qdrant/embeddings/__init__.py @@ -0,0 +1,5 @@ +from .base import EmbeddingProvider +from .factory import create_embedding_provider +from .fastembed import FastEmbedProvider + +__all__ = ["EmbeddingProvider", "FastEmbedProvider", "create_embedding_provider"] diff --git a/src/mcp_server_qdrant/embeddings/base.py b/src/mcp_server_qdrant/embeddings/base.py new file mode 100644 index 0000000..7440772 --- /dev/null +++ b/src/mcp_server_qdrant/embeddings/base.py @@ -0,0 +1,16 @@ +from abc import ABC, abstractmethod +from typing import List + + +class EmbeddingProvider(ABC): + """Abstract base class for embedding providers.""" + + @abstractmethod + async def embed_documents(self, documents: List[str]) -> List[List[float]]: + """Embed a list of documents into vectors.""" + pass + + @abstractmethod + async def embed_query(self, query: str) -> List[float]: + """Embed a query into a vector.""" + pass diff --git a/src/mcp_server_qdrant/embeddings/factory.py b/src/mcp_server_qdrant/embeddings/factory.py new file mode 100644 index 0000000..a160dcf --- /dev/null +++ b/src/mcp_server_qdrant/embeddings/factory.py @@ -0,0 +1,67 @@ +from mcp_server_qdrant.embeddings import EmbeddingProvider + + +def create_embedding_provider(provider_type: str, **kwargs) -> EmbeddingProvider: + """ + Create an embedding provider based on the specified type. + + :param provider_type: The type of embedding provider to create. + :param kwargs: Additional arguments to pass to the provider constructor. + :return: An instance of the specified embedding provider. + """ + if provider_type.lower() == "fastembed": + from .fastembed import FastEmbedProvider + model_name = kwargs.get("model_name", "sentence-transformers/all-MiniLM-L6-v2") + return FastEmbedProvider(model_name) + else: + raise ValueError(f"Unsupported embedding provider: {provider_type}") +from typing import Optional +from .fastembed import FastEmbedProvider +from .base import EmbeddingProvider + + +def create_embedding_provider(provider_type: str, model_name: Optional[str] = None) -> EmbeddingProvider: + """ + Create an embedding provider based on the provider type. + + Args: + provider_type: The type of embedding provider to create. + model_name: The name of the model to use. + + Returns: + An instance of EmbeddingProvider. + + Raises: + ValueError: If the provider type is not supported. + """ + if provider_type.lower() == "fastembed": + return FastEmbedProvider(model_name) + else: + raise ValueError(f"Unsupported embedding provider: {provider_type}") +from typing import Literal + +from .fastembed import FastEmbedProvider + + +def create_embedding_provider( + provider_type: Literal["fastembed"], + **kwargs +) -> FastEmbedProvider: + """ + Factory function to create an embedding provider. + + Args: + provider_type: The type of embedding provider to create. + **kwargs: Additional arguments to pass to the provider constructor. + + Returns: + An instance of the requested embedding provider. + + Raises: + ValueError: If the provider type is not supported. + """ + if provider_type == "fastembed": + model_name = kwargs.get("model_name", "sentence-transformers/all-MiniLM-L6-v2") + return FastEmbedProvider(model_name) + else: + raise ValueError(f"Unsupported embedding provider: {provider_type}") diff --git a/src/mcp_server_qdrant/embeddings/fastembed.py b/src/mcp_server_qdrant/embeddings/fastembed.py new file mode 100644 index 0000000..c54bbb1 --- /dev/null +++ b/src/mcp_server_qdrant/embeddings/fastembed.py @@ -0,0 +1,36 @@ +from typing import List +import asyncio +from fastembed import TextEmbedding + +from .base import EmbeddingProvider + + +class FastEmbedProvider(EmbeddingProvider): + """FastEmbed implementation of the embedding provider.""" + + def __init__(self, model_name: str): + """ + Initialize the FastEmbed provider. + + :param model_name: The name of the FastEmbed model to use. + """ + self.model_name = model_name + self.embedding_model = TextEmbedding(model_name) + + async def embed_documents(self, documents: List[str]) -> List[List[float]]: + """Embed a list of documents into vectors.""" + # Run in a thread pool since FastEmbed is synchronous + loop = asyncio.get_event_loop() + embeddings = await loop.run_in_executor( + None, lambda: list(self.embedding_model.passage_embed(documents)) + ) + return [embedding.tolist() for embedding in embeddings] + + async def embed_query(self, query: str) -> List[float]: + """Embed a query into a vector.""" + # Run in a thread pool since FastEmbed is synchronous + loop = asyncio.get_event_loop() + embeddings = await loop.run_in_executor( + None, lambda: list(self.embedding_model.query_embed([query])) + ) + return embeddings[0].tolist() diff --git a/src/mcp_server_qdrant/qdrant.py b/src/mcp_server_qdrant/qdrant.py index bbba0a3..c281aa3 100644 --- a/src/mcp_server_qdrant/qdrant.py +++ b/src/mcp_server_qdrant/qdrant.py @@ -1,6 +1,6 @@ -from typing import Optional - -from qdrant_client import AsyncQdrantClient +from typing import Optional, List +from qdrant_client import AsyncQdrantClient, models +from .embeddings.base import EmbeddingProvider class QdrantConnector: @@ -9,7 +9,7 @@ class QdrantConnector: :param qdrant_url: The URL of the Qdrant server. :param qdrant_api_key: The API key to use for the Qdrant server. :param collection_name: The name of the collection to use. - :param fastembed_model_name: The name of the FastEmbed model to use. + :param embedding_provider: The embedding provider to use. :param qdrant_local_path: The path to the storage directory for the Qdrant client, if local mode is used. """ @@ -18,29 +18,52 @@ class QdrantConnector: qdrant_url: Optional[str], qdrant_api_key: Optional[str], collection_name: str, - fastembed_model_name: str, + embedding_provider: EmbeddingProvider, qdrant_local_path: Optional[str] = None, ): self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None self._qdrant_api_key = qdrant_api_key self._collection_name = collection_name - self._fastembed_model_name = fastembed_model_name - # For the time being, FastEmbed models are the only supported ones. - # A list of all available models can be found here: - # https://qdrant.github.io/fastembed/examples/Supported_Models/ - self._client = AsyncQdrantClient( - location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path - ) - self._client.set_model(fastembed_model_name) + self._embedding_provider = embedding_provider + self._client = AsyncQdrantClient(location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path) + + async def _ensure_collection_exists(self): + """Ensure that the collection exists, creating it if necessary.""" + collection_exists = await self._client.collection_exists(self._collection_name) + if not collection_exists: + # Create the collection with the appropriate vector size + # We'll get the vector size by embedding a sample text + sample_vector = await self._embedding_provider.embed_query("sample text") + vector_size = len(sample_vector) + + await self._client.create_collection( + collection_name=self._collection_name, + vectors_config=models.VectorParams( + size=vector_size, + distance=models.Distance.COSINE, + ), + ) async def store_memory(self, information: str): """ Store a memory in the Qdrant collection. :param information: The information to store. """ - await self._client.add( - self._collection_name, - documents=[information], + await self._ensure_collection_exists() + + # Embed the document + embeddings = await self._embedding_provider.embed_documents([information]) + + # Add to Qdrant + await self._client.upsert( + collection_name=self._collection_name, + points=[ + models.PointStruct( + id=hash(information), # Simple hash as ID + vector=embeddings[0], + payload={"document": information}, + ) + ], ) async def find_memories(self, query: str) -> list[str]: @@ -53,9 +76,14 @@ class QdrantConnector: if not collection_exists: return [] - search_results = await self._client.query( - self._collection_name, - query_text=query, + # Embed the query + query_vector = await self._embedding_provider.embed_query(query) + + # Search in Qdrant + search_results = await self._client.search( + collection_name=self._collection_name, + query_vector=query_vector, limit=10, ) - return [result.document for result in search_results] + + return [result.payload["document"] for result in search_results] diff --git a/src/mcp_server_qdrant/server.py b/src/mcp_server_qdrant/server.py index 66814bc..96921fd 100644 --- a/src/mcp_server_qdrant/server.py +++ b/src/mcp_server_qdrant/server.py @@ -1,20 +1,23 @@ -import asyncio from typing import Optional -import click -import mcp -import mcp.types as types -from mcp.server import NotificationOptions, Server +from mcp.server import Server, NotificationOptions from mcp.server.models import InitializationOptions +import click +import mcp.types as types +import asyncio +import mcp + from .qdrant import QdrantConnector +from .embeddings.factory import create_embedding_provider def serve( qdrant_url: Optional[str], qdrant_api_key: Optional[str], collection_name: str, - fastembed_model_name: str, + embedding_provider_type: str, + embedding_model_name: str, qdrant_local_path: Optional[str] = None, ) -> Server: """ @@ -22,17 +25,20 @@ def serve( :param qdrant_url: The URL of the Qdrant server. :param qdrant_api_key: The API key to use for the Qdrant server. :param collection_name: The name of the collection to use. - :param fastembed_model_name: The name of the FastEmbed model to use. + :param embedding_provider_type: The type of embedding provider to use. + :param embedding_model_name: The name of the embedding model to use. :param qdrant_local_path: The path to the storage directory for the Qdrant client, if local mode is used. """ server = Server("qdrant") + # Create the embedding provider + embedding_provider = create_embedding_provider( + embedding_provider_type, + model_name=embedding_model_name + ) + qdrant = QdrantConnector( - qdrant_url, - qdrant_api_key, - collection_name, - fastembed_model_name, - qdrant_local_path, + qdrant_url, qdrant_api_key, collection_name, embedding_provider, qdrant_local_path ) @server.list_tools() @@ -133,10 +139,18 @@ def serve( help="Collection name", ) @click.option( - "--fastembed-model-name", - envvar="FASTEMBED_MODEL_NAME", - required=True, - help="FastEmbed model name", + "--embedding-provider", + envvar="EMBEDDING_PROVIDER", + required=False, + help="Embedding provider to use", + default="fastembed", + type=click.Choice(["fastembed"], case_sensitive=False), +) +@click.option( + "--embedding-model", + envvar="EMBEDDING_MODEL", + required=False, + help="Embedding model name", default="sentence-transformers/all-MiniLM-L6-v2", ) @click.option( @@ -149,14 +163,13 @@ def main( qdrant_url: Optional[str], qdrant_api_key: str, collection_name: Optional[str], - fastembed_model_name: str, + embedding_provider: str, + embedding_model: str, qdrant_local_path: Optional[str], ): # XOR of url and local path, since we accept only one of them if not (bool(qdrant_url) ^ bool(qdrant_local_path)): - raise ValueError( - "Exactly one of qdrant-url or qdrant-local-path must be provided" - ) + raise ValueError("Exactly one of qdrant-url or qdrant-local-path must be provided") async def _run(): async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): @@ -164,7 +177,8 @@ def main( qdrant_url, qdrant_api_key, collection_name, - fastembed_model_name, + embedding_provider, + embedding_model, qdrant_local_path, ) await server.run( diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..a06d7db --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# This file can be empty, it just marks the directory as a Python package diff --git a/uv.lock b/uv.lock index 218fb16..e91b996 100644 --- a/uv.lock +++ b/uv.lock @@ -494,6 +494,7 @@ dev = [ { name = "pre-commit" }, { name = "pyright" }, { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "ruff" }, ] @@ -508,6 +509,7 @@ dev = [ { name = "pre-commit", specifier = ">=4.1.0" }, { name = "pyright", specifier = ">=1.1.389" }, { name = "pytest", specifier = ">=8.3.3" }, + { name = "pytest-asyncio", specifier = ">=0.23.0" }, { name = "ruff", specifier = ">=0.8.0" }, ] @@ -981,6 +983,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6b/77/7440a06a8ead44c7757a64362dd22df5760f9b12dc5f11b6188cd2fc27a0/pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2", size = 342341 }, ] +[[package]] +name = "pytest-asyncio" +version = "0.25.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f2/a8/ecbc8ede70921dd2f544ab1cadd3ff3bf842af27f87bbdea774c7baa1d38/pytest_asyncio-0.25.3.tar.gz", hash = "sha256:fc1da2cf9f125ada7e710b4ddad05518d4cee187ae9412e9ac9271003497f07a", size = 54239 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/67/17/3493c5624e48fd97156ebaec380dcaafee9506d7e2c46218ceebbb57d7de/pytest_asyncio-0.25.3-py3-none-any.whl", hash = "sha256:9e89518e0f9bd08928f97a3482fdc4e244df17529460bc038291ccaf8f85c7c3", size = 19467 }, +] + [[package]] name = "pywin32" version = "308"