Implement the server with FastMCP
This commit is contained in:
@@ -6,8 +6,10 @@ readme = "README.md"
|
|||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"mcp>=0.9.1",
|
"mcp[cli]>=1.3.0",
|
||||||
"qdrant-client[fastembed]>=1.12.0",
|
"fastembed>=0.6.0",
|
||||||
|
"qdrant-client>=1.12.0",
|
||||||
|
"typer>=0.15.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from . import server
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Main entry point for the package."""
|
"""Main entry point for the package."""
|
||||||
server.main()
|
server.mcp.run()
|
||||||
|
|
||||||
|
|
||||||
# Optionally expose other important items at package level
|
# Optionally expose other important items at package level
|
||||||
|
|||||||
@@ -1,137 +0,0 @@
|
|||||||
import inspect
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, get_type_hints
|
|
||||||
|
|
||||||
from mcp import types
|
|
||||||
from mcp.server import Server
|
|
||||||
|
|
||||||
|
|
||||||
class QdrantMCPServer(Server):
|
|
||||||
"""
|
|
||||||
An MCP server that uses Qdrant to store and retrieve information.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, name: str = "Qdrant"):
|
|
||||||
super().__init__(name)
|
|
||||||
self._tool_handlers: Dict[str, Callable] = {}
|
|
||||||
self._tools: List[types.Tool] = []
|
|
||||||
# This monkeypatching is required to make the server list the tools
|
|
||||||
# and handle tool calls. It simplifies the process of registering
|
|
||||||
# tool handlers. Please do not remove it.
|
|
||||||
self.handle_list_tool = self.list_tools()(self.handle_list_tool)
|
|
||||||
self.handle_tool_call = self.call_tool()(self.handle_tool_call)
|
|
||||||
|
|
||||||
def register_tool(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
description: str,
|
|
||||||
name: Optional[str] = None,
|
|
||||||
input_schema: Optional[dict[str, Any]] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
A decorator to register a tool with the server. The description is used
|
|
||||||
to generate the tool's metadata.
|
|
||||||
|
|
||||||
Name is optional, and if not provided, the function's name will be used.
|
|
||||||
|
|
||||||
:param description: The description of the tool.
|
|
||||||
:param name: The name of the tool. If not provided, the function's name will be used.
|
|
||||||
:param input_schema: The input schema for the tool. If not provided, it will be
|
|
||||||
automatically generated from the function's parameters.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(func: Callable):
|
|
||||||
def wrapper(fn):
|
|
||||||
nonlocal name, input_schema
|
|
||||||
|
|
||||||
# Use function name if name not provided
|
|
||||||
if name is None:
|
|
||||||
name = fn.__name__
|
|
||||||
|
|
||||||
# If no input schema is provided, generate one from the function parameters
|
|
||||||
if input_schema is None:
|
|
||||||
input_schema = self.__parse_function_parameters(fn)
|
|
||||||
|
|
||||||
# Create the tool definition
|
|
||||||
tool = types.Tool(
|
|
||||||
name=name,
|
|
||||||
description=description,
|
|
||||||
inputSchema=input_schema,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Register in both collections
|
|
||||||
self._tool_handlers[name] = fn
|
|
||||||
self._tools.append(tool)
|
|
||||||
|
|
||||||
return fn
|
|
||||||
|
|
||||||
# Handle both @register_tool and @register_tool() syntax
|
|
||||||
if func is None:
|
|
||||||
return wrapper
|
|
||||||
return wrapper(func)
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
async def handle_list_tool(self) -> List[types.Tool]:
|
|
||||||
"""Expose the list of tools to the server."""
|
|
||||||
return self._tools
|
|
||||||
|
|
||||||
async def handle_tool_call(
|
|
||||||
self, name: str, arguments: dict | None
|
|
||||||
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
|
||||||
"""Handle any tool call."""
|
|
||||||
if name not in self._tool_handlers:
|
|
||||||
raise ValueError(f"Unknown tool: {name}")
|
|
||||||
return await self._tool_handlers[name](**arguments)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def __parse_function_parameters(func: Callable) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Parse the parameters of a function to create an input schema.
|
|
||||||
|
|
||||||
:param func: The function to parse.
|
|
||||||
:return: A dictionary representing the input schema.
|
|
||||||
"""
|
|
||||||
signature = inspect.signature(func)
|
|
||||||
type_hints = get_type_hints(func)
|
|
||||||
|
|
||||||
properties = {}
|
|
||||||
required = []
|
|
||||||
|
|
||||||
for param_name, param in signature.parameters.items():
|
|
||||||
# Skip self parameter for methods
|
|
||||||
if param_name == "self":
|
|
||||||
continue
|
|
||||||
|
|
||||||
param_type = type_hints.get(param_name, Any)
|
|
||||||
param_schema = {"type": "string"} # Default to string
|
|
||||||
|
|
||||||
# Map Python types to JSON Schema types
|
|
||||||
if param_type in (int, float):
|
|
||||||
param_schema["type"] = "number"
|
|
||||||
elif param_type is bool:
|
|
||||||
param_schema["type"] = "boolean"
|
|
||||||
elif param_type is list or getattr(param_type, "__origin__", None) is list:
|
|
||||||
param_schema["type"] = "array"
|
|
||||||
|
|
||||||
# Get default value if any
|
|
||||||
if param.default is not inspect.Parameter.empty:
|
|
||||||
param_schema["default"] = param.default
|
|
||||||
else:
|
|
||||||
required.append(param_name)
|
|
||||||
|
|
||||||
# Get description from docstring if available
|
|
||||||
if func.__doc__:
|
|
||||||
param_docs = [
|
|
||||||
line.strip()
|
|
||||||
for line in func.__doc__.split("\n")
|
|
||||||
if f":param {param_name}:" in line
|
|
||||||
]
|
|
||||||
if param_docs:
|
|
||||||
description = (
|
|
||||||
param_docs[0].split(f":param {param_name}:")[1].strip()
|
|
||||||
)
|
|
||||||
param_schema["description"] = description
|
|
||||||
|
|
||||||
properties[param_name] = param_schema
|
|
||||||
|
|
||||||
return {"type": "object", "properties": properties, "required": required}
|
|
||||||
@@ -1,17 +1,16 @@
|
|||||||
from mcp_server_qdrant.embeddings import EmbeddingProvider
|
from mcp_server_qdrant.embeddings import EmbeddingProvider
|
||||||
|
from mcp_server_qdrant.settings import EmbeddingProviderSettings
|
||||||
|
|
||||||
|
|
||||||
def create_embedding_provider(provider_type: str, model_name: str) -> EmbeddingProvider:
|
def create_embedding_provider(settings: EmbeddingProviderSettings) -> EmbeddingProvider:
|
||||||
"""
|
"""
|
||||||
Create an embedding provider based on the specified type.
|
Create an embedding provider based on the specified type.
|
||||||
|
:param settings: The settings for the embedding provider.
|
||||||
:param provider_type: The type of embedding provider to create.
|
|
||||||
:param model_name: The name of the model to use for embeddings, specific to the provider type.
|
|
||||||
:return: An instance of the specified embedding provider.
|
:return: An instance of the specified embedding provider.
|
||||||
"""
|
"""
|
||||||
if provider_type.lower() == "fastembed":
|
if settings.provider_type.lower() == "fastembed":
|
||||||
from .fastembed import FastEmbedProvider
|
from mcp_server_qdrant.embeddings.fastembed import FastEmbedProvider
|
||||||
|
|
||||||
return FastEmbedProvider(model_name)
|
return FastEmbedProvider(settings.model_name)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported embedding provider: {provider_type}")
|
raise ValueError(f"Unsupported embedding provider: {settings.provider_type}")
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import List
|
|||||||
|
|
||||||
from fastembed import TextEmbedding
|
from fastembed import TextEmbedding
|
||||||
|
|
||||||
from .base import EmbeddingProvider
|
from mcp_server_qdrant.embeddings.base import EmbeddingProvider
|
||||||
|
|
||||||
|
|
||||||
class FastEmbedProvider(EmbeddingProvider):
|
class FastEmbedProvider(EmbeddingProvider):
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
import functools
|
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
|
|
||||||
def register_task(func: Callable) -> Callable:
|
|
||||||
@functools.wraps(func)
|
|
||||||
async def wrapper(*args, **kwargs):
|
|
||||||
print(f"Starting task {func.__name__}")
|
|
||||||
result = await func(*args, **kwargs)
|
|
||||||
print(f"Finished task {func.__name__}")
|
|
||||||
return result
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
import importlib.metadata
|
|
||||||
|
|
||||||
|
|
||||||
def get_package_version() -> str:
|
|
||||||
"""Get the package version using importlib.metadata."""
|
|
||||||
try:
|
|
||||||
return importlib.metadata.version("mcp-server-qdrant")
|
|
||||||
except importlib.metadata.PackageNotFoundError:
|
|
||||||
# Fall back to a default version if package is not installed
|
|
||||||
return "0.0.0"
|
|
||||||
@@ -53,9 +53,9 @@ class QdrantConnector:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def store_memory(self, information: str):
|
async def store(self, information: str):
|
||||||
"""
|
"""
|
||||||
Store a memory in the Qdrant collection.
|
Store some information in the Qdrant collection.
|
||||||
:param information: The information to store.
|
:param information: The information to store.
|
||||||
"""
|
"""
|
||||||
await self._ensure_collection_exists()
|
await self._ensure_collection_exists()
|
||||||
@@ -76,11 +76,11 @@ class QdrantConnector:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def find_memories(self, query: str) -> list[str]:
|
async def search(self, query: str) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Find memories in the Qdrant collection. If there are no memories found, an empty list is returned.
|
Find points in the Qdrant collection. If there are no entries found, an empty list is returned.
|
||||||
:param query: The query to use for the search.
|
:param query: The query to use for the search.
|
||||||
:return: A list of memories found.
|
:return: A list of entries found.
|
||||||
"""
|
"""
|
||||||
collection_exists = await self._client.collection_exists(self._collection_name)
|
collection_exists = await self._client.collection_exists(self._collection_name)
|
||||||
if not collection_exists:
|
if not collection_exists:
|
||||||
|
|||||||
@@ -1,43 +1,93 @@
|
|||||||
import asyncio
|
import logging
|
||||||
from typing import Optional
|
import os
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import AsyncIterator, List
|
||||||
|
|
||||||
import click
|
from mcp.server import Server
|
||||||
import mcp
|
from mcp.server.fastmcp import Context, FastMCP
|
||||||
import mcp.types as types
|
|
||||||
from mcp.server import NotificationOptions, Server
|
|
||||||
from mcp.server.models import InitializationOptions
|
|
||||||
|
|
||||||
from .custom_server import QdrantMCPServer
|
from mcp_server_qdrant.embeddings.factory import create_embedding_provider
|
||||||
from .embeddings.factory import create_embedding_provider
|
from mcp_server_qdrant.qdrant import QdrantConnector
|
||||||
from .helper import get_package_version
|
from mcp_server_qdrant.settings import (
|
||||||
from .qdrant import QdrantConnector
|
EmbeddingProviderSettings,
|
||||||
|
QdrantSettings,
|
||||||
|
parse_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Parse command line arguments and set them as environment variables.
|
||||||
|
# This is done for backwards compatibility with the previous versions
|
||||||
|
# of the MCP server.
|
||||||
|
env_vars = parse_args()
|
||||||
|
for key, value in env_vars.items():
|
||||||
|
os.environ[key] = value
|
||||||
|
|
||||||
|
|
||||||
def serve(
|
@asynccontextmanager
|
||||||
qdrant_connector: QdrantConnector,
|
async def server_lifespan(server: Server) -> AsyncIterator[dict]: # noqa
|
||||||
) -> Server:
|
|
||||||
"""
|
"""
|
||||||
Instantiate the server and configure tools to store and find memories in Qdrant.
|
Context manager to handle the lifespan of the server.
|
||||||
:param qdrant_connector: An instance of QdrantConnector to use for storing and retrieving memories.
|
This is used to configure the embedding provider and Qdrant connector.
|
||||||
"""
|
"""
|
||||||
server = QdrantMCPServer("qdrant")
|
try:
|
||||||
|
# Embedding provider is created with a factory function so we can add
|
||||||
|
# some more providers in the future. Currently, only FastEmbed is supported.
|
||||||
|
embedding_provider_settings = EmbeddingProviderSettings()
|
||||||
|
embedding_provider = create_embedding_provider(embedding_provider_settings)
|
||||||
|
logger.info(
|
||||||
|
f"Using embedding provider {embedding_provider_settings.provider_type} with "
|
||||||
|
f"model {embedding_provider_settings.model_name}"
|
||||||
|
)
|
||||||
|
|
||||||
@server.register_tool(
|
qdrant_configuration = QdrantSettings()
|
||||||
|
qdrant_connector = QdrantConnector(
|
||||||
|
qdrant_configuration.location,
|
||||||
|
qdrant_configuration.api_key,
|
||||||
|
qdrant_configuration.collection_name,
|
||||||
|
embedding_provider,
|
||||||
|
qdrant_configuration.local_path,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Connecting to Qdrant at {qdrant_configuration.get_qdrant_location()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"embedding_provider": embedding_provider,
|
||||||
|
"qdrant_connector": qdrant_connector,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
mcp = FastMCP("Qdrant", lifespan=server_lifespan)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool(
|
||||||
name="qdrant-store-memory",
|
name="qdrant-store-memory",
|
||||||
description=(
|
description=(
|
||||||
"Keep the memory for later use, when you are asked to remember something."
|
"Keep the memory for later use, when you are asked to remember something."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
async def store_memory(information: str):
|
async def store(information: str, ctx: Context) -> str:
|
||||||
"""
|
"""
|
||||||
Store a memory in Qdrant.
|
Store a memory in Qdrant.
|
||||||
:param information: The information to store.
|
:param information: The information to store.
|
||||||
|
:param ctx: The context for the request.
|
||||||
|
:return: A message indicating that the information was stored.
|
||||||
"""
|
"""
|
||||||
nonlocal qdrant_connector
|
await ctx.debug(f"Storing information {information} in Qdrant")
|
||||||
await qdrant_connector.store_memory(information)
|
qdrant_connector: QdrantConnector = ctx.request_context.lifespan_context[
|
||||||
return [types.TextContent(type="text", text=f"Remembered: {information}")]
|
"qdrant_connector"
|
||||||
|
]
|
||||||
|
await qdrant_connector.store(information)
|
||||||
|
return f"Remembered: {information}"
|
||||||
|
|
||||||
@server.register_tool(
|
|
||||||
|
@mcp.tool(
|
||||||
name="qdrant-find-memories",
|
name="qdrant-find-memories",
|
||||||
description=(
|
description=(
|
||||||
"Look up memories in Qdrant. Use this tool when you need to: \n"
|
"Look up memories in Qdrant. Use this tool when you need to: \n"
|
||||||
@@ -46,125 +96,27 @@ def serve(
|
|||||||
" - Get some personal information about the user"
|
" - Get some personal information about the user"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
async def find_memories(query: str):
|
async def find(query: str, ctx: Context) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Find memories in Qdrant.
|
Find memories in Qdrant.
|
||||||
:param query: The query to use for the search.
|
:param query: The query to use for the search.
|
||||||
:return: A list of memories found.
|
:param ctx: The context for the request.
|
||||||
|
:return: A list of entries found.
|
||||||
"""
|
"""
|
||||||
nonlocal qdrant_connector
|
await ctx.debug(f"Finding points for query {query}")
|
||||||
memories = await qdrant_connector.find_memories(query)
|
qdrant_connector: QdrantConnector = ctx.request_context.lifespan_context[
|
||||||
content = [
|
"qdrant_connector"
|
||||||
types.TextContent(type="text", text=f"Memories for the query '{query}'"),
|
|
||||||
]
|
]
|
||||||
for memory in memories:
|
entries = await qdrant_connector.search(query)
|
||||||
content.append(
|
if not entries:
|
||||||
types.TextContent(type="text", text=f"<memory>{memory}</memory>")
|
return [f"No memories found for the query '{query}'"]
|
||||||
)
|
content = [
|
||||||
|
f"Memories for the query '{query}'",
|
||||||
|
]
|
||||||
|
for entry in entries:
|
||||||
|
content.append(f"<entry>{entry}</entry>")
|
||||||
return content
|
return content
|
||||||
|
|
||||||
return server
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
@click.command()
|
mcp.run()
|
||||||
@click.option(
|
|
||||||
"--qdrant-url",
|
|
||||||
envvar="QDRANT_URL",
|
|
||||||
required=False,
|
|
||||||
help="Qdrant URL",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--qdrant-api-key",
|
|
||||||
envvar="QDRANT_API_KEY",
|
|
||||||
required=False,
|
|
||||||
help="Qdrant API key",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--collection-name",
|
|
||||||
envvar="COLLECTION_NAME",
|
|
||||||
required=True,
|
|
||||||
help="Collection name",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--fastembed-model-name",
|
|
||||||
envvar="FASTEMBED_MODEL_NAME",
|
|
||||||
required=False,
|
|
||||||
help="FastEmbed model name",
|
|
||||||
default="sentence-transformers/all-MiniLM-L6-v2",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--embedding-provider",
|
|
||||||
envvar="EMBEDDING_PROVIDER",
|
|
||||||
required=False,
|
|
||||||
help="Embedding provider to use",
|
|
||||||
default="fastembed",
|
|
||||||
type=click.Choice(["fastembed"], case_sensitive=False),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--embedding-model",
|
|
||||||
envvar="EMBEDDING_MODEL",
|
|
||||||
required=False,
|
|
||||||
help="Embedding model name",
|
|
||||||
default="sentence-transformers/all-MiniLM-L6-v2",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--qdrant-local-path",
|
|
||||||
envvar="QDRANT_LOCAL_PATH",
|
|
||||||
required=False,
|
|
||||||
help="Qdrant local path",
|
|
||||||
)
|
|
||||||
def main(
|
|
||||||
qdrant_url: Optional[str],
|
|
||||||
qdrant_api_key: str,
|
|
||||||
collection_name: Optional[str],
|
|
||||||
fastembed_model_name: Optional[str],
|
|
||||||
embedding_provider: str,
|
|
||||||
embedding_model: str,
|
|
||||||
qdrant_local_path: Optional[str],
|
|
||||||
):
|
|
||||||
# XOR of url and local path, since we accept only one of them
|
|
||||||
if not (bool(qdrant_url) ^ bool(qdrant_local_path)):
|
|
||||||
raise ValueError(
|
|
||||||
"Exactly one of qdrant-url or qdrant-local-path must be provided"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Warn if fastembed_model_name is provided, as this is going to be deprecated
|
|
||||||
if fastembed_model_name:
|
|
||||||
click.echo(
|
|
||||||
"Warning: --fastembed-model-name parameter is deprecated and will be removed in a future version. "
|
|
||||||
"Please use --embedding-provider and --embedding-model instead",
|
|
||||||
err=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _run():
|
|
||||||
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
|
|
||||||
# Create the embedding provider
|
|
||||||
provider = create_embedding_provider(
|
|
||||||
provider_type=embedding_provider, model_name=embedding_model
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create the Qdrant connector
|
|
||||||
qdrant_connector = QdrantConnector(
|
|
||||||
qdrant_url,
|
|
||||||
qdrant_api_key,
|
|
||||||
collection_name,
|
|
||||||
provider,
|
|
||||||
qdrant_local_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create and run the server
|
|
||||||
server = serve(qdrant_connector)
|
|
||||||
await server.run(
|
|
||||||
read_stream,
|
|
||||||
write_stream,
|
|
||||||
InitializationOptions(
|
|
||||||
server_name="qdrant",
|
|
||||||
server_version=get_package_version(),
|
|
||||||
capabilities=server.get_capabilities(
|
|
||||||
notification_options=NotificationOptions(),
|
|
||||||
experimental_capabilities={},
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
asyncio.run(_run())
|
|
||||||
|
|||||||
101
src/mcp_server_qdrant/settings.py
Normal file
101
src/mcp_server_qdrant/settings.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
import argparse
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingProviderSettings(BaseSettings):
|
||||||
|
"""
|
||||||
|
Configuration for the embedding provider.
|
||||||
|
"""
|
||||||
|
|
||||||
|
provider_type: str = Field(
|
||||||
|
default="fastembed", validation_alias="EMBEDDING_PROVIDER"
|
||||||
|
)
|
||||||
|
model_name: str = Field(
|
||||||
|
default="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
validation_alias="EMBEDDING_MODEL",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class QdrantSettings(BaseSettings):
|
||||||
|
"""
|
||||||
|
Configuration for the Qdrant connector.
|
||||||
|
"""
|
||||||
|
|
||||||
|
location: Optional[str] = Field(default=None, validation_alias="QDRANT_URL")
|
||||||
|
api_key: Optional[str] = Field(default=None, validation_alias="QDRANT_API_KEY")
|
||||||
|
collection_name: str = Field(validation_alias="COLLECTION_NAME")
|
||||||
|
local_path: Optional[str] = Field(
|
||||||
|
default=None, validation_alias="QDRANT_LOCAL_PATH"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_qdrant_location(self) -> str:
|
||||||
|
"""
|
||||||
|
Get the Qdrant location, either the URL or the local path.
|
||||||
|
"""
|
||||||
|
return self.location or self.local_path
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Parse command line arguments for the MCP server.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Dictionary of parsed arguments
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser(description="Qdrant MCP Server")
|
||||||
|
|
||||||
|
# Qdrant connection options
|
||||||
|
connection_group = parser.add_mutually_exclusive_group()
|
||||||
|
connection_group.add_argument(
|
||||||
|
"--qdrant-url",
|
||||||
|
help="URL of the Qdrant server, e.g. http://localhost:6333",
|
||||||
|
)
|
||||||
|
connection_group.add_argument(
|
||||||
|
"--qdrant-local-path",
|
||||||
|
help="Path to the local Qdrant database",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Other Qdrant settings
|
||||||
|
parser.add_argument(
|
||||||
|
"--qdrant-api-key",
|
||||||
|
help="API key for the Qdrant server",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--collection-name",
|
||||||
|
help="Name of the collection to use",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Embedding settings
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-provider",
|
||||||
|
help="Embedding provider to use (currently only 'fastembed' is supported)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
help="Name of the embedding model to use",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Convert to dictionary and filter out None values
|
||||||
|
args_dict = {k: v for k, v in vars(args).items() if v is not None}
|
||||||
|
|
||||||
|
# Convert argument names to environment variable format
|
||||||
|
env_vars = {}
|
||||||
|
if "qdrant_url" in args_dict:
|
||||||
|
env_vars["QDRANT_URL"] = args_dict["qdrant_url"]
|
||||||
|
if "qdrant_api_key" in args_dict:
|
||||||
|
env_vars["QDRANT_API_KEY"] = args_dict["qdrant_api_key"]
|
||||||
|
if "collection_name" in args_dict:
|
||||||
|
env_vars["COLLECTION_NAME"] = args_dict["collection_name"]
|
||||||
|
if "embedding_model" in args_dict:
|
||||||
|
env_vars["EMBEDDING_MODEL"] = args_dict["embedding_model"]
|
||||||
|
if "embedding_provider" in args_dict:
|
||||||
|
env_vars["EMBEDDING_PROVIDER"] = args_dict["embedding_provider"]
|
||||||
|
if "qdrant_local_path" in args_dict:
|
||||||
|
env_vars["QDRANT_LOCAL_PATH"] = args_dict["qdrant_local_path"]
|
||||||
|
|
||||||
|
return env_vars
|
||||||
61
tests/test_config.py
Normal file
61
tests/test_config.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mcp_server_qdrant.settings import EmbeddingProviderSettings, QdrantSettings
|
||||||
|
|
||||||
|
|
||||||
|
class TestQdrantSettings:
|
||||||
|
def test_default_values(self):
|
||||||
|
"""Test that required fields raise errors when not provided."""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# Should raise error because required fields are missing
|
||||||
|
QdrantSettings()
|
||||||
|
|
||||||
|
@patch.dict(
|
||||||
|
os.environ,
|
||||||
|
{"QDRANT_URL": "http://localhost:6333", "COLLECTION_NAME": "test_collection"},
|
||||||
|
)
|
||||||
|
def test_minimal_config(self):
|
||||||
|
"""Test loading minimal configuration from environment variables."""
|
||||||
|
settings = QdrantSettings()
|
||||||
|
assert settings.location == "http://localhost:6333"
|
||||||
|
assert settings.collection_name == "test_collection"
|
||||||
|
assert settings.api_key is None
|
||||||
|
assert settings.local_path is None
|
||||||
|
|
||||||
|
@patch.dict(
|
||||||
|
os.environ,
|
||||||
|
{
|
||||||
|
"QDRANT_URL": "http://qdrant.example.com:6333",
|
||||||
|
"QDRANT_API_KEY": "test_api_key",
|
||||||
|
"COLLECTION_NAME": "my_memories",
|
||||||
|
"QDRANT_LOCAL_PATH": "/tmp/qdrant",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def test_full_config(self):
|
||||||
|
"""Test loading full configuration from environment variables."""
|
||||||
|
settings = QdrantSettings()
|
||||||
|
assert settings.location == "http://qdrant.example.com:6333"
|
||||||
|
assert settings.api_key == "test_api_key"
|
||||||
|
assert settings.collection_name == "my_memories"
|
||||||
|
assert settings.local_path == "/tmp/qdrant"
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddingProviderSettings:
|
||||||
|
def test_default_values(self):
|
||||||
|
"""Test default values are set correctly."""
|
||||||
|
settings = EmbeddingProviderSettings()
|
||||||
|
assert settings.provider_type == "fastembed"
|
||||||
|
assert settings.model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
|
||||||
|
@patch.dict(
|
||||||
|
os.environ,
|
||||||
|
{"EMBEDDING_PROVIDER": "custom_provider", "EMBEDDING_MODEL": "custom_model"},
|
||||||
|
)
|
||||||
|
def test_custom_values(self):
|
||||||
|
"""Test loading custom values from environment variables."""
|
||||||
|
settings = EmbeddingProviderSettings()
|
||||||
|
assert settings.provider_type == "custom_provider"
|
||||||
|
assert settings.model_name == "custom_model"
|
||||||
@@ -1,100 +0,0 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
from src.mcp_server_qdrant.custom_server import QdrantMCPServer
|
|
||||||
|
|
||||||
|
|
||||||
def test_register_tool_decorator():
|
|
||||||
"""Test that the register_tool method works as a decorator and correctly parses parameters."""
|
|
||||||
server = QdrantMCPServer()
|
|
||||||
|
|
||||||
@server.register_tool(description="Test function with different parameter types")
|
|
||||||
def test_function(text_param: str, number_param: int, flag_param: bool = False):
|
|
||||||
"""
|
|
||||||
A test function with different parameter types.
|
|
||||||
|
|
||||||
:param text_param: A string parameter
|
|
||||||
:param number_param: An integer parameter
|
|
||||||
:param flag_param: A boolean parameter with default
|
|
||||||
"""
|
|
||||||
return f"{text_param} {number_param} {flag_param}"
|
|
||||||
|
|
||||||
# Check that the function was registered in tool_handlers
|
|
||||||
assert "test_function" in server._tool_handlers
|
|
||||||
assert server._tool_handlers["test_function"] == test_function
|
|
||||||
|
|
||||||
# Check that the tool was added to tools list
|
|
||||||
assert len(server._tools) == 1
|
|
||||||
tool = server._tools[0]
|
|
||||||
assert tool.name == "test_function"
|
|
||||||
assert tool.description == "Test function with different parameter types"
|
|
||||||
|
|
||||||
# Check the generated schema
|
|
||||||
schema = tool.inputSchema
|
|
||||||
assert schema["type"] == "object"
|
|
||||||
|
|
||||||
# Check properties
|
|
||||||
properties = schema["properties"]
|
|
||||||
assert "text_param" in properties
|
|
||||||
assert properties["text_param"]["type"] == "string"
|
|
||||||
assert "description" in properties["text_param"]
|
|
||||||
|
|
||||||
assert "number_param" in properties
|
|
||||||
assert properties["number_param"]["type"] == "number"
|
|
||||||
|
|
||||||
assert "flag_param" in properties
|
|
||||||
assert properties["flag_param"]["type"] == "boolean"
|
|
||||||
assert "default" in properties["flag_param"]
|
|
||||||
assert properties["flag_param"]["default"] is False
|
|
||||||
|
|
||||||
# Check required fields
|
|
||||||
assert "required" in schema
|
|
||||||
assert "text_param" in schema["required"]
|
|
||||||
assert "number_param" in schema["required"]
|
|
||||||
assert "flag_param" not in schema["required"] # Has default value
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_list_tool():
|
|
||||||
"""Test that handle_list_tool returns all registered tools."""
|
|
||||||
server = QdrantMCPServer()
|
|
||||||
|
|
||||||
# Register multiple tools
|
|
||||||
@server.register_tool(description="First test tool")
|
|
||||||
def tool_one(param1: str):
|
|
||||||
"""First tool."""
|
|
||||||
return param1
|
|
||||||
|
|
||||||
@server.register_tool(description="Second test tool")
|
|
||||||
def tool_two(param1: int, param2: bool = True):
|
|
||||||
"""Second tool."""
|
|
||||||
return param1, param2
|
|
||||||
|
|
||||||
@server.register_tool(name="custom_name", description="Tool with custom name")
|
|
||||||
def tool_three(param1: str):
|
|
||||||
"""Third tool with custom name."""
|
|
||||||
return param1
|
|
||||||
|
|
||||||
# Get the list of tools
|
|
||||||
tools = await server.handle_list_tool()
|
|
||||||
|
|
||||||
# Check that all tools are returned
|
|
||||||
assert len(tools) == 3
|
|
||||||
|
|
||||||
# Check tool names
|
|
||||||
tool_names = [tool.name for tool in tools]
|
|
||||||
assert "tool_one" in tool_names
|
|
||||||
assert "tool_two" in tool_names
|
|
||||||
assert "custom_name" in tool_names
|
|
||||||
assert "tool_three" not in tool_names # Should use custom name instead
|
|
||||||
|
|
||||||
# Check tool descriptions
|
|
||||||
descriptions = {tool.name: tool.description for tool in tools}
|
|
||||||
assert descriptions["tool_one"] == "First test tool"
|
|
||||||
assert descriptions["tool_two"] == "Second test tool"
|
|
||||||
assert descriptions["custom_name"] == "Tool with custom name"
|
|
||||||
|
|
||||||
# Check schemas are properly generated
|
|
||||||
for tool in tools:
|
|
||||||
assert tool.inputSchema is not None
|
|
||||||
assert tool.inputSchema["type"] == "object"
|
|
||||||
assert "properties" in tool.inputSchema
|
|
||||||
63
tests/test_fastembed_integration.py
Normal file
63
tests/test_fastembed_integration.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from fastembed import TextEmbedding
|
||||||
|
|
||||||
|
from mcp_server_qdrant.embeddings.fastembed import FastEmbedProvider
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestFastEmbedProviderIntegration:
|
||||||
|
"""Integration tests for FastEmbedProvider."""
|
||||||
|
|
||||||
|
async def test_initialization(self):
|
||||||
|
"""Test that the provider can be initialized with a valid model."""
|
||||||
|
provider = FastEmbedProvider("sentence-transformers/all-MiniLM-L6-v2")
|
||||||
|
assert provider.model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
assert isinstance(provider.embedding_model, TextEmbedding)
|
||||||
|
|
||||||
|
async def test_embed_documents(self):
|
||||||
|
"""Test that documents can be embedded."""
|
||||||
|
provider = FastEmbedProvider("sentence-transformers/all-MiniLM-L6-v2")
|
||||||
|
documents = ["This is a test document.", "This is another test document."]
|
||||||
|
|
||||||
|
embeddings = await provider.embed_documents(documents)
|
||||||
|
|
||||||
|
# Check that we got the right number of embeddings
|
||||||
|
assert len(embeddings) == len(documents)
|
||||||
|
|
||||||
|
# Check that embeddings have the expected shape
|
||||||
|
# The exact dimension depends on the model, but should be consistent
|
||||||
|
assert len(embeddings[0]) > 0
|
||||||
|
assert all(len(embedding) == len(embeddings[0]) for embedding in embeddings)
|
||||||
|
|
||||||
|
# Check that embeddings are different for different documents
|
||||||
|
# Convert to numpy arrays for easier comparison
|
||||||
|
embedding1 = np.array(embeddings[0])
|
||||||
|
embedding2 = np.array(embeddings[1])
|
||||||
|
assert not np.array_equal(embedding1, embedding2)
|
||||||
|
|
||||||
|
async def test_embed_query(self):
|
||||||
|
"""Test that queries can be embedded."""
|
||||||
|
provider = FastEmbedProvider("sentence-transformers/all-MiniLM-L6-v2")
|
||||||
|
query = "This is a test query."
|
||||||
|
|
||||||
|
embedding = await provider.embed_query(query)
|
||||||
|
|
||||||
|
# Check that embedding has the expected shape
|
||||||
|
assert len(embedding) > 0
|
||||||
|
|
||||||
|
# Embed the same query again to check consistency
|
||||||
|
embedding2 = await provider.embed_query(query)
|
||||||
|
assert len(embedding) == len(embedding2)
|
||||||
|
|
||||||
|
# The embeddings should be identical for the same input
|
||||||
|
np.testing.assert_array_almost_equal(np.array(embedding), np.array(embedding2))
|
||||||
|
|
||||||
|
def test_get_vector_name(self):
|
||||||
|
"""Test that the vector name is generated correctly."""
|
||||||
|
provider = FastEmbedProvider("sentence-transformers/all-MiniLM-L6-v2")
|
||||||
|
vector_name = provider.get_vector_name()
|
||||||
|
|
||||||
|
# Check that the vector name follows the expected format
|
||||||
|
assert vector_name.startswith("fast-")
|
||||||
|
assert "minilm" in vector_name.lower()
|
||||||
Reference in New Issue
Block a user