Add custom server implementation to ease the development of the tools

This commit is contained in:
Kacper Łukawski
2025-03-07 18:11:15 +01:00
parent b2c96ba7de
commit f8d3cc474b
8 changed files with 304 additions and 90 deletions

View File

@@ -0,0 +1,137 @@
import inspect
from typing import Any, Callable, Dict, List, Optional, get_type_hints
from mcp import types
from mcp.server import Server
class QdrantMCPServer(Server):
"""
An MCP server that uses Qdrant to store and retrieve information.
"""
def __init__(self, name: str = "Qdrant"):
super().__init__(name)
self._tool_handlers: Dict[str, Callable] = {}
self._tools: List[types.Tool] = []
# This monkeypatching is required to make the server list the tools
# and handle tool calls. It simplifies the process of registering
# tool handlers. Please do not remove it.
self.handle_list_tool = self.list_tools()(self.handle_list_tool)
self.handle_tool_call = self.call_tool()(self.handle_tool_call)
def register_tool(
self,
*,
description: str,
name: Optional[str] = None,
input_schema: Optional[dict[str, Any]] = None,
):
"""
A decorator to register a tool with the server. The description is used
to generate the tool's metadata.
Name is optional, and if not provided, the function's name will be used.
:param description: The description of the tool.
:param name: The name of the tool. If not provided, the function's name will be used.
:param input_schema: The input schema for the tool. If not provided, it will be
automatically generated from the function's parameters.
"""
def decorator(func: Callable):
def wrapper(fn):
nonlocal name, input_schema
# Use function name if name not provided
if name is None:
name = fn.__name__
# If no input schema is provided, generate one from the function parameters
if input_schema is None:
input_schema = self.__parse_function_parameters(fn)
# Create the tool definition
tool = types.Tool(
name=name,
description=description,
inputSchema=input_schema,
)
# Register in both collections
self._tool_handlers[name] = fn
self._tools.append(tool)
return fn
# Handle both @register_tool and @register_tool() syntax
if func is None:
return wrapper
return wrapper(func)
return decorator
async def handle_list_tool(self) -> List[types.Tool]:
"""Expose the list of tools to the server."""
return self._tools
async def handle_tool_call(
self, name: str, arguments: dict | None
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
"""Handle any tool call."""
if name not in self._tool_handlers:
raise ValueError(f"Unknown tool: {name}")
return await self._tool_handlers[name](**arguments)
@staticmethod
def __parse_function_parameters(func: Callable) -> Dict[str, Any]:
"""
Parse the parameters of a function to create an input schema.
:param func: The function to parse.
:return: A dictionary representing the input schema.
"""
signature = inspect.signature(func)
type_hints = get_type_hints(func)
properties = {}
required = []
for param_name, param in signature.parameters.items():
# Skip self parameter for methods
if param_name == "self":
continue
param_type = type_hints.get(param_name, Any)
param_schema = {"type": "string"} # Default to string
# Map Python types to JSON Schema types
if param_type in (int, float):
param_schema["type"] = "number"
elif param_type is bool:
param_schema["type"] = "boolean"
elif param_type is list or getattr(param_type, "__origin__", None) is list:
param_schema["type"] = "array"
# Get default value if any
if param.default is not inspect.Parameter.empty:
param_schema["default"] = param.default
else:
required.append(param_name)
# Get description from docstring if available
if func.__doc__:
param_docs = [
line.strip()
for line in func.__doc__.split("\n")
if f":param {param_name}:" in line
]
if param_docs:
description = (
param_docs[0].split(f":param {param_name}:")[1].strip()
)
param_schema["description"] = description
properties[param_name] = param_schema
return {"type": "object", "properties": properties, "required": required}

View File

@@ -7,14 +7,12 @@ from .base import EmbeddingProvider
class FastEmbedProvider(EmbeddingProvider):
"""FastEmbed implementation of the embedding provider."""
"""
FastEmbed implementation of the embedding provider.
:param model_name: The name of the FastEmbed model to use.
"""
def __init__(self, model_name: str):
"""
Initialize the FastEmbed provider.
:param model_name: The name of the FastEmbed model to use.
"""
self.model_name = model_name
self.embedding_model = TextEmbedding(model_name)

View File

@@ -0,0 +1,13 @@
import functools
from typing import Callable
def register_task(func: Callable) -> Callable:
@functools.wraps(func)
async def wrapper(*args, **kwargs):
print(f"Starting task {func.__name__}")
result = await func(*args, **kwargs)
print(f"Finished task {func.__name__}")
return result
return wrapper

View File

View File

@@ -0,0 +1,10 @@
import importlib.metadata
def get_package_version() -> str:
"""Get the package version using importlib.metadata."""
try:
return importlib.metadata.version("mcp-server-qdrant")
except importlib.metadata.PackageNotFoundError:
# Fall back to a default version if package is not installed
return "0.0.0"

View File

@@ -1,5 +1,4 @@
import asyncio
import importlib.metadata
from typing import Optional
import click
@@ -8,19 +7,12 @@ import mcp.types as types
from mcp.server import NotificationOptions, Server
from mcp.server.models import InitializationOptions
from .custom_server import QdrantMCPServer
from .embeddings.factory import create_embedding_provider
from .helper import get_package_version
from .qdrant import QdrantConnector
def get_package_version() -> str:
"""Get the package version using importlib.metadata."""
try:
return importlib.metadata.version("mcp-server-qdrant")
except importlib.metadata.PackageNotFoundError:
# Fall back to a default version if package is not installed
return "0.0.0"
def serve(
qdrant_connector: QdrantConnector,
) -> Server:
@@ -28,84 +20,48 @@ def serve(
Instantiate the server and configure tools to store and find memories in Qdrant.
:param qdrant_connector: An instance of QdrantConnector to use for storing and retrieving memories.
"""
server = Server("qdrant")
server = QdrantMCPServer("qdrant")
@server.list_tools()
async def handle_list_tools() -> list[types.Tool]:
@server.register_tool(
name="qdrant-store-memory",
description=(
"Keep the memory for later use, when you are asked to remember something."
),
)
async def store_memory(information: str):
"""
Return the list of tools that the server provides. By default, there are two
tools: one to store memories and another to find them. Finding the memories is not
implemented as a resource, as it requires a query to be passed and resources point
to a very specific piece of data.
Store a memory in Qdrant.
:param information: The information to store.
"""
return [
types.Tool(
name="qdrant-store-memory",
description=(
"Keep the memory for later use, when you are asked to remember something."
),
inputSchema={
"type": "object",
"properties": {
"information": {
"type": "string",
},
},
"required": ["information"],
},
),
types.Tool(
name="qdrant-find-memories",
description=(
"Look up memories in Qdrant. Use this tool when you need to: \n"
" - Find memories by their content \n"
" - Access memories for further analysis \n"
" - Get some personal information about the user"
),
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to search for",
}
},
"required": ["query"],
},
),
nonlocal qdrant_connector
await qdrant_connector.store_memory(information)
return [types.TextContent(type="text", text=f"Remembered: {information}")]
@server.register_tool(
name="qdrant-find-memories",
description=(
"Look up memories in Qdrant. Use this tool when you need to: \n"
" - Find memories by their content \n"
" - Access memories for further analysis \n"
" - Get some personal information about the user"
),
)
async def find_memories(query: str):
"""
Find memories in Qdrant.
:param query: The query to use for the search.
:return: A list of memories found.
"""
nonlocal qdrant_connector
memories = await qdrant_connector.find_memories(query)
content = [
types.TextContent(type="text", text=f"Memories for the query '{query}'"),
]
@server.call_tool()
async def handle_tool_call(
name: str, arguments: dict | None
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
if name not in ["qdrant-store-memory", "qdrant-find-memories"]:
raise ValueError(f"Unknown tool: {name}")
if name == "qdrant-store-memory":
if not arguments or "information" not in arguments:
raise ValueError("Missing required argument 'information'")
information = arguments["information"]
await qdrant_connector.store_memory(information)
return [types.TextContent(type="text", text=f"Remembered: {information}")]
if name == "qdrant-find-memories":
if not arguments or "query" not in arguments:
raise ValueError("Missing required argument 'query'")
query = arguments["query"]
memories = await qdrant_connector.find_memories(query)
content = [
types.TextContent(
type="text", text=f"Memories for the query '{query}'"
),
]
for memory in memories:
content.append(
types.TextContent(type="text", text=f"<memory>{memory}</memory>")
)
return content
raise ValueError(f"Unknown tool: {name}")
for memory in memories:
content.append(
types.TextContent(type="text", text=f"<memory>{memory}</memory>")
)
return content
return server