diff --git a/src/mcp_server_qdrant/custom_server.py b/src/mcp_server_qdrant/custom_server.py new file mode 100644 index 0000000..dcddfa2 --- /dev/null +++ b/src/mcp_server_qdrant/custom_server.py @@ -0,0 +1,137 @@ +import inspect +from typing import Any, Callable, Dict, List, Optional, get_type_hints + +from mcp import types +from mcp.server import Server + + +class QdrantMCPServer(Server): + """ + An MCP server that uses Qdrant to store and retrieve information. + """ + + def __init__(self, name: str = "Qdrant"): + super().__init__(name) + self._tool_handlers: Dict[str, Callable] = {} + self._tools: List[types.Tool] = [] + # This monkeypatching is required to make the server list the tools + # and handle tool calls. It simplifies the process of registering + # tool handlers. Please do not remove it. + self.handle_list_tool = self.list_tools()(self.handle_list_tool) + self.handle_tool_call = self.call_tool()(self.handle_tool_call) + + def register_tool( + self, + *, + description: str, + name: Optional[str] = None, + input_schema: Optional[dict[str, Any]] = None, + ): + """ + A decorator to register a tool with the server. The description is used + to generate the tool's metadata. + + Name is optional, and if not provided, the function's name will be used. + + :param description: The description of the tool. + :param name: The name of the tool. If not provided, the function's name will be used. + :param input_schema: The input schema for the tool. If not provided, it will be + automatically generated from the function's parameters. + """ + + def decorator(func: Callable): + def wrapper(fn): + nonlocal name, input_schema + + # Use function name if name not provided + if name is None: + name = fn.__name__ + + # If no input schema is provided, generate one from the function parameters + if input_schema is None: + input_schema = self.__parse_function_parameters(fn) + + # Create the tool definition + tool = types.Tool( + name=name, + description=description, + inputSchema=input_schema, + ) + + # Register in both collections + self._tool_handlers[name] = fn + self._tools.append(tool) + + return fn + + # Handle both @register_tool and @register_tool() syntax + if func is None: + return wrapper + return wrapper(func) + + return decorator + + async def handle_list_tool(self) -> List[types.Tool]: + """Expose the list of tools to the server.""" + return self._tools + + async def handle_tool_call( + self, name: str, arguments: dict | None + ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: + """Handle any tool call.""" + if name not in self._tool_handlers: + raise ValueError(f"Unknown tool: {name}") + return await self._tool_handlers[name](**arguments) + + @staticmethod + def __parse_function_parameters(func: Callable) -> Dict[str, Any]: + """ + Parse the parameters of a function to create an input schema. + + :param func: The function to parse. + :return: A dictionary representing the input schema. + """ + signature = inspect.signature(func) + type_hints = get_type_hints(func) + + properties = {} + required = [] + + for param_name, param in signature.parameters.items(): + # Skip self parameter for methods + if param_name == "self": + continue + + param_type = type_hints.get(param_name, Any) + param_schema = {"type": "string"} # Default to string + + # Map Python types to JSON Schema types + if param_type in (int, float): + param_schema["type"] = "number" + elif param_type is bool: + param_schema["type"] = "boolean" + elif param_type is list or getattr(param_type, "__origin__", None) is list: + param_schema["type"] = "array" + + # Get default value if any + if param.default is not inspect.Parameter.empty: + param_schema["default"] = param.default + else: + required.append(param_name) + + # Get description from docstring if available + if func.__doc__: + param_docs = [ + line.strip() + for line in func.__doc__.split("\n") + if f":param {param_name}:" in line + ] + if param_docs: + description = ( + param_docs[0].split(f":param {param_name}:")[1].strip() + ) + param_schema["description"] = description + + properties[param_name] = param_schema + + return {"type": "object", "properties": properties, "required": required} diff --git a/src/mcp_server_qdrant/embeddings/fastembed.py b/src/mcp_server_qdrant/embeddings/fastembed.py index 90a9d9e..acfeedf 100644 --- a/src/mcp_server_qdrant/embeddings/fastembed.py +++ b/src/mcp_server_qdrant/embeddings/fastembed.py @@ -7,14 +7,12 @@ from .base import EmbeddingProvider class FastEmbedProvider(EmbeddingProvider): - """FastEmbed implementation of the embedding provider.""" + """ + FastEmbed implementation of the embedding provider. + :param model_name: The name of the FastEmbed model to use. + """ def __init__(self, model_name: str): - """ - Initialize the FastEmbed provider. - - :param model_name: The name of the FastEmbed model to use. - """ self.model_name = model_name self.embedding_model = TextEmbedding(model_name) diff --git a/src/mcp_server_qdrant/handlers/__init__.py b/src/mcp_server_qdrant/handlers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/mcp_server_qdrant/handlers/decorators.py b/src/mcp_server_qdrant/handlers/decorators.py new file mode 100644 index 0000000..d04de89 --- /dev/null +++ b/src/mcp_server_qdrant/handlers/decorators.py @@ -0,0 +1,13 @@ +import functools +from typing import Callable + + +def register_task(func: Callable) -> Callable: + @functools.wraps(func) + async def wrapper(*args, **kwargs): + print(f"Starting task {func.__name__}") + result = await func(*args, **kwargs) + print(f"Finished task {func.__name__}") + return result + + return wrapper diff --git a/src/mcp_server_qdrant/handlers/memory.py b/src/mcp_server_qdrant/handlers/memory.py new file mode 100644 index 0000000..e69de29 diff --git a/src/mcp_server_qdrant/helper.py b/src/mcp_server_qdrant/helper.py new file mode 100644 index 0000000..659a81b --- /dev/null +++ b/src/mcp_server_qdrant/helper.py @@ -0,0 +1,10 @@ +import importlib.metadata + + +def get_package_version() -> str: + """Get the package version using importlib.metadata.""" + try: + return importlib.metadata.version("mcp-server-qdrant") + except importlib.metadata.PackageNotFoundError: + # Fall back to a default version if package is not installed + return "0.0.0" diff --git a/src/mcp_server_qdrant/server.py b/src/mcp_server_qdrant/server.py index 95db682..1725b17 100644 --- a/src/mcp_server_qdrant/server.py +++ b/src/mcp_server_qdrant/server.py @@ -1,5 +1,4 @@ import asyncio -import importlib.metadata from typing import Optional import click @@ -8,19 +7,12 @@ import mcp.types as types from mcp.server import NotificationOptions, Server from mcp.server.models import InitializationOptions +from .custom_server import QdrantMCPServer from .embeddings.factory import create_embedding_provider +from .helper import get_package_version from .qdrant import QdrantConnector -def get_package_version() -> str: - """Get the package version using importlib.metadata.""" - try: - return importlib.metadata.version("mcp-server-qdrant") - except importlib.metadata.PackageNotFoundError: - # Fall back to a default version if package is not installed - return "0.0.0" - - def serve( qdrant_connector: QdrantConnector, ) -> Server: @@ -28,84 +20,48 @@ def serve( Instantiate the server and configure tools to store and find memories in Qdrant. :param qdrant_connector: An instance of QdrantConnector to use for storing and retrieving memories. """ - server = Server("qdrant") + server = QdrantMCPServer("qdrant") - @server.list_tools() - async def handle_list_tools() -> list[types.Tool]: + @server.register_tool( + name="qdrant-store-memory", + description=( + "Keep the memory for later use, when you are asked to remember something." + ), + ) + async def store_memory(information: str): """ - Return the list of tools that the server provides. By default, there are two - tools: one to store memories and another to find them. Finding the memories is not - implemented as a resource, as it requires a query to be passed and resources point - to a very specific piece of data. + Store a memory in Qdrant. + :param information: The information to store. """ - return [ - types.Tool( - name="qdrant-store-memory", - description=( - "Keep the memory for later use, when you are asked to remember something." - ), - inputSchema={ - "type": "object", - "properties": { - "information": { - "type": "string", - }, - }, - "required": ["information"], - }, - ), - types.Tool( - name="qdrant-find-memories", - description=( - "Look up memories in Qdrant. Use this tool when you need to: \n" - " - Find memories by their content \n" - " - Access memories for further analysis \n" - " - Get some personal information about the user" - ), - inputSchema={ - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "The query to search for", - } - }, - "required": ["query"], - }, - ), + nonlocal qdrant_connector + await qdrant_connector.store_memory(information) + return [types.TextContent(type="text", text=f"Remembered: {information}")] + + @server.register_tool( + name="qdrant-find-memories", + description=( + "Look up memories in Qdrant. Use this tool when you need to: \n" + " - Find memories by their content \n" + " - Access memories for further analysis \n" + " - Get some personal information about the user" + ), + ) + async def find_memories(query: str): + """ + Find memories in Qdrant. + :param query: The query to use for the search. + :return: A list of memories found. + """ + nonlocal qdrant_connector + memories = await qdrant_connector.find_memories(query) + content = [ + types.TextContent(type="text", text=f"Memories for the query '{query}'"), ] - - @server.call_tool() - async def handle_tool_call( - name: str, arguments: dict | None - ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: - if name not in ["qdrant-store-memory", "qdrant-find-memories"]: - raise ValueError(f"Unknown tool: {name}") - - if name == "qdrant-store-memory": - if not arguments or "information" not in arguments: - raise ValueError("Missing required argument 'information'") - information = arguments["information"] - await qdrant_connector.store_memory(information) - return [types.TextContent(type="text", text=f"Remembered: {information}")] - - if name == "qdrant-find-memories": - if not arguments or "query" not in arguments: - raise ValueError("Missing required argument 'query'") - query = arguments["query"] - memories = await qdrant_connector.find_memories(query) - content = [ - types.TextContent( - type="text", text=f"Memories for the query '{query}'" - ), - ] - for memory in memories: - content.append( - types.TextContent(type="text", text=f"{memory}") - ) - return content - - raise ValueError(f"Unknown tool: {name}") + for memory in memories: + content.append( + types.TextContent(type="text", text=f"{memory}") + ) + return content return server diff --git a/tests/test_custom_server.py b/tests/test_custom_server.py new file mode 100644 index 0000000..b506c25 --- /dev/null +++ b/tests/test_custom_server.py @@ -0,0 +1,100 @@ +import pytest + +from src.mcp_server_qdrant.custom_server import QdrantMCPServer + + +def test_register_tool_decorator(): + """Test that the register_tool method works as a decorator and correctly parses parameters.""" + server = QdrantMCPServer() + + @server.register_tool(description="Test function with different parameter types") + def test_function(text_param: str, number_param: int, flag_param: bool = False): + """ + A test function with different parameter types. + + :param text_param: A string parameter + :param number_param: An integer parameter + :param flag_param: A boolean parameter with default + """ + return f"{text_param} {number_param} {flag_param}" + + # Check that the function was registered in tool_handlers + assert "test_function" in server._tool_handlers + assert server._tool_handlers["test_function"] == test_function + + # Check that the tool was added to tools list + assert len(server._tools) == 1 + tool = server._tools[0] + assert tool.name == "test_function" + assert tool.description == "Test function with different parameter types" + + # Check the generated schema + schema = tool.inputSchema + assert schema["type"] == "object" + + # Check properties + properties = schema["properties"] + assert "text_param" in properties + assert properties["text_param"]["type"] == "string" + assert "description" in properties["text_param"] + + assert "number_param" in properties + assert properties["number_param"]["type"] == "number" + + assert "flag_param" in properties + assert properties["flag_param"]["type"] == "boolean" + assert "default" in properties["flag_param"] + assert properties["flag_param"]["default"] is False + + # Check required fields + assert "required" in schema + assert "text_param" in schema["required"] + assert "number_param" in schema["required"] + assert "flag_param" not in schema["required"] # Has default value + + +@pytest.mark.asyncio +async def test_handle_list_tool(): + """Test that handle_list_tool returns all registered tools.""" + server = QdrantMCPServer() + + # Register multiple tools + @server.register_tool(description="First test tool") + def tool_one(param1: str): + """First tool.""" + return param1 + + @server.register_tool(description="Second test tool") + def tool_two(param1: int, param2: bool = True): + """Second tool.""" + return param1, param2 + + @server.register_tool(name="custom_name", description="Tool with custom name") + def tool_three(param1: str): + """Third tool with custom name.""" + return param1 + + # Get the list of tools + tools = await server.handle_list_tool() + + # Check that all tools are returned + assert len(tools) == 3 + + # Check tool names + tool_names = [tool.name for tool in tools] + assert "tool_one" in tool_names + assert "tool_two" in tool_names + assert "custom_name" in tool_names + assert "tool_three" not in tool_names # Should use custom name instead + + # Check tool descriptions + descriptions = {tool.name: tool.description for tool in tools} + assert descriptions["tool_one"] == "First test tool" + assert descriptions["tool_two"] == "Second test tool" + assert descriptions["custom_name"] == "Tool with custom name" + + # Check schemas are properly generated + for tool in tools: + assert tool.inputSchema is not None + assert tool.inputSchema["type"] == "object" + assert "properties" in tool.inputSchema