diff --git a/src/mcp_server_qdrant/custom_server.py b/src/mcp_server_qdrant/custom_server.py
new file mode 100644
index 0000000..dcddfa2
--- /dev/null
+++ b/src/mcp_server_qdrant/custom_server.py
@@ -0,0 +1,137 @@
+import inspect
+from typing import Any, Callable, Dict, List, Optional, get_type_hints
+
+from mcp import types
+from mcp.server import Server
+
+
+class QdrantMCPServer(Server):
+ """
+ An MCP server that uses Qdrant to store and retrieve information.
+ """
+
+ def __init__(self, name: str = "Qdrant"):
+ super().__init__(name)
+ self._tool_handlers: Dict[str, Callable] = {}
+ self._tools: List[types.Tool] = []
+ # This monkeypatching is required to make the server list the tools
+ # and handle tool calls. It simplifies the process of registering
+ # tool handlers. Please do not remove it.
+ self.handle_list_tool = self.list_tools()(self.handle_list_tool)
+ self.handle_tool_call = self.call_tool()(self.handle_tool_call)
+
+ def register_tool(
+ self,
+ *,
+ description: str,
+ name: Optional[str] = None,
+ input_schema: Optional[dict[str, Any]] = None,
+ ):
+ """
+ A decorator to register a tool with the server. The description is used
+ to generate the tool's metadata.
+
+ Name is optional, and if not provided, the function's name will be used.
+
+ :param description: The description of the tool.
+ :param name: The name of the tool. If not provided, the function's name will be used.
+ :param input_schema: The input schema for the tool. If not provided, it will be
+ automatically generated from the function's parameters.
+ """
+
+ def decorator(func: Callable):
+ def wrapper(fn):
+ nonlocal name, input_schema
+
+ # Use function name if name not provided
+ if name is None:
+ name = fn.__name__
+
+ # If no input schema is provided, generate one from the function parameters
+ if input_schema is None:
+ input_schema = self.__parse_function_parameters(fn)
+
+ # Create the tool definition
+ tool = types.Tool(
+ name=name,
+ description=description,
+ inputSchema=input_schema,
+ )
+
+ # Register in both collections
+ self._tool_handlers[name] = fn
+ self._tools.append(tool)
+
+ return fn
+
+ # Handle both @register_tool and @register_tool() syntax
+ if func is None:
+ return wrapper
+ return wrapper(func)
+
+ return decorator
+
+ async def handle_list_tool(self) -> List[types.Tool]:
+ """Expose the list of tools to the server."""
+ return self._tools
+
+ async def handle_tool_call(
+ self, name: str, arguments: dict | None
+ ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
+ """Handle any tool call."""
+ if name not in self._tool_handlers:
+ raise ValueError(f"Unknown tool: {name}")
+ return await self._tool_handlers[name](**arguments)
+
+ @staticmethod
+ def __parse_function_parameters(func: Callable) -> Dict[str, Any]:
+ """
+ Parse the parameters of a function to create an input schema.
+
+ :param func: The function to parse.
+ :return: A dictionary representing the input schema.
+ """
+ signature = inspect.signature(func)
+ type_hints = get_type_hints(func)
+
+ properties = {}
+ required = []
+
+ for param_name, param in signature.parameters.items():
+ # Skip self parameter for methods
+ if param_name == "self":
+ continue
+
+ param_type = type_hints.get(param_name, Any)
+ param_schema = {"type": "string"} # Default to string
+
+ # Map Python types to JSON Schema types
+ if param_type in (int, float):
+ param_schema["type"] = "number"
+ elif param_type is bool:
+ param_schema["type"] = "boolean"
+ elif param_type is list or getattr(param_type, "__origin__", None) is list:
+ param_schema["type"] = "array"
+
+ # Get default value if any
+ if param.default is not inspect.Parameter.empty:
+ param_schema["default"] = param.default
+ else:
+ required.append(param_name)
+
+ # Get description from docstring if available
+ if func.__doc__:
+ param_docs = [
+ line.strip()
+ for line in func.__doc__.split("\n")
+ if f":param {param_name}:" in line
+ ]
+ if param_docs:
+ description = (
+ param_docs[0].split(f":param {param_name}:")[1].strip()
+ )
+ param_schema["description"] = description
+
+ properties[param_name] = param_schema
+
+ return {"type": "object", "properties": properties, "required": required}
diff --git a/src/mcp_server_qdrant/embeddings/fastembed.py b/src/mcp_server_qdrant/embeddings/fastembed.py
index 90a9d9e..acfeedf 100644
--- a/src/mcp_server_qdrant/embeddings/fastembed.py
+++ b/src/mcp_server_qdrant/embeddings/fastembed.py
@@ -7,14 +7,12 @@ from .base import EmbeddingProvider
class FastEmbedProvider(EmbeddingProvider):
- """FastEmbed implementation of the embedding provider."""
+ """
+ FastEmbed implementation of the embedding provider.
+ :param model_name: The name of the FastEmbed model to use.
+ """
def __init__(self, model_name: str):
- """
- Initialize the FastEmbed provider.
-
- :param model_name: The name of the FastEmbed model to use.
- """
self.model_name = model_name
self.embedding_model = TextEmbedding(model_name)
diff --git a/src/mcp_server_qdrant/handlers/__init__.py b/src/mcp_server_qdrant/handlers/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/mcp_server_qdrant/handlers/decorators.py b/src/mcp_server_qdrant/handlers/decorators.py
new file mode 100644
index 0000000..d04de89
--- /dev/null
+++ b/src/mcp_server_qdrant/handlers/decorators.py
@@ -0,0 +1,13 @@
+import functools
+from typing import Callable
+
+
+def register_task(func: Callable) -> Callable:
+ @functools.wraps(func)
+ async def wrapper(*args, **kwargs):
+ print(f"Starting task {func.__name__}")
+ result = await func(*args, **kwargs)
+ print(f"Finished task {func.__name__}")
+ return result
+
+ return wrapper
diff --git a/src/mcp_server_qdrant/handlers/memory.py b/src/mcp_server_qdrant/handlers/memory.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/mcp_server_qdrant/helper.py b/src/mcp_server_qdrant/helper.py
new file mode 100644
index 0000000..659a81b
--- /dev/null
+++ b/src/mcp_server_qdrant/helper.py
@@ -0,0 +1,10 @@
+import importlib.metadata
+
+
+def get_package_version() -> str:
+ """Get the package version using importlib.metadata."""
+ try:
+ return importlib.metadata.version("mcp-server-qdrant")
+ except importlib.metadata.PackageNotFoundError:
+ # Fall back to a default version if package is not installed
+ return "0.0.0"
diff --git a/src/mcp_server_qdrant/server.py b/src/mcp_server_qdrant/server.py
index 95db682..1725b17 100644
--- a/src/mcp_server_qdrant/server.py
+++ b/src/mcp_server_qdrant/server.py
@@ -1,5 +1,4 @@
import asyncio
-import importlib.metadata
from typing import Optional
import click
@@ -8,19 +7,12 @@ import mcp.types as types
from mcp.server import NotificationOptions, Server
from mcp.server.models import InitializationOptions
+from .custom_server import QdrantMCPServer
from .embeddings.factory import create_embedding_provider
+from .helper import get_package_version
from .qdrant import QdrantConnector
-def get_package_version() -> str:
- """Get the package version using importlib.metadata."""
- try:
- return importlib.metadata.version("mcp-server-qdrant")
- except importlib.metadata.PackageNotFoundError:
- # Fall back to a default version if package is not installed
- return "0.0.0"
-
-
def serve(
qdrant_connector: QdrantConnector,
) -> Server:
@@ -28,84 +20,48 @@ def serve(
Instantiate the server and configure tools to store and find memories in Qdrant.
:param qdrant_connector: An instance of QdrantConnector to use for storing and retrieving memories.
"""
- server = Server("qdrant")
+ server = QdrantMCPServer("qdrant")
- @server.list_tools()
- async def handle_list_tools() -> list[types.Tool]:
+ @server.register_tool(
+ name="qdrant-store-memory",
+ description=(
+ "Keep the memory for later use, when you are asked to remember something."
+ ),
+ )
+ async def store_memory(information: str):
"""
- Return the list of tools that the server provides. By default, there are two
- tools: one to store memories and another to find them. Finding the memories is not
- implemented as a resource, as it requires a query to be passed and resources point
- to a very specific piece of data.
+ Store a memory in Qdrant.
+ :param information: The information to store.
"""
- return [
- types.Tool(
- name="qdrant-store-memory",
- description=(
- "Keep the memory for later use, when you are asked to remember something."
- ),
- inputSchema={
- "type": "object",
- "properties": {
- "information": {
- "type": "string",
- },
- },
- "required": ["information"],
- },
- ),
- types.Tool(
- name="qdrant-find-memories",
- description=(
- "Look up memories in Qdrant. Use this tool when you need to: \n"
- " - Find memories by their content \n"
- " - Access memories for further analysis \n"
- " - Get some personal information about the user"
- ),
- inputSchema={
- "type": "object",
- "properties": {
- "query": {
- "type": "string",
- "description": "The query to search for",
- }
- },
- "required": ["query"],
- },
- ),
+ nonlocal qdrant_connector
+ await qdrant_connector.store_memory(information)
+ return [types.TextContent(type="text", text=f"Remembered: {information}")]
+
+ @server.register_tool(
+ name="qdrant-find-memories",
+ description=(
+ "Look up memories in Qdrant. Use this tool when you need to: \n"
+ " - Find memories by their content \n"
+ " - Access memories for further analysis \n"
+ " - Get some personal information about the user"
+ ),
+ )
+ async def find_memories(query: str):
+ """
+ Find memories in Qdrant.
+ :param query: The query to use for the search.
+ :return: A list of memories found.
+ """
+ nonlocal qdrant_connector
+ memories = await qdrant_connector.find_memories(query)
+ content = [
+ types.TextContent(type="text", text=f"Memories for the query '{query}'"),
]
-
- @server.call_tool()
- async def handle_tool_call(
- name: str, arguments: dict | None
- ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
- if name not in ["qdrant-store-memory", "qdrant-find-memories"]:
- raise ValueError(f"Unknown tool: {name}")
-
- if name == "qdrant-store-memory":
- if not arguments or "information" not in arguments:
- raise ValueError("Missing required argument 'information'")
- information = arguments["information"]
- await qdrant_connector.store_memory(information)
- return [types.TextContent(type="text", text=f"Remembered: {information}")]
-
- if name == "qdrant-find-memories":
- if not arguments or "query" not in arguments:
- raise ValueError("Missing required argument 'query'")
- query = arguments["query"]
- memories = await qdrant_connector.find_memories(query)
- content = [
- types.TextContent(
- type="text", text=f"Memories for the query '{query}'"
- ),
- ]
- for memory in memories:
- content.append(
- types.TextContent(type="text", text=f"{memory}")
- )
- return content
-
- raise ValueError(f"Unknown tool: {name}")
+ for memory in memories:
+ content.append(
+ types.TextContent(type="text", text=f"{memory}")
+ )
+ return content
return server
diff --git a/tests/test_custom_server.py b/tests/test_custom_server.py
new file mode 100644
index 0000000..b506c25
--- /dev/null
+++ b/tests/test_custom_server.py
@@ -0,0 +1,100 @@
+import pytest
+
+from src.mcp_server_qdrant.custom_server import QdrantMCPServer
+
+
+def test_register_tool_decorator():
+ """Test that the register_tool method works as a decorator and correctly parses parameters."""
+ server = QdrantMCPServer()
+
+ @server.register_tool(description="Test function with different parameter types")
+ def test_function(text_param: str, number_param: int, flag_param: bool = False):
+ """
+ A test function with different parameter types.
+
+ :param text_param: A string parameter
+ :param number_param: An integer parameter
+ :param flag_param: A boolean parameter with default
+ """
+ return f"{text_param} {number_param} {flag_param}"
+
+ # Check that the function was registered in tool_handlers
+ assert "test_function" in server._tool_handlers
+ assert server._tool_handlers["test_function"] == test_function
+
+ # Check that the tool was added to tools list
+ assert len(server._tools) == 1
+ tool = server._tools[0]
+ assert tool.name == "test_function"
+ assert tool.description == "Test function with different parameter types"
+
+ # Check the generated schema
+ schema = tool.inputSchema
+ assert schema["type"] == "object"
+
+ # Check properties
+ properties = schema["properties"]
+ assert "text_param" in properties
+ assert properties["text_param"]["type"] == "string"
+ assert "description" in properties["text_param"]
+
+ assert "number_param" in properties
+ assert properties["number_param"]["type"] == "number"
+
+ assert "flag_param" in properties
+ assert properties["flag_param"]["type"] == "boolean"
+ assert "default" in properties["flag_param"]
+ assert properties["flag_param"]["default"] is False
+
+ # Check required fields
+ assert "required" in schema
+ assert "text_param" in schema["required"]
+ assert "number_param" in schema["required"]
+ assert "flag_param" not in schema["required"] # Has default value
+
+
+@pytest.mark.asyncio
+async def test_handle_list_tool():
+ """Test that handle_list_tool returns all registered tools."""
+ server = QdrantMCPServer()
+
+ # Register multiple tools
+ @server.register_tool(description="First test tool")
+ def tool_one(param1: str):
+ """First tool."""
+ return param1
+
+ @server.register_tool(description="Second test tool")
+ def tool_two(param1: int, param2: bool = True):
+ """Second tool."""
+ return param1, param2
+
+ @server.register_tool(name="custom_name", description="Tool with custom name")
+ def tool_three(param1: str):
+ """Third tool with custom name."""
+ return param1
+
+ # Get the list of tools
+ tools = await server.handle_list_tool()
+
+ # Check that all tools are returned
+ assert len(tools) == 3
+
+ # Check tool names
+ tool_names = [tool.name for tool in tools]
+ assert "tool_one" in tool_names
+ assert "tool_two" in tool_names
+ assert "custom_name" in tool_names
+ assert "tool_three" not in tool_names # Should use custom name instead
+
+ # Check tool descriptions
+ descriptions = {tool.name: tool.description for tool in tools}
+ assert descriptions["tool_one"] == "First test tool"
+ assert descriptions["tool_two"] == "Second test tool"
+ assert descriptions["custom_name"] == "Tool with custom name"
+
+ # Check schemas are properly generated
+ for tool in tools:
+ assert tool.inputSchema is not None
+ assert tool.inputSchema["type"] == "object"
+ assert "properties" in tool.inputSchema