Implement the server with FastMCP

This commit is contained in:
Kacper Łukawski
2025-03-10 09:25:02 +01:00
parent f8d3cc474b
commit 18cc93f6c8
16 changed files with 1079 additions and 1064 deletions

View File

@@ -3,7 +3,7 @@ from . import server
def main():
"""Main entry point for the package."""
server.main()
server.mcp.run()
# Optionally expose other important items at package level

View File

@@ -1,137 +0,0 @@
import inspect
from typing import Any, Callable, Dict, List, Optional, get_type_hints
from mcp import types
from mcp.server import Server
class QdrantMCPServer(Server):
"""
An MCP server that uses Qdrant to store and retrieve information.
"""
def __init__(self, name: str = "Qdrant"):
super().__init__(name)
self._tool_handlers: Dict[str, Callable] = {}
self._tools: List[types.Tool] = []
# This monkeypatching is required to make the server list the tools
# and handle tool calls. It simplifies the process of registering
# tool handlers. Please do not remove it.
self.handle_list_tool = self.list_tools()(self.handle_list_tool)
self.handle_tool_call = self.call_tool()(self.handle_tool_call)
def register_tool(
self,
*,
description: str,
name: Optional[str] = None,
input_schema: Optional[dict[str, Any]] = None,
):
"""
A decorator to register a tool with the server. The description is used
to generate the tool's metadata.
Name is optional, and if not provided, the function's name will be used.
:param description: The description of the tool.
:param name: The name of the tool. If not provided, the function's name will be used.
:param input_schema: The input schema for the tool. If not provided, it will be
automatically generated from the function's parameters.
"""
def decorator(func: Callable):
def wrapper(fn):
nonlocal name, input_schema
# Use function name if name not provided
if name is None:
name = fn.__name__
# If no input schema is provided, generate one from the function parameters
if input_schema is None:
input_schema = self.__parse_function_parameters(fn)
# Create the tool definition
tool = types.Tool(
name=name,
description=description,
inputSchema=input_schema,
)
# Register in both collections
self._tool_handlers[name] = fn
self._tools.append(tool)
return fn
# Handle both @register_tool and @register_tool() syntax
if func is None:
return wrapper
return wrapper(func)
return decorator
async def handle_list_tool(self) -> List[types.Tool]:
"""Expose the list of tools to the server."""
return self._tools
async def handle_tool_call(
self, name: str, arguments: dict | None
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
"""Handle any tool call."""
if name not in self._tool_handlers:
raise ValueError(f"Unknown tool: {name}")
return await self._tool_handlers[name](**arguments)
@staticmethod
def __parse_function_parameters(func: Callable) -> Dict[str, Any]:
"""
Parse the parameters of a function to create an input schema.
:param func: The function to parse.
:return: A dictionary representing the input schema.
"""
signature = inspect.signature(func)
type_hints = get_type_hints(func)
properties = {}
required = []
for param_name, param in signature.parameters.items():
# Skip self parameter for methods
if param_name == "self":
continue
param_type = type_hints.get(param_name, Any)
param_schema = {"type": "string"} # Default to string
# Map Python types to JSON Schema types
if param_type in (int, float):
param_schema["type"] = "number"
elif param_type is bool:
param_schema["type"] = "boolean"
elif param_type is list or getattr(param_type, "__origin__", None) is list:
param_schema["type"] = "array"
# Get default value if any
if param.default is not inspect.Parameter.empty:
param_schema["default"] = param.default
else:
required.append(param_name)
# Get description from docstring if available
if func.__doc__:
param_docs = [
line.strip()
for line in func.__doc__.split("\n")
if f":param {param_name}:" in line
]
if param_docs:
description = (
param_docs[0].split(f":param {param_name}:")[1].strip()
)
param_schema["description"] = description
properties[param_name] = param_schema
return {"type": "object", "properties": properties, "required": required}

View File

@@ -1,17 +1,16 @@
from mcp_server_qdrant.embeddings import EmbeddingProvider
from mcp_server_qdrant.settings import EmbeddingProviderSettings
def create_embedding_provider(provider_type: str, model_name: str) -> EmbeddingProvider:
def create_embedding_provider(settings: EmbeddingProviderSettings) -> EmbeddingProvider:
"""
Create an embedding provider based on the specified type.
:param provider_type: The type of embedding provider to create.
:param model_name: The name of the model to use for embeddings, specific to the provider type.
:param settings: The settings for the embedding provider.
:return: An instance of the specified embedding provider.
"""
if provider_type.lower() == "fastembed":
from .fastembed import FastEmbedProvider
if settings.provider_type.lower() == "fastembed":
from mcp_server_qdrant.embeddings.fastembed import FastEmbedProvider
return FastEmbedProvider(model_name)
return FastEmbedProvider(settings.model_name)
else:
raise ValueError(f"Unsupported embedding provider: {provider_type}")
raise ValueError(f"Unsupported embedding provider: {settings.provider_type}")

View File

@@ -3,7 +3,7 @@ from typing import List
from fastembed import TextEmbedding
from .base import EmbeddingProvider
from mcp_server_qdrant.embeddings.base import EmbeddingProvider
class FastEmbedProvider(EmbeddingProvider):

View File

@@ -1,13 +0,0 @@
import functools
from typing import Callable
def register_task(func: Callable) -> Callable:
@functools.wraps(func)
async def wrapper(*args, **kwargs):
print(f"Starting task {func.__name__}")
result = await func(*args, **kwargs)
print(f"Finished task {func.__name__}")
return result
return wrapper

View File

@@ -1,10 +0,0 @@
import importlib.metadata
def get_package_version() -> str:
"""Get the package version using importlib.metadata."""
try:
return importlib.metadata.version("mcp-server-qdrant")
except importlib.metadata.PackageNotFoundError:
# Fall back to a default version if package is not installed
return "0.0.0"

View File

@@ -53,9 +53,9 @@ class QdrantConnector:
},
)
async def store_memory(self, information: str):
async def store(self, information: str):
"""
Store a memory in the Qdrant collection.
Store some information in the Qdrant collection.
:param information: The information to store.
"""
await self._ensure_collection_exists()
@@ -76,11 +76,11 @@ class QdrantConnector:
],
)
async def find_memories(self, query: str) -> list[str]:
async def search(self, query: str) -> list[str]:
"""
Find memories in the Qdrant collection. If there are no memories found, an empty list is returned.
Find points in the Qdrant collection. If there are no entries found, an empty list is returned.
:param query: The query to use for the search.
:return: A list of memories found.
:return: A list of entries found.
"""
collection_exists = await self._client.collection_exists(self._collection_name)
if not collection_exists:

View File

@@ -1,170 +1,122 @@
import asyncio
from typing import Optional
import logging
import os
from contextlib import asynccontextmanager
from typing import AsyncIterator, List
import click
import mcp
import mcp.types as types
from mcp.server import NotificationOptions, Server
from mcp.server.models import InitializationOptions
from mcp.server import Server
from mcp.server.fastmcp import Context, FastMCP
from .custom_server import QdrantMCPServer
from .embeddings.factory import create_embedding_provider
from .helper import get_package_version
from .qdrant import QdrantConnector
from mcp_server_qdrant.embeddings.factory import create_embedding_provider
from mcp_server_qdrant.qdrant import QdrantConnector
from mcp_server_qdrant.settings import (
EmbeddingProviderSettings,
QdrantSettings,
parse_args,
)
logger = logging.getLogger(__name__)
# Parse command line arguments and set them as environment variables.
# This is done for backwards compatibility with the previous versions
# of the MCP server.
env_vars = parse_args()
for key, value in env_vars.items():
os.environ[key] = value
def serve(
qdrant_connector: QdrantConnector,
) -> Server:
@asynccontextmanager
async def server_lifespan(server: Server) -> AsyncIterator[dict]: # noqa
"""
Instantiate the server and configure tools to store and find memories in Qdrant.
:param qdrant_connector: An instance of QdrantConnector to use for storing and retrieving memories.
Context manager to handle the lifespan of the server.
This is used to configure the embedding provider and Qdrant connector.
"""
server = QdrantMCPServer("qdrant")
@server.register_tool(
name="qdrant-store-memory",
description=(
"Keep the memory for later use, when you are asked to remember something."
),
)
async def store_memory(information: str):
"""
Store a memory in Qdrant.
:param information: The information to store.
"""
nonlocal qdrant_connector
await qdrant_connector.store_memory(information)
return [types.TextContent(type="text", text=f"Remembered: {information}")]
@server.register_tool(
name="qdrant-find-memories",
description=(
"Look up memories in Qdrant. Use this tool when you need to: \n"
" - Find memories by their content \n"
" - Access memories for further analysis \n"
" - Get some personal information about the user"
),
)
async def find_memories(query: str):
"""
Find memories in Qdrant.
:param query: The query to use for the search.
:return: A list of memories found.
"""
nonlocal qdrant_connector
memories = await qdrant_connector.find_memories(query)
content = [
types.TextContent(type="text", text=f"Memories for the query '{query}'"),
]
for memory in memories:
content.append(
types.TextContent(type="text", text=f"<memory>{memory}</memory>")
)
return content
return server
@click.command()
@click.option(
"--qdrant-url",
envvar="QDRANT_URL",
required=False,
help="Qdrant URL",
)
@click.option(
"--qdrant-api-key",
envvar="QDRANT_API_KEY",
required=False,
help="Qdrant API key",
)
@click.option(
"--collection-name",
envvar="COLLECTION_NAME",
required=True,
help="Collection name",
)
@click.option(
"--fastembed-model-name",
envvar="FASTEMBED_MODEL_NAME",
required=False,
help="FastEmbed model name",
default="sentence-transformers/all-MiniLM-L6-v2",
)
@click.option(
"--embedding-provider",
envvar="EMBEDDING_PROVIDER",
required=False,
help="Embedding provider to use",
default="fastembed",
type=click.Choice(["fastembed"], case_sensitive=False),
)
@click.option(
"--embedding-model",
envvar="EMBEDDING_MODEL",
required=False,
help="Embedding model name",
default="sentence-transformers/all-MiniLM-L6-v2",
)
@click.option(
"--qdrant-local-path",
envvar="QDRANT_LOCAL_PATH",
required=False,
help="Qdrant local path",
)
def main(
qdrant_url: Optional[str],
qdrant_api_key: str,
collection_name: Optional[str],
fastembed_model_name: Optional[str],
embedding_provider: str,
embedding_model: str,
qdrant_local_path: Optional[str],
):
# XOR of url and local path, since we accept only one of them
if not (bool(qdrant_url) ^ bool(qdrant_local_path)):
raise ValueError(
"Exactly one of qdrant-url or qdrant-local-path must be provided"
try:
# Embedding provider is created with a factory function so we can add
# some more providers in the future. Currently, only FastEmbed is supported.
embedding_provider_settings = EmbeddingProviderSettings()
embedding_provider = create_embedding_provider(embedding_provider_settings)
logger.info(
f"Using embedding provider {embedding_provider_settings.provider_type} with "
f"model {embedding_provider_settings.model_name}"
)
# Warn if fastembed_model_name is provided, as this is going to be deprecated
if fastembed_model_name:
click.echo(
"Warning: --fastembed-model-name parameter is deprecated and will be removed in a future version. "
"Please use --embedding-provider and --embedding-model instead",
err=True,
qdrant_configuration = QdrantSettings()
qdrant_connector = QdrantConnector(
qdrant_configuration.location,
qdrant_configuration.api_key,
qdrant_configuration.collection_name,
embedding_provider,
qdrant_configuration.local_path,
)
logger.info(
f"Connecting to Qdrant at {qdrant_configuration.get_qdrant_location()}"
)
async def _run():
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
# Create the embedding provider
provider = create_embedding_provider(
provider_type=embedding_provider, model_name=embedding_model
)
yield {
"embedding_provider": embedding_provider,
"qdrant_connector": qdrant_connector,
}
except Exception as e:
logger.error(e)
raise e
finally:
pass
# Create the Qdrant connector
qdrant_connector = QdrantConnector(
qdrant_url,
qdrant_api_key,
collection_name,
provider,
qdrant_local_path,
)
# Create and run the server
server = serve(qdrant_connector)
await server.run(
read_stream,
write_stream,
InitializationOptions(
server_name="qdrant",
server_version=get_package_version(),
capabilities=server.get_capabilities(
notification_options=NotificationOptions(),
experimental_capabilities={},
),
),
)
mcp = FastMCP("Qdrant", lifespan=server_lifespan)
asyncio.run(_run())
@mcp.tool(
name="qdrant-store-memory",
description=(
"Keep the memory for later use, when you are asked to remember something."
),
)
async def store(information: str, ctx: Context) -> str:
"""
Store a memory in Qdrant.
:param information: The information to store.
:param ctx: The context for the request.
:return: A message indicating that the information was stored.
"""
await ctx.debug(f"Storing information {information} in Qdrant")
qdrant_connector: QdrantConnector = ctx.request_context.lifespan_context[
"qdrant_connector"
]
await qdrant_connector.store(information)
return f"Remembered: {information}"
@mcp.tool(
name="qdrant-find-memories",
description=(
"Look up memories in Qdrant. Use this tool when you need to: \n"
" - Find memories by their content \n"
" - Access memories for further analysis \n"
" - Get some personal information about the user"
),
)
async def find(query: str, ctx: Context) -> List[str]:
"""
Find memories in Qdrant.
:param query: The query to use for the search.
:param ctx: The context for the request.
:return: A list of entries found.
"""
await ctx.debug(f"Finding points for query {query}")
qdrant_connector: QdrantConnector = ctx.request_context.lifespan_context[
"qdrant_connector"
]
entries = await qdrant_connector.search(query)
if not entries:
return [f"No memories found for the query '{query}'"]
content = [
f"Memories for the query '{query}'",
]
for entry in entries:
content.append(f"<entry>{entry}</entry>")
return content
if __name__ == "__main__":
mcp.run()

View File

@@ -0,0 +1,101 @@
import argparse
from typing import Any, Dict, Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class EmbeddingProviderSettings(BaseSettings):
"""
Configuration for the embedding provider.
"""
provider_type: str = Field(
default="fastembed", validation_alias="EMBEDDING_PROVIDER"
)
model_name: str = Field(
default="sentence-transformers/all-MiniLM-L6-v2",
validation_alias="EMBEDDING_MODEL",
)
class QdrantSettings(BaseSettings):
"""
Configuration for the Qdrant connector.
"""
location: Optional[str] = Field(default=None, validation_alias="QDRANT_URL")
api_key: Optional[str] = Field(default=None, validation_alias="QDRANT_API_KEY")
collection_name: str = Field(validation_alias="COLLECTION_NAME")
local_path: Optional[str] = Field(
default=None, validation_alias="QDRANT_LOCAL_PATH"
)
def get_qdrant_location(self) -> str:
"""
Get the Qdrant location, either the URL or the local path.
"""
return self.location or self.local_path
def parse_args() -> Dict[str, Any]:
"""
Parse command line arguments for the MCP server.
Returns:
Dict[str, Any]: Dictionary of parsed arguments
"""
parser = argparse.ArgumentParser(description="Qdrant MCP Server")
# Qdrant connection options
connection_group = parser.add_mutually_exclusive_group()
connection_group.add_argument(
"--qdrant-url",
help="URL of the Qdrant server, e.g. http://localhost:6333",
)
connection_group.add_argument(
"--qdrant-local-path",
help="Path to the local Qdrant database",
)
# Other Qdrant settings
parser.add_argument(
"--qdrant-api-key",
help="API key for the Qdrant server",
)
parser.add_argument(
"--collection-name",
help="Name of the collection to use",
)
# Embedding settings
parser.add_argument(
"--embedding-provider",
help="Embedding provider to use (currently only 'fastembed' is supported)",
)
parser.add_argument(
"--embedding-model",
help="Name of the embedding model to use",
)
args = parser.parse_args()
# Convert to dictionary and filter out None values
args_dict = {k: v for k, v in vars(args).items() if v is not None}
# Convert argument names to environment variable format
env_vars = {}
if "qdrant_url" in args_dict:
env_vars["QDRANT_URL"] = args_dict["qdrant_url"]
if "qdrant_api_key" in args_dict:
env_vars["QDRANT_API_KEY"] = args_dict["qdrant_api_key"]
if "collection_name" in args_dict:
env_vars["COLLECTION_NAME"] = args_dict["collection_name"]
if "embedding_model" in args_dict:
env_vars["EMBEDDING_MODEL"] = args_dict["embedding_model"]
if "embedding_provider" in args_dict:
env_vars["EMBEDDING_PROVIDER"] = args_dict["embedding_provider"]
if "qdrant_local_path" in args_dict:
env_vars["QDRANT_LOCAL_PATH"] = args_dict["qdrant_local_path"]
return env_vars