Implement the server with FastMCP
This commit is contained in:
61
tests/test_config.py
Normal file
61
tests/test_config.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from mcp_server_qdrant.settings import EmbeddingProviderSettings, QdrantSettings
|
||||
|
||||
|
||||
class TestQdrantSettings:
|
||||
def test_default_values(self):
|
||||
"""Test that required fields raise errors when not provided."""
|
||||
with pytest.raises(ValueError):
|
||||
# Should raise error because required fields are missing
|
||||
QdrantSettings()
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{"QDRANT_URL": "http://localhost:6333", "COLLECTION_NAME": "test_collection"},
|
||||
)
|
||||
def test_minimal_config(self):
|
||||
"""Test loading minimal configuration from environment variables."""
|
||||
settings = QdrantSettings()
|
||||
assert settings.location == "http://localhost:6333"
|
||||
assert settings.collection_name == "test_collection"
|
||||
assert settings.api_key is None
|
||||
assert settings.local_path is None
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"QDRANT_URL": "http://qdrant.example.com:6333",
|
||||
"QDRANT_API_KEY": "test_api_key",
|
||||
"COLLECTION_NAME": "my_memories",
|
||||
"QDRANT_LOCAL_PATH": "/tmp/qdrant",
|
||||
},
|
||||
)
|
||||
def test_full_config(self):
|
||||
"""Test loading full configuration from environment variables."""
|
||||
settings = QdrantSettings()
|
||||
assert settings.location == "http://qdrant.example.com:6333"
|
||||
assert settings.api_key == "test_api_key"
|
||||
assert settings.collection_name == "my_memories"
|
||||
assert settings.local_path == "/tmp/qdrant"
|
||||
|
||||
|
||||
class TestEmbeddingProviderSettings:
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly."""
|
||||
settings = EmbeddingProviderSettings()
|
||||
assert settings.provider_type == "fastembed"
|
||||
assert settings.model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{"EMBEDDING_PROVIDER": "custom_provider", "EMBEDDING_MODEL": "custom_model"},
|
||||
)
|
||||
def test_custom_values(self):
|
||||
"""Test loading custom values from environment variables."""
|
||||
settings = EmbeddingProviderSettings()
|
||||
assert settings.provider_type == "custom_provider"
|
||||
assert settings.model_name == "custom_model"
|
||||
@@ -1,100 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from src.mcp_server_qdrant.custom_server import QdrantMCPServer
|
||||
|
||||
|
||||
def test_register_tool_decorator():
|
||||
"""Test that the register_tool method works as a decorator and correctly parses parameters."""
|
||||
server = QdrantMCPServer()
|
||||
|
||||
@server.register_tool(description="Test function with different parameter types")
|
||||
def test_function(text_param: str, number_param: int, flag_param: bool = False):
|
||||
"""
|
||||
A test function with different parameter types.
|
||||
|
||||
:param text_param: A string parameter
|
||||
:param number_param: An integer parameter
|
||||
:param flag_param: A boolean parameter with default
|
||||
"""
|
||||
return f"{text_param} {number_param} {flag_param}"
|
||||
|
||||
# Check that the function was registered in tool_handlers
|
||||
assert "test_function" in server._tool_handlers
|
||||
assert server._tool_handlers["test_function"] == test_function
|
||||
|
||||
# Check that the tool was added to tools list
|
||||
assert len(server._tools) == 1
|
||||
tool = server._tools[0]
|
||||
assert tool.name == "test_function"
|
||||
assert tool.description == "Test function with different parameter types"
|
||||
|
||||
# Check the generated schema
|
||||
schema = tool.inputSchema
|
||||
assert schema["type"] == "object"
|
||||
|
||||
# Check properties
|
||||
properties = schema["properties"]
|
||||
assert "text_param" in properties
|
||||
assert properties["text_param"]["type"] == "string"
|
||||
assert "description" in properties["text_param"]
|
||||
|
||||
assert "number_param" in properties
|
||||
assert properties["number_param"]["type"] == "number"
|
||||
|
||||
assert "flag_param" in properties
|
||||
assert properties["flag_param"]["type"] == "boolean"
|
||||
assert "default" in properties["flag_param"]
|
||||
assert properties["flag_param"]["default"] is False
|
||||
|
||||
# Check required fields
|
||||
assert "required" in schema
|
||||
assert "text_param" in schema["required"]
|
||||
assert "number_param" in schema["required"]
|
||||
assert "flag_param" not in schema["required"] # Has default value
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_list_tool():
|
||||
"""Test that handle_list_tool returns all registered tools."""
|
||||
server = QdrantMCPServer()
|
||||
|
||||
# Register multiple tools
|
||||
@server.register_tool(description="First test tool")
|
||||
def tool_one(param1: str):
|
||||
"""First tool."""
|
||||
return param1
|
||||
|
||||
@server.register_tool(description="Second test tool")
|
||||
def tool_two(param1: int, param2: bool = True):
|
||||
"""Second tool."""
|
||||
return param1, param2
|
||||
|
||||
@server.register_tool(name="custom_name", description="Tool with custom name")
|
||||
def tool_three(param1: str):
|
||||
"""Third tool with custom name."""
|
||||
return param1
|
||||
|
||||
# Get the list of tools
|
||||
tools = await server.handle_list_tool()
|
||||
|
||||
# Check that all tools are returned
|
||||
assert len(tools) == 3
|
||||
|
||||
# Check tool names
|
||||
tool_names = [tool.name for tool in tools]
|
||||
assert "tool_one" in tool_names
|
||||
assert "tool_two" in tool_names
|
||||
assert "custom_name" in tool_names
|
||||
assert "tool_three" not in tool_names # Should use custom name instead
|
||||
|
||||
# Check tool descriptions
|
||||
descriptions = {tool.name: tool.description for tool in tools}
|
||||
assert descriptions["tool_one"] == "First test tool"
|
||||
assert descriptions["tool_two"] == "Second test tool"
|
||||
assert descriptions["custom_name"] == "Tool with custom name"
|
||||
|
||||
# Check schemas are properly generated
|
||||
for tool in tools:
|
||||
assert tool.inputSchema is not None
|
||||
assert tool.inputSchema["type"] == "object"
|
||||
assert "properties" in tool.inputSchema
|
||||
63
tests/test_fastembed_integration.py
Normal file
63
tests/test_fastembed_integration.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from fastembed import TextEmbedding
|
||||
|
||||
from mcp_server_qdrant.embeddings.fastembed import FastEmbedProvider
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestFastEmbedProviderIntegration:
|
||||
"""Integration tests for FastEmbedProvider."""
|
||||
|
||||
async def test_initialization(self):
|
||||
"""Test that the provider can be initialized with a valid model."""
|
||||
provider = FastEmbedProvider("sentence-transformers/all-MiniLM-L6-v2")
|
||||
assert provider.model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
assert isinstance(provider.embedding_model, TextEmbedding)
|
||||
|
||||
async def test_embed_documents(self):
|
||||
"""Test that documents can be embedded."""
|
||||
provider = FastEmbedProvider("sentence-transformers/all-MiniLM-L6-v2")
|
||||
documents = ["This is a test document.", "This is another test document."]
|
||||
|
||||
embeddings = await provider.embed_documents(documents)
|
||||
|
||||
# Check that we got the right number of embeddings
|
||||
assert len(embeddings) == len(documents)
|
||||
|
||||
# Check that embeddings have the expected shape
|
||||
# The exact dimension depends on the model, but should be consistent
|
||||
assert len(embeddings[0]) > 0
|
||||
assert all(len(embedding) == len(embeddings[0]) for embedding in embeddings)
|
||||
|
||||
# Check that embeddings are different for different documents
|
||||
# Convert to numpy arrays for easier comparison
|
||||
embedding1 = np.array(embeddings[0])
|
||||
embedding2 = np.array(embeddings[1])
|
||||
assert not np.array_equal(embedding1, embedding2)
|
||||
|
||||
async def test_embed_query(self):
|
||||
"""Test that queries can be embedded."""
|
||||
provider = FastEmbedProvider("sentence-transformers/all-MiniLM-L6-v2")
|
||||
query = "This is a test query."
|
||||
|
||||
embedding = await provider.embed_query(query)
|
||||
|
||||
# Check that embedding has the expected shape
|
||||
assert len(embedding) > 0
|
||||
|
||||
# Embed the same query again to check consistency
|
||||
embedding2 = await provider.embed_query(query)
|
||||
assert len(embedding) == len(embedding2)
|
||||
|
||||
# The embeddings should be identical for the same input
|
||||
np.testing.assert_array_almost_equal(np.array(embedding), np.array(embedding2))
|
||||
|
||||
def test_get_vector_name(self):
|
||||
"""Test that the vector name is generated correctly."""
|
||||
provider = FastEmbedProvider("sentence-transformers/all-MiniLM-L6-v2")
|
||||
vector_name = provider.get_vector_name()
|
||||
|
||||
# Check that the vector name follows the expected format
|
||||
assert vector_name.startswith("fast-")
|
||||
assert "minilm" in vector_name.lower()
|
||||
Reference in New Issue
Block a user