Merge pull request #22 from qdrant/refactor/fastmcp

Refactor: use FastMCP
This commit is contained in:
Kacper Łukawski
2025-03-10 14:34:40 +01:00
committed by GitHub
11 changed files with 1085 additions and 856 deletions

View File

@@ -15,7 +15,7 @@ repos:
- id: check-added-large-files
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.0
rev: v0.9.10
hooks:
- id: ruff
args: [ --fix ]

View File

@@ -6,8 +6,10 @@ readme = "README.md"
requires-python = ">=3.10"
license = "Apache-2.0"
dependencies = [
"mcp>=0.9.1",
"qdrant-client[fastembed]>=1.12.0",
"mcp[cli]>=1.3.0",
"fastembed>=0.6.0",
"qdrant-client>=1.12.0",
"typer>=0.15.2",
]
[build-system]

View File

@@ -3,7 +3,7 @@ from . import server
def main():
"""Main entry point for the package."""
server.main()
server.mcp.run()
# Optionally expose other important items at package level

View File

@@ -1,17 +1,16 @@
from mcp_server_qdrant.embeddings import EmbeddingProvider
from mcp_server_qdrant.settings import EmbeddingProviderSettings
def create_embedding_provider(provider_type: str, model_name: str) -> EmbeddingProvider:
def create_embedding_provider(settings: EmbeddingProviderSettings) -> EmbeddingProvider:
"""
Create an embedding provider based on the specified type.
:param provider_type: The type of embedding provider to create.
:param model_name: The name of the model to use for embeddings, specific to the provider type.
:param settings: The settings for the embedding provider.
:return: An instance of the specified embedding provider.
"""
if provider_type.lower() == "fastembed":
from .fastembed import FastEmbedProvider
if settings.provider_type.lower() == "fastembed":
from mcp_server_qdrant.embeddings.fastembed import FastEmbedProvider
return FastEmbedProvider(model_name)
return FastEmbedProvider(settings.model_name)
else:
raise ValueError(f"Unsupported embedding provider: {provider_type}")
raise ValueError(f"Unsupported embedding provider: {settings.provider_type}")

View File

@@ -3,18 +3,16 @@ from typing import List
from fastembed import TextEmbedding
from .base import EmbeddingProvider
from mcp_server_qdrant.embeddings.base import EmbeddingProvider
class FastEmbedProvider(EmbeddingProvider):
"""FastEmbed implementation of the embedding provider."""
def __init__(self, model_name: str):
"""
Initialize the FastEmbed provider.
FastEmbed implementation of the embedding provider.
:param model_name: The name of the FastEmbed model to use.
"""
def __init__(self, model_name: str):
self.model_name = model_name
self.embedding_model = TextEmbedding(model_name)

View File

@@ -53,9 +53,9 @@ class QdrantConnector:
},
)
async def store_memory(self, information: str):
async def store(self, information: str):
"""
Store a memory in the Qdrant collection.
Store some information in the Qdrant collection.
:param information: The information to store.
"""
await self._ensure_collection_exists()
@@ -76,11 +76,11 @@ class QdrantConnector:
],
)
async def find_memories(self, query: str) -> list[str]:
async def search(self, query: str) -> list[str]:
"""
Find memories in the Qdrant collection. If there are no memories found, an empty list is returned.
Find points in the Qdrant collection. If there are no entries found, an empty list is returned.
:param query: The query to use for the search.
:return: A list of memories found.
:return: A list of entries found.
"""
collection_exists = await self._client.collection_exists(self._collection_name)
if not collection_exists:

View File

@@ -1,60 +1,93 @@
import asyncio
import importlib.metadata
from typing import Optional
import logging
import os
from contextlib import asynccontextmanager
from typing import AsyncIterator, List
import click
import mcp
import mcp.types as types
from mcp.server import NotificationOptions, Server
from mcp.server.models import InitializationOptions
from mcp.server import Server
from mcp.server.fastmcp import Context, FastMCP
from .embeddings.factory import create_embedding_provider
from .qdrant import QdrantConnector
from mcp_server_qdrant.embeddings.factory import create_embedding_provider
from mcp_server_qdrant.qdrant import QdrantConnector
from mcp_server_qdrant.settings import (
EmbeddingProviderSettings,
QdrantSettings,
parse_args,
)
logger = logging.getLogger(__name__)
# Parse command line arguments and set them as environment variables.
# This is done for backwards compatibility with the previous versions
# of the MCP server.
env_vars = parse_args()
for key, value in env_vars.items():
os.environ[key] = value
def get_package_version() -> str:
"""Get the package version using importlib.metadata."""
@asynccontextmanager
async def server_lifespan(server: Server) -> AsyncIterator[dict]: # noqa
"""
Context manager to handle the lifespan of the server.
This is used to configure the embedding provider and Qdrant connector.
"""
try:
return importlib.metadata.version("mcp-server-qdrant")
except importlib.metadata.PackageNotFoundError:
# Fall back to a default version if package is not installed
return "0.0.0"
# Embedding provider is created with a factory function so we can add
# some more providers in the future. Currently, only FastEmbed is supported.
embedding_provider_settings = EmbeddingProviderSettings()
embedding_provider = create_embedding_provider(embedding_provider_settings)
logger.info(
f"Using embedding provider {embedding_provider_settings.provider_type} with "
f"model {embedding_provider_settings.model_name}"
)
qdrant_configuration = QdrantSettings()
qdrant_connector = QdrantConnector(
qdrant_configuration.location,
qdrant_configuration.api_key,
qdrant_configuration.collection_name,
embedding_provider,
qdrant_configuration.local_path,
)
logger.info(
f"Connecting to Qdrant at {qdrant_configuration.get_qdrant_location()}"
)
yield {
"embedding_provider": embedding_provider,
"qdrant_connector": qdrant_connector,
}
except Exception as e:
logger.error(e)
raise e
finally:
pass
def serve(
qdrant_connector: QdrantConnector,
) -> Server:
"""
Instantiate the server and configure tools to store and find memories in Qdrant.
:param qdrant_connector: An instance of QdrantConnector to use for storing and retrieving memories.
"""
server = Server("qdrant")
mcp = FastMCP("Qdrant", lifespan=server_lifespan)
@server.list_tools()
async def handle_list_tools() -> list[types.Tool]:
"""
Return the list of tools that the server provides. By default, there are two
tools: one to store memories and another to find them. Finding the memories is not
implemented as a resource, as it requires a query to be passed and resources point
to a very specific piece of data.
"""
return [
types.Tool(
@mcp.tool(
name="qdrant-store-memory",
description=(
"Keep the memory for later use, when you are asked to remember something."
),
inputSchema={
"type": "object",
"properties": {
"information": {
"type": "string",
},
},
"required": ["information"],
},
),
types.Tool(
)
async def store(information: str, ctx: Context) -> str:
"""
Store a memory in Qdrant.
:param information: The information to store.
:param ctx: The context for the request.
:return: A message indicating that the information was stored.
"""
await ctx.debug(f"Storing information {information} in Qdrant")
qdrant_connector: QdrantConnector = ctx.request_context.lifespan_context[
"qdrant_connector"
]
await qdrant_connector.store(information)
return f"Remembered: {information}"
@mcp.tool(
name="qdrant-find-memories",
description=(
"Look up memories in Qdrant. Use this tool when you need to: \n"
@@ -62,153 +95,28 @@ def serve(
" - Access memories for further analysis \n"
" - Get some personal information about the user"
),
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to search for",
}
},
"required": ["query"],
},
),
)
async def find(query: str, ctx: Context) -> List[str]:
"""
Find memories in Qdrant.
:param query: The query to use for the search.
:param ctx: The context for the request.
:return: A list of entries found.
"""
await ctx.debug(f"Finding points for query {query}")
qdrant_connector: QdrantConnector = ctx.request_context.lifespan_context[
"qdrant_connector"
]
@server.call_tool()
async def handle_tool_call(
name: str, arguments: dict | None
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
if name not in ["qdrant-store-memory", "qdrant-find-memories"]:
raise ValueError(f"Unknown tool: {name}")
if name == "qdrant-store-memory":
if not arguments or "information" not in arguments:
raise ValueError("Missing required argument 'information'")
information = arguments["information"]
await qdrant_connector.store_memory(information)
return [types.TextContent(type="text", text=f"Remembered: {information}")]
if name == "qdrant-find-memories":
if not arguments or "query" not in arguments:
raise ValueError("Missing required argument 'query'")
query = arguments["query"]
memories = await qdrant_connector.find_memories(query)
entries = await qdrant_connector.search(query)
if not entries:
return [f"No memories found for the query '{query}'"]
content = [
types.TextContent(
type="text", text=f"Memories for the query '{query}'"
),
f"Memories for the query '{query}'",
]
for memory in memories:
content.append(
types.TextContent(type="text", text=f"<memory>{memory}</memory>")
)
for entry in entries:
content.append(f"<entry>{entry}</entry>")
return content
raise ValueError(f"Unknown tool: {name}")
return server
@click.command()
@click.option(
"--qdrant-url",
envvar="QDRANT_URL",
required=False,
help="Qdrant URL",
)
@click.option(
"--qdrant-api-key",
envvar="QDRANT_API_KEY",
required=False,
help="Qdrant API key",
)
@click.option(
"--collection-name",
envvar="COLLECTION_NAME",
required=True,
help="Collection name",
)
@click.option(
"--fastembed-model-name",
envvar="FASTEMBED_MODEL_NAME",
required=False,
help="FastEmbed model name",
default="sentence-transformers/all-MiniLM-L6-v2",
)
@click.option(
"--embedding-provider",
envvar="EMBEDDING_PROVIDER",
required=False,
help="Embedding provider to use",
default="fastembed",
type=click.Choice(["fastembed"], case_sensitive=False),
)
@click.option(
"--embedding-model",
envvar="EMBEDDING_MODEL",
required=False,
help="Embedding model name",
default="sentence-transformers/all-MiniLM-L6-v2",
)
@click.option(
"--qdrant-local-path",
envvar="QDRANT_LOCAL_PATH",
required=False,
help="Qdrant local path",
)
def main(
qdrant_url: Optional[str],
qdrant_api_key: str,
collection_name: Optional[str],
fastembed_model_name: Optional[str],
embedding_provider: str,
embedding_model: str,
qdrant_local_path: Optional[str],
):
# XOR of url and local path, since we accept only one of them
if not (bool(qdrant_url) ^ bool(qdrant_local_path)):
raise ValueError(
"Exactly one of qdrant-url or qdrant-local-path must be provided"
)
# Warn if fastembed_model_name is provided, as this is going to be deprecated
if fastembed_model_name:
click.echo(
"Warning: --fastembed-model-name parameter is deprecated and will be removed in a future version. "
"Please use --embedding-provider and --embedding-model instead",
err=True,
)
async def _run():
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
# Create the embedding provider
provider = create_embedding_provider(
provider_type=embedding_provider, model_name=embedding_model
)
# Create the Qdrant connector
qdrant_connector = QdrantConnector(
qdrant_url,
qdrant_api_key,
collection_name,
provider,
qdrant_local_path,
)
# Create and run the server
server = serve(qdrant_connector)
await server.run(
read_stream,
write_stream,
InitializationOptions(
server_name="qdrant",
server_version=get_package_version(),
capabilities=server.get_capabilities(
notification_options=NotificationOptions(),
experimental_capabilities={},
),
),
)
asyncio.run(_run())
if __name__ == "__main__":
mcp.run()

View File

@@ -0,0 +1,101 @@
import argparse
from typing import Any, Dict, Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class EmbeddingProviderSettings(BaseSettings):
"""
Configuration for the embedding provider.
"""
provider_type: str = Field(
default="fastembed", validation_alias="EMBEDDING_PROVIDER"
)
model_name: str = Field(
default="sentence-transformers/all-MiniLM-L6-v2",
validation_alias="EMBEDDING_MODEL",
)
class QdrantSettings(BaseSettings):
"""
Configuration for the Qdrant connector.
"""
location: Optional[str] = Field(default=None, validation_alias="QDRANT_URL")
api_key: Optional[str] = Field(default=None, validation_alias="QDRANT_API_KEY")
collection_name: str = Field(validation_alias="COLLECTION_NAME")
local_path: Optional[str] = Field(
default=None, validation_alias="QDRANT_LOCAL_PATH"
)
def get_qdrant_location(self) -> str:
"""
Get the Qdrant location, either the URL or the local path.
"""
return self.location or self.local_path
def parse_args() -> Dict[str, Any]:
"""
Parse command line arguments for the MCP server.
Returns:
Dict[str, Any]: Dictionary of parsed arguments
"""
parser = argparse.ArgumentParser(description="Qdrant MCP Server")
# Qdrant connection options
connection_group = parser.add_mutually_exclusive_group()
connection_group.add_argument(
"--qdrant-url",
help="URL of the Qdrant server, e.g. http://localhost:6333",
)
connection_group.add_argument(
"--qdrant-local-path",
help="Path to the local Qdrant database",
)
# Other Qdrant settings
parser.add_argument(
"--qdrant-api-key",
help="API key for the Qdrant server",
)
parser.add_argument(
"--collection-name",
help="Name of the collection to use",
)
# Embedding settings
parser.add_argument(
"--embedding-provider",
help="Embedding provider to use (currently only 'fastembed' is supported)",
)
parser.add_argument(
"--embedding-model",
help="Name of the embedding model to use",
)
args = parser.parse_args()
# Convert to dictionary and filter out None values
args_dict = {k: v for k, v in vars(args).items() if v is not None}
# Convert argument names to environment variable format
env_vars = {}
if "qdrant_url" in args_dict:
env_vars["QDRANT_URL"] = args_dict["qdrant_url"]
if "qdrant_api_key" in args_dict:
env_vars["QDRANT_API_KEY"] = args_dict["qdrant_api_key"]
if "collection_name" in args_dict:
env_vars["COLLECTION_NAME"] = args_dict["collection_name"]
if "embedding_model" in args_dict:
env_vars["EMBEDDING_MODEL"] = args_dict["embedding_model"]
if "embedding_provider" in args_dict:
env_vars["EMBEDDING_PROVIDER"] = args_dict["embedding_provider"]
if "qdrant_local_path" in args_dict:
env_vars["QDRANT_LOCAL_PATH"] = args_dict["qdrant_local_path"]
return env_vars

61
tests/test_config.py Normal file
View File

@@ -0,0 +1,61 @@
import os
from unittest.mock import patch
import pytest
from mcp_server_qdrant.settings import EmbeddingProviderSettings, QdrantSettings
class TestQdrantSettings:
def test_default_values(self):
"""Test that required fields raise errors when not provided."""
with pytest.raises(ValueError):
# Should raise error because required fields are missing
QdrantSettings()
@patch.dict(
os.environ,
{"QDRANT_URL": "http://localhost:6333", "COLLECTION_NAME": "test_collection"},
)
def test_minimal_config(self):
"""Test loading minimal configuration from environment variables."""
settings = QdrantSettings()
assert settings.location == "http://localhost:6333"
assert settings.collection_name == "test_collection"
assert settings.api_key is None
assert settings.local_path is None
@patch.dict(
os.environ,
{
"QDRANT_URL": "http://qdrant.example.com:6333",
"QDRANT_API_KEY": "test_api_key",
"COLLECTION_NAME": "my_memories",
"QDRANT_LOCAL_PATH": "/tmp/qdrant",
},
)
def test_full_config(self):
"""Test loading full configuration from environment variables."""
settings = QdrantSettings()
assert settings.location == "http://qdrant.example.com:6333"
assert settings.api_key == "test_api_key"
assert settings.collection_name == "my_memories"
assert settings.local_path == "/tmp/qdrant"
class TestEmbeddingProviderSettings:
def test_default_values(self):
"""Test default values are set correctly."""
settings = EmbeddingProviderSettings()
assert settings.provider_type == "fastembed"
assert settings.model_name == "sentence-transformers/all-MiniLM-L6-v2"
@patch.dict(
os.environ,
{"EMBEDDING_PROVIDER": "custom_provider", "EMBEDDING_MODEL": "custom_model"},
)
def test_custom_values(self):
"""Test loading custom values from environment variables."""
settings = EmbeddingProviderSettings()
assert settings.provider_type == "custom_provider"
assert settings.model_name == "custom_model"

View File

@@ -0,0 +1,63 @@
import numpy as np
import pytest
from fastembed import TextEmbedding
from mcp_server_qdrant.embeddings.fastembed import FastEmbedProvider
@pytest.mark.asyncio
class TestFastEmbedProviderIntegration:
"""Integration tests for FastEmbedProvider."""
async def test_initialization(self):
"""Test that the provider can be initialized with a valid model."""
provider = FastEmbedProvider("sentence-transformers/all-MiniLM-L6-v2")
assert provider.model_name == "sentence-transformers/all-MiniLM-L6-v2"
assert isinstance(provider.embedding_model, TextEmbedding)
async def test_embed_documents(self):
"""Test that documents can be embedded."""
provider = FastEmbedProvider("sentence-transformers/all-MiniLM-L6-v2")
documents = ["This is a test document.", "This is another test document."]
embeddings = await provider.embed_documents(documents)
# Check that we got the right number of embeddings
assert len(embeddings) == len(documents)
# Check that embeddings have the expected shape
# The exact dimension depends on the model, but should be consistent
assert len(embeddings[0]) > 0
assert all(len(embedding) == len(embeddings[0]) for embedding in embeddings)
# Check that embeddings are different for different documents
# Convert to numpy arrays for easier comparison
embedding1 = np.array(embeddings[0])
embedding2 = np.array(embeddings[1])
assert not np.array_equal(embedding1, embedding2)
async def test_embed_query(self):
"""Test that queries can be embedded."""
provider = FastEmbedProvider("sentence-transformers/all-MiniLM-L6-v2")
query = "This is a test query."
embedding = await provider.embed_query(query)
# Check that embedding has the expected shape
assert len(embedding) > 0
# Embed the same query again to check consistency
embedding2 = await provider.embed_query(query)
assert len(embedding) == len(embedding2)
# The embeddings should be identical for the same input
np.testing.assert_array_almost_equal(np.array(embedding), np.array(embedding2))
def test_get_vector_name(self):
"""Test that the vector name is generated correctly."""
provider = FastEmbedProvider("sentence-transformers/all-MiniLM-L6-v2")
vector_name = provider.get_vector_name()
# Check that the vector name follows the expected format
assert vector_name.startswith("fast-")
assert "minilm" in vector_name.lower()

1357
uv.lock generated

File diff suppressed because it is too large Load Diff