Add custom server implementation to ease the development of the tools
This commit is contained in:
137
src/mcp_server_qdrant/custom_server.py
Normal file
137
src/mcp_server_qdrant/custom_server.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
import inspect
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, get_type_hints
|
||||||
|
|
||||||
|
from mcp import types
|
||||||
|
from mcp.server import Server
|
||||||
|
|
||||||
|
|
||||||
|
class QdrantMCPServer(Server):
|
||||||
|
"""
|
||||||
|
An MCP server that uses Qdrant to store and retrieve information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name: str = "Qdrant"):
|
||||||
|
super().__init__(name)
|
||||||
|
self._tool_handlers: Dict[str, Callable] = {}
|
||||||
|
self._tools: List[types.Tool] = []
|
||||||
|
# This monkeypatching is required to make the server list the tools
|
||||||
|
# and handle tool calls. It simplifies the process of registering
|
||||||
|
# tool handlers. Please do not remove it.
|
||||||
|
self.handle_list_tool = self.list_tools()(self.handle_list_tool)
|
||||||
|
self.handle_tool_call = self.call_tool()(self.handle_tool_call)
|
||||||
|
|
||||||
|
def register_tool(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
description: str,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
input_schema: Optional[dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
A decorator to register a tool with the server. The description is used
|
||||||
|
to generate the tool's metadata.
|
||||||
|
|
||||||
|
Name is optional, and if not provided, the function's name will be used.
|
||||||
|
|
||||||
|
:param description: The description of the tool.
|
||||||
|
:param name: The name of the tool. If not provided, the function's name will be used.
|
||||||
|
:param input_schema: The input schema for the tool. If not provided, it will be
|
||||||
|
automatically generated from the function's parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func: Callable):
|
||||||
|
def wrapper(fn):
|
||||||
|
nonlocal name, input_schema
|
||||||
|
|
||||||
|
# Use function name if name not provided
|
||||||
|
if name is None:
|
||||||
|
name = fn.__name__
|
||||||
|
|
||||||
|
# If no input schema is provided, generate one from the function parameters
|
||||||
|
if input_schema is None:
|
||||||
|
input_schema = self.__parse_function_parameters(fn)
|
||||||
|
|
||||||
|
# Create the tool definition
|
||||||
|
tool = types.Tool(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
inputSchema=input_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register in both collections
|
||||||
|
self._tool_handlers[name] = fn
|
||||||
|
self._tools.append(tool)
|
||||||
|
|
||||||
|
return fn
|
||||||
|
|
||||||
|
# Handle both @register_tool and @register_tool() syntax
|
||||||
|
if func is None:
|
||||||
|
return wrapper
|
||||||
|
return wrapper(func)
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
async def handle_list_tool(self) -> List[types.Tool]:
|
||||||
|
"""Expose the list of tools to the server."""
|
||||||
|
return self._tools
|
||||||
|
|
||||||
|
async def handle_tool_call(
|
||||||
|
self, name: str, arguments: dict | None
|
||||||
|
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||||
|
"""Handle any tool call."""
|
||||||
|
if name not in self._tool_handlers:
|
||||||
|
raise ValueError(f"Unknown tool: {name}")
|
||||||
|
return await self._tool_handlers[name](**arguments)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __parse_function_parameters(func: Callable) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Parse the parameters of a function to create an input schema.
|
||||||
|
|
||||||
|
:param func: The function to parse.
|
||||||
|
:return: A dictionary representing the input schema.
|
||||||
|
"""
|
||||||
|
signature = inspect.signature(func)
|
||||||
|
type_hints = get_type_hints(func)
|
||||||
|
|
||||||
|
properties = {}
|
||||||
|
required = []
|
||||||
|
|
||||||
|
for param_name, param in signature.parameters.items():
|
||||||
|
# Skip self parameter for methods
|
||||||
|
if param_name == "self":
|
||||||
|
continue
|
||||||
|
|
||||||
|
param_type = type_hints.get(param_name, Any)
|
||||||
|
param_schema = {"type": "string"} # Default to string
|
||||||
|
|
||||||
|
# Map Python types to JSON Schema types
|
||||||
|
if param_type in (int, float):
|
||||||
|
param_schema["type"] = "number"
|
||||||
|
elif param_type is bool:
|
||||||
|
param_schema["type"] = "boolean"
|
||||||
|
elif param_type is list or getattr(param_type, "__origin__", None) is list:
|
||||||
|
param_schema["type"] = "array"
|
||||||
|
|
||||||
|
# Get default value if any
|
||||||
|
if param.default is not inspect.Parameter.empty:
|
||||||
|
param_schema["default"] = param.default
|
||||||
|
else:
|
||||||
|
required.append(param_name)
|
||||||
|
|
||||||
|
# Get description from docstring if available
|
||||||
|
if func.__doc__:
|
||||||
|
param_docs = [
|
||||||
|
line.strip()
|
||||||
|
for line in func.__doc__.split("\n")
|
||||||
|
if f":param {param_name}:" in line
|
||||||
|
]
|
||||||
|
if param_docs:
|
||||||
|
description = (
|
||||||
|
param_docs[0].split(f":param {param_name}:")[1].strip()
|
||||||
|
)
|
||||||
|
param_schema["description"] = description
|
||||||
|
|
||||||
|
properties[param_name] = param_schema
|
||||||
|
|
||||||
|
return {"type": "object", "properties": properties, "required": required}
|
||||||
@@ -7,14 +7,12 @@ from .base import EmbeddingProvider
|
|||||||
|
|
||||||
|
|
||||||
class FastEmbedProvider(EmbeddingProvider):
|
class FastEmbedProvider(EmbeddingProvider):
|
||||||
"""FastEmbed implementation of the embedding provider."""
|
"""
|
||||||
|
FastEmbed implementation of the embedding provider.
|
||||||
|
:param model_name: The name of the FastEmbed model to use.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, model_name: str):
|
def __init__(self, model_name: str):
|
||||||
"""
|
|
||||||
Initialize the FastEmbed provider.
|
|
||||||
|
|
||||||
:param model_name: The name of the FastEmbed model to use.
|
|
||||||
"""
|
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.embedding_model = TextEmbedding(model_name)
|
self.embedding_model = TextEmbedding(model_name)
|
||||||
|
|
||||||
|
|||||||
0
src/mcp_server_qdrant/handlers/__init__.py
Normal file
0
src/mcp_server_qdrant/handlers/__init__.py
Normal file
13
src/mcp_server_qdrant/handlers/decorators.py
Normal file
13
src/mcp_server_qdrant/handlers/decorators.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
import functools
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
|
||||||
|
def register_task(func: Callable) -> Callable:
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
print(f"Starting task {func.__name__}")
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
print(f"Finished task {func.__name__}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
return wrapper
|
||||||
0
src/mcp_server_qdrant/handlers/memory.py
Normal file
0
src/mcp_server_qdrant/handlers/memory.py
Normal file
10
src/mcp_server_qdrant/helper.py
Normal file
10
src/mcp_server_qdrant/helper.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
import importlib.metadata
|
||||||
|
|
||||||
|
|
||||||
|
def get_package_version() -> str:
|
||||||
|
"""Get the package version using importlib.metadata."""
|
||||||
|
try:
|
||||||
|
return importlib.metadata.version("mcp-server-qdrant")
|
||||||
|
except importlib.metadata.PackageNotFoundError:
|
||||||
|
# Fall back to a default version if package is not installed
|
||||||
|
return "0.0.0"
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import importlib.metadata
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import click
|
import click
|
||||||
@@ -8,19 +7,12 @@ import mcp.types as types
|
|||||||
from mcp.server import NotificationOptions, Server
|
from mcp.server import NotificationOptions, Server
|
||||||
from mcp.server.models import InitializationOptions
|
from mcp.server.models import InitializationOptions
|
||||||
|
|
||||||
|
from .custom_server import QdrantMCPServer
|
||||||
from .embeddings.factory import create_embedding_provider
|
from .embeddings.factory import create_embedding_provider
|
||||||
|
from .helper import get_package_version
|
||||||
from .qdrant import QdrantConnector
|
from .qdrant import QdrantConnector
|
||||||
|
|
||||||
|
|
||||||
def get_package_version() -> str:
|
|
||||||
"""Get the package version using importlib.metadata."""
|
|
||||||
try:
|
|
||||||
return importlib.metadata.version("mcp-server-qdrant")
|
|
||||||
except importlib.metadata.PackageNotFoundError:
|
|
||||||
# Fall back to a default version if package is not installed
|
|
||||||
return "0.0.0"
|
|
||||||
|
|
||||||
|
|
||||||
def serve(
|
def serve(
|
||||||
qdrant_connector: QdrantConnector,
|
qdrant_connector: QdrantConnector,
|
||||||
) -> Server:
|
) -> Server:
|
||||||
@@ -28,84 +20,48 @@ def serve(
|
|||||||
Instantiate the server and configure tools to store and find memories in Qdrant.
|
Instantiate the server and configure tools to store and find memories in Qdrant.
|
||||||
:param qdrant_connector: An instance of QdrantConnector to use for storing and retrieving memories.
|
:param qdrant_connector: An instance of QdrantConnector to use for storing and retrieving memories.
|
||||||
"""
|
"""
|
||||||
server = Server("qdrant")
|
server = QdrantMCPServer("qdrant")
|
||||||
|
|
||||||
@server.list_tools()
|
@server.register_tool(
|
||||||
async def handle_list_tools() -> list[types.Tool]:
|
name="qdrant-store-memory",
|
||||||
|
description=(
|
||||||
|
"Keep the memory for later use, when you are asked to remember something."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
async def store_memory(information: str):
|
||||||
"""
|
"""
|
||||||
Return the list of tools that the server provides. By default, there are two
|
Store a memory in Qdrant.
|
||||||
tools: one to store memories and another to find them. Finding the memories is not
|
:param information: The information to store.
|
||||||
implemented as a resource, as it requires a query to be passed and resources point
|
|
||||||
to a very specific piece of data.
|
|
||||||
"""
|
"""
|
||||||
return [
|
nonlocal qdrant_connector
|
||||||
types.Tool(
|
await qdrant_connector.store_memory(information)
|
||||||
name="qdrant-store-memory",
|
return [types.TextContent(type="text", text=f"Remembered: {information}")]
|
||||||
description=(
|
|
||||||
"Keep the memory for later use, when you are asked to remember something."
|
@server.register_tool(
|
||||||
),
|
name="qdrant-find-memories",
|
||||||
inputSchema={
|
description=(
|
||||||
"type": "object",
|
"Look up memories in Qdrant. Use this tool when you need to: \n"
|
||||||
"properties": {
|
" - Find memories by their content \n"
|
||||||
"information": {
|
" - Access memories for further analysis \n"
|
||||||
"type": "string",
|
" - Get some personal information about the user"
|
||||||
},
|
),
|
||||||
},
|
)
|
||||||
"required": ["information"],
|
async def find_memories(query: str):
|
||||||
},
|
"""
|
||||||
),
|
Find memories in Qdrant.
|
||||||
types.Tool(
|
:param query: The query to use for the search.
|
||||||
name="qdrant-find-memories",
|
:return: A list of memories found.
|
||||||
description=(
|
"""
|
||||||
"Look up memories in Qdrant. Use this tool when you need to: \n"
|
nonlocal qdrant_connector
|
||||||
" - Find memories by their content \n"
|
memories = await qdrant_connector.find_memories(query)
|
||||||
" - Access memories for further analysis \n"
|
content = [
|
||||||
" - Get some personal information about the user"
|
types.TextContent(type="text", text=f"Memories for the query '{query}'"),
|
||||||
),
|
|
||||||
inputSchema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"query": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The query to search for",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["query"],
|
|
||||||
},
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
for memory in memories:
|
||||||
@server.call_tool()
|
content.append(
|
||||||
async def handle_tool_call(
|
types.TextContent(type="text", text=f"<memory>{memory}</memory>")
|
||||||
name: str, arguments: dict | None
|
)
|
||||||
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
return content
|
||||||
if name not in ["qdrant-store-memory", "qdrant-find-memories"]:
|
|
||||||
raise ValueError(f"Unknown tool: {name}")
|
|
||||||
|
|
||||||
if name == "qdrant-store-memory":
|
|
||||||
if not arguments or "information" not in arguments:
|
|
||||||
raise ValueError("Missing required argument 'information'")
|
|
||||||
information = arguments["information"]
|
|
||||||
await qdrant_connector.store_memory(information)
|
|
||||||
return [types.TextContent(type="text", text=f"Remembered: {information}")]
|
|
||||||
|
|
||||||
if name == "qdrant-find-memories":
|
|
||||||
if not arguments or "query" not in arguments:
|
|
||||||
raise ValueError("Missing required argument 'query'")
|
|
||||||
query = arguments["query"]
|
|
||||||
memories = await qdrant_connector.find_memories(query)
|
|
||||||
content = [
|
|
||||||
types.TextContent(
|
|
||||||
type="text", text=f"Memories for the query '{query}'"
|
|
||||||
),
|
|
||||||
]
|
|
||||||
for memory in memories:
|
|
||||||
content.append(
|
|
||||||
types.TextContent(type="text", text=f"<memory>{memory}</memory>")
|
|
||||||
)
|
|
||||||
return content
|
|
||||||
|
|
||||||
raise ValueError(f"Unknown tool: {name}")
|
|
||||||
|
|
||||||
return server
|
return server
|
||||||
|
|
||||||
|
|||||||
100
tests/test_custom_server.py
Normal file
100
tests/test_custom_server.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.mcp_server_qdrant.custom_server import QdrantMCPServer
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_tool_decorator():
|
||||||
|
"""Test that the register_tool method works as a decorator and correctly parses parameters."""
|
||||||
|
server = QdrantMCPServer()
|
||||||
|
|
||||||
|
@server.register_tool(description="Test function with different parameter types")
|
||||||
|
def test_function(text_param: str, number_param: int, flag_param: bool = False):
|
||||||
|
"""
|
||||||
|
A test function with different parameter types.
|
||||||
|
|
||||||
|
:param text_param: A string parameter
|
||||||
|
:param number_param: An integer parameter
|
||||||
|
:param flag_param: A boolean parameter with default
|
||||||
|
"""
|
||||||
|
return f"{text_param} {number_param} {flag_param}"
|
||||||
|
|
||||||
|
# Check that the function was registered in tool_handlers
|
||||||
|
assert "test_function" in server._tool_handlers
|
||||||
|
assert server._tool_handlers["test_function"] == test_function
|
||||||
|
|
||||||
|
# Check that the tool was added to tools list
|
||||||
|
assert len(server._tools) == 1
|
||||||
|
tool = server._tools[0]
|
||||||
|
assert tool.name == "test_function"
|
||||||
|
assert tool.description == "Test function with different parameter types"
|
||||||
|
|
||||||
|
# Check the generated schema
|
||||||
|
schema = tool.inputSchema
|
||||||
|
assert schema["type"] == "object"
|
||||||
|
|
||||||
|
# Check properties
|
||||||
|
properties = schema["properties"]
|
||||||
|
assert "text_param" in properties
|
||||||
|
assert properties["text_param"]["type"] == "string"
|
||||||
|
assert "description" in properties["text_param"]
|
||||||
|
|
||||||
|
assert "number_param" in properties
|
||||||
|
assert properties["number_param"]["type"] == "number"
|
||||||
|
|
||||||
|
assert "flag_param" in properties
|
||||||
|
assert properties["flag_param"]["type"] == "boolean"
|
||||||
|
assert "default" in properties["flag_param"]
|
||||||
|
assert properties["flag_param"]["default"] is False
|
||||||
|
|
||||||
|
# Check required fields
|
||||||
|
assert "required" in schema
|
||||||
|
assert "text_param" in schema["required"]
|
||||||
|
assert "number_param" in schema["required"]
|
||||||
|
assert "flag_param" not in schema["required"] # Has default value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_list_tool():
|
||||||
|
"""Test that handle_list_tool returns all registered tools."""
|
||||||
|
server = QdrantMCPServer()
|
||||||
|
|
||||||
|
# Register multiple tools
|
||||||
|
@server.register_tool(description="First test tool")
|
||||||
|
def tool_one(param1: str):
|
||||||
|
"""First tool."""
|
||||||
|
return param1
|
||||||
|
|
||||||
|
@server.register_tool(description="Second test tool")
|
||||||
|
def tool_two(param1: int, param2: bool = True):
|
||||||
|
"""Second tool."""
|
||||||
|
return param1, param2
|
||||||
|
|
||||||
|
@server.register_tool(name="custom_name", description="Tool with custom name")
|
||||||
|
def tool_three(param1: str):
|
||||||
|
"""Third tool with custom name."""
|
||||||
|
return param1
|
||||||
|
|
||||||
|
# Get the list of tools
|
||||||
|
tools = await server.handle_list_tool()
|
||||||
|
|
||||||
|
# Check that all tools are returned
|
||||||
|
assert len(tools) == 3
|
||||||
|
|
||||||
|
# Check tool names
|
||||||
|
tool_names = [tool.name for tool in tools]
|
||||||
|
assert "tool_one" in tool_names
|
||||||
|
assert "tool_two" in tool_names
|
||||||
|
assert "custom_name" in tool_names
|
||||||
|
assert "tool_three" not in tool_names # Should use custom name instead
|
||||||
|
|
||||||
|
# Check tool descriptions
|
||||||
|
descriptions = {tool.name: tool.description for tool in tools}
|
||||||
|
assert descriptions["tool_one"] == "First test tool"
|
||||||
|
assert descriptions["tool_two"] == "Second test tool"
|
||||||
|
assert descriptions["custom_name"] == "Tool with custom name"
|
||||||
|
|
||||||
|
# Check schemas are properly generated
|
||||||
|
for tool in tools:
|
||||||
|
assert tool.inputSchema is not None
|
||||||
|
assert tool.inputSchema["type"] == "object"
|
||||||
|
assert "properties" in tool.inputSchema
|
||||||
Reference in New Issue
Block a user