new: update type hints (#64)

* new: update type hints

* fix: do not pass location and path to qdrant client, and do not accept them together

* new: update settings tests

* fix: revert removal of local path
This commit is contained in:
George
2025-06-12 00:55:07 +04:00
committed by GitHub
parent b657656363
commit 28bf298a32
6 changed files with 83 additions and 78 deletions

View File

@@ -1,17 +1,16 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List
class EmbeddingProvider(ABC): class EmbeddingProvider(ABC):
"""Abstract base class for embedding providers.""" """Abstract base class for embedding providers."""
@abstractmethod @abstractmethod
async def embed_documents(self, documents: List[str]) -> List[List[float]]: async def embed_documents(self, documents: list[str]) -> list[list[float]]:
"""Embed a list of documents into vectors.""" """Embed a list of documents into vectors."""
pass pass
@abstractmethod @abstractmethod
async def embed_query(self, query: str) -> List[float]: async def embed_query(self, query: str) -> list[float]:
"""Embed a query into a vector.""" """Embed a query into a vector."""
pass pass

View File

@@ -1,5 +1,4 @@
import asyncio import asyncio
from typing import List
from fastembed import TextEmbedding from fastembed import TextEmbedding
from fastembed.common.model_description import DenseModelDescription from fastembed.common.model_description import DenseModelDescription
@@ -17,7 +16,7 @@ class FastEmbedProvider(EmbeddingProvider):
self.model_name = model_name self.model_name = model_name
self.embedding_model = TextEmbedding(model_name) self.embedding_model = TextEmbedding(model_name)
async def embed_documents(self, documents: List[str]) -> List[List[float]]: async def embed_documents(self, documents: list[str]) -> list[list[float]]:
"""Embed a list of documents into vectors.""" """Embed a list of documents into vectors."""
# Run in a thread pool since FastEmbed is synchronous # Run in a thread pool since FastEmbed is synchronous
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@@ -26,7 +25,7 @@ class FastEmbedProvider(EmbeddingProvider):
) )
return [embedding.tolist() for embedding in embeddings] return [embedding.tolist() for embedding in embeddings]
async def embed_query(self, query: str) -> List[float]: async def embed_query(self, query: str) -> list[float]:
"""Embed a query into a vector.""" """Embed a query into a vector."""
# Run in a thread pool since FastEmbed is synchronous # Run in a thread pool since FastEmbed is synchronous
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()

View File

@@ -1,6 +1,6 @@
import json import json
import logging import logging
from typing import Annotated, Any, List, Optional from typing import Annotated, Any
from fastmcp import Context, FastMCP from fastmcp import Context, FastMCP
from pydantic import Field from pydantic import Field
@@ -76,7 +76,7 @@ class QdrantMCPServer(FastMCP):
# If we set it to be optional, some of the MCP clients, like Cursor, cannot # If we set it to be optional, some of the MCP clients, like Cursor, cannot
# handle the optional parameter correctly. # handle the optional parameter correctly.
metadata: Annotated[ metadata: Annotated[
Optional[Metadata], Metadata | None,
Field( Field(
description="Extra metadata stored along with memorised information. Any json is accepted." description="Extra metadata stored along with memorised information. Any json is accepted."
), ),
@@ -106,14 +106,15 @@ class QdrantMCPServer(FastMCP):
collection_name: Annotated[ collection_name: Annotated[
str, Field(description="The collection to search in") str, Field(description="The collection to search in")
], ],
query_filter: Optional[ArbitraryFilter] = None, query_filter: ArbitraryFilter | None = None,
) -> List[str]: ) -> list[str]:
""" """
Find memories in Qdrant. Find memories in Qdrant.
:param ctx: The context for the request. :param ctx: The context for the request.
:param query: The query to use for the search. :param query: The query to use for the search.
:param collection_name: The name of the collection to search in, optional. If not provided, :param collection_name: The name of the collection to search in, optional. If not provided,
the default collection is used. the default collection is used.
:param query_filter: The filter to apply to the query.
:return: A list of entries found. :return: A list of entries found.
""" """
@@ -123,10 +124,6 @@ class QdrantMCPServer(FastMCP):
query_filter = models.Filter(**query_filter) if query_filter else None query_filter = models.Filter(**query_filter) if query_filter else None
await ctx.debug(f"Finding results for query {query}") await ctx.debug(f"Finding results for query {query}")
if collection_name:
await ctx.debug(
f"Overriding the collection name with {collection_name}"
)
entries = await self.qdrant_connector.search( entries = await self.qdrant_connector.search(
query, query,

View File

@@ -1,6 +1,6 @@
import logging import logging
import uuid import uuid
from typing import Any, Dict, Optional from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
from qdrant_client import AsyncQdrantClient, models from qdrant_client import AsyncQdrantClient, models
@@ -10,9 +10,8 @@ from mcp_server_qdrant.settings import METADATA_PATH
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
Metadata = Dict[str, Any] Metadata = dict[str, Any]
ArbitraryFilter = dict[str, Any]
ArbitraryFilter = Dict[str, Any]
class Entry(BaseModel): class Entry(BaseModel):
@@ -21,7 +20,7 @@ class Entry(BaseModel):
""" """
content: str content: str
metadata: Optional[Metadata] = None metadata: Metadata | None = None
class QdrantConnector: class QdrantConnector:
@@ -37,12 +36,12 @@ class QdrantConnector:
def __init__( def __init__(
self, self,
qdrant_url: Optional[str], qdrant_url: str | None,
qdrant_api_key: Optional[str], qdrant_api_key: str | None,
collection_name: Optional[str], collection_name: str | None,
embedding_provider: EmbeddingProvider, embedding_provider: EmbeddingProvider,
qdrant_local_path: Optional[str] = None, qdrant_local_path: str | None = None,
field_indexes: Optional[dict[str, models.PayloadSchemaType]] = None, field_indexes: dict[str, models.PayloadSchemaType] | None = None,
): ):
self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None
self._qdrant_api_key = qdrant_api_key self._qdrant_api_key = qdrant_api_key
@@ -61,7 +60,7 @@ class QdrantConnector:
response = await self._client.get_collections() response = await self._client.get_collections()
return [collection.name for collection in response.collections] return [collection.name for collection in response.collections]
async def store(self, entry: Entry, *, collection_name: Optional[str] = None): async def store(self, entry: Entry, *, collection_name: str | None = None):
""" """
Store some information in the Qdrant collection, along with the specified metadata. Store some information in the Qdrant collection, along with the specified metadata.
:param entry: The entry to store in the Qdrant collection. :param entry: The entry to store in the Qdrant collection.
@@ -95,9 +94,9 @@ class QdrantConnector:
self, self,
query: str, query: str,
*, *,
collection_name: Optional[str] = None, collection_name: str | None = None,
limit: int = 10, limit: int = 10,
query_filter: Optional[models.Filter] = None, query_filter: models.Filter | None = None,
) -> list[Entry]: ) -> list[Entry]:
""" """
Find points in the Qdrant collection. If there are no entries found, an empty list is returned. Find points in the Qdrant collection. If there are no entries found, an empty list is returned.
@@ -105,6 +104,8 @@ class QdrantConnector:
:param collection_name: The name of the collection to search in, optional. If not provided, :param collection_name: The name of the collection to search in, optional. If not provided,
the default collection is used. the default collection is used.
:param limit: The maximum number of entries to return. :param limit: The maximum number of entries to return.
:param query_filter: The filter to apply to the query, if any.
:return: A list of entries found. :return: A list of entries found.
""" """
collection_name = collection_name or self._default_collection_name collection_name = collection_name or self._default_collection_name

View File

@@ -1,6 +1,6 @@
from typing import Literal, Optional from typing import Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, model_validator
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
from mcp_server_qdrant.embeddings.types import EmbeddingProviderType from mcp_server_qdrant.embeddings.types import EmbeddingProviderType
@@ -56,7 +56,7 @@ class FilterableField(BaseModel):
field_type: Literal["keyword", "integer", "float", "boolean"] = Field( field_type: Literal["keyword", "integer", "float", "boolean"] = Field(
description="The type of the field" description="The type of the field"
) )
condition: Optional[Literal["==", "!=", ">", ">=", "<", "<=", "any", "except"]] = ( condition: Literal["==", "!=", ">", ">=", "<", "<=", "any", "except"] | None = (
Field( Field(
default=None, default=None,
description=( description=(
@@ -76,18 +76,16 @@ class QdrantSettings(BaseSettings):
Configuration for the Qdrant connector. Configuration for the Qdrant connector.
""" """
location: Optional[str] = Field(default=None, validation_alias="QDRANT_URL") location: str | None = Field(default=None, validation_alias="QDRANT_URL")
api_key: Optional[str] = Field(default=None, validation_alias="QDRANT_API_KEY") api_key: str | None = Field(default=None, validation_alias="QDRANT_API_KEY")
collection_name: Optional[str] = Field( collection_name: str | None = Field(
default=None, validation_alias="COLLECTION_NAME" default=None, validation_alias="COLLECTION_NAME"
) )
local_path: Optional[str] = Field( local_path: str | None = Field(default=None, validation_alias="QDRANT_LOCAL_PATH")
default=None, validation_alias="QDRANT_LOCAL_PATH"
)
search_limit: int = Field(default=10, validation_alias="QDRANT_SEARCH_LIMIT") search_limit: int = Field(default=10, validation_alias="QDRANT_SEARCH_LIMIT")
read_only: bool = Field(default=False, validation_alias="QDRANT_READ_ONLY") read_only: bool = Field(default=False, validation_alias="QDRANT_READ_ONLY")
filterable_fields: Optional[list[FilterableField]] = Field(default=None) filterable_fields: list[FilterableField] | None = Field(default=None)
allow_arbitrary_filter: bool = Field( allow_arbitrary_filter: bool = Field(
default=False, validation_alias="QDRANT_ALLOW_ARBITRARY_FILTER" default=False, validation_alias="QDRANT_ALLOW_ARBITRARY_FILTER"
@@ -106,3 +104,12 @@ class QdrantSettings(BaseSettings):
for field in self.filterable_fields for field in self.filterable_fields
if field.condition is not None if field.condition is not None
} }
@model_validator(mode="after")
def check_local_path_conflict(self) -> "QdrantSettings":
if self.local_path:
if self.location is not None or self.api_key is not None:
raise ValueError(
"If 'local_path' is set, 'location' and 'api_key' must be None."
)
return self

View File

@@ -1,5 +1,4 @@
import os import pytest
from unittest.mock import patch
from mcp_server_qdrant.embeddings.types import EmbeddingProviderType from mcp_server_qdrant.embeddings.types import EmbeddingProviderType
from mcp_server_qdrant.settings import ( from mcp_server_qdrant.settings import (
@@ -18,34 +17,51 @@ class TestQdrantSettings:
# Should not raise error because there are no required fields # Should not raise error because there are no required fields
QdrantSettings() QdrantSettings()
@patch.dict( def test_minimal_config(self, monkeypatch):
os.environ,
{"QDRANT_URL": "http://localhost:6333", "COLLECTION_NAME": "test_collection"},
)
def test_minimal_config(self):
"""Test loading minimal configuration from environment variables.""" """Test loading minimal configuration from environment variables."""
monkeypatch.setenv("QDRANT_URL", "http://localhost:6333")
monkeypatch.setenv("COLLECTION_NAME", "test_collection")
settings = QdrantSettings() settings = QdrantSettings()
assert settings.location == "http://localhost:6333" assert settings.location == "http://localhost:6333"
assert settings.collection_name == "test_collection" assert settings.collection_name == "test_collection"
assert settings.api_key is None assert settings.api_key is None
assert settings.local_path is None assert settings.local_path is None
@patch.dict( def test_full_config(self, monkeypatch):
os.environ,
{
"QDRANT_URL": "http://qdrant.example.com:6333",
"QDRANT_API_KEY": "test_api_key",
"COLLECTION_NAME": "my_memories",
"QDRANT_LOCAL_PATH": "/tmp/qdrant",
},
)
def test_full_config(self):
"""Test loading full configuration from environment variables.""" """Test loading full configuration from environment variables."""
monkeypatch.setenv("QDRANT_URL", "http://qdrant.example.com:6333")
monkeypatch.setenv("QDRANT_API_KEY", "test_api_key")
monkeypatch.setenv("COLLECTION_NAME", "my_memories")
monkeypatch.setenv("QDRANT_SEARCH_LIMIT", "15")
monkeypatch.setenv("QDRANT_READ_ONLY", "1")
settings = QdrantSettings() settings = QdrantSettings()
assert settings.location == "http://qdrant.example.com:6333" assert settings.location == "http://qdrant.example.com:6333"
assert settings.api_key == "test_api_key" assert settings.api_key == "test_api_key"
assert settings.collection_name == "my_memories" assert settings.collection_name == "my_memories"
assert settings.local_path == "/tmp/qdrant" assert settings.search_limit == 15
assert settings.read_only is True
def test_local_path_config(self, monkeypatch):
"""Test loading local path configuration from environment variables."""
monkeypatch.setenv("QDRANT_LOCAL_PATH", "/path/to/local/qdrant")
settings = QdrantSettings()
assert settings.local_path == "/path/to/local/qdrant"
def test_local_path_is_exclusive_with_url(self, monkeypatch):
"""Test that local path cannot be set if Qdrant URL is provided."""
monkeypatch.setenv("QDRANT_URL", "http://localhost:6333")
monkeypatch.setenv("QDRANT_LOCAL_PATH", "/path/to/local/qdrant")
with pytest.raises(ValueError):
QdrantSettings()
monkeypatch.delenv("QDRANT_URL", raising=False)
monkeypatch.setenv("QDRANT_API_KEY", "test_api_key")
with pytest.raises(ValueError):
QdrantSettings()
class TestEmbeddingProviderSettings: class TestEmbeddingProviderSettings:
@@ -55,12 +71,9 @@ class TestEmbeddingProviderSettings:
assert settings.provider_type == EmbeddingProviderType.FASTEMBED assert settings.provider_type == EmbeddingProviderType.FASTEMBED
assert settings.model_name == "sentence-transformers/all-MiniLM-L6-v2" assert settings.model_name == "sentence-transformers/all-MiniLM-L6-v2"
@patch.dict( def test_custom_values(self, monkeypatch):
os.environ,
{"EMBEDDING_MODEL": "custom_model"},
)
def test_custom_values(self):
"""Test loading custom values from environment variables.""" """Test loading custom values from environment variables."""
monkeypatch.setenv("EMBEDDING_MODEL", "custom_model")
settings = EmbeddingProviderSettings() settings = EmbeddingProviderSettings()
assert settings.provider_type == EmbeddingProviderType.FASTEMBED assert settings.provider_type == EmbeddingProviderType.FASTEMBED
assert settings.model_name == "custom_model" assert settings.model_name == "custom_model"
@@ -73,35 +86,24 @@ class TestToolSettings:
assert settings.tool_store_description == DEFAULT_TOOL_STORE_DESCRIPTION assert settings.tool_store_description == DEFAULT_TOOL_STORE_DESCRIPTION
assert settings.tool_find_description == DEFAULT_TOOL_FIND_DESCRIPTION assert settings.tool_find_description == DEFAULT_TOOL_FIND_DESCRIPTION
@patch.dict( def test_custom_store_description(self, monkeypatch):
os.environ,
{"TOOL_STORE_DESCRIPTION": "Custom store description"},
)
def test_custom_store_description(self):
"""Test loading custom store description from environment variable.""" """Test loading custom store description from environment variable."""
monkeypatch.setenv("TOOL_STORE_DESCRIPTION", "Custom store description")
settings = ToolSettings() settings = ToolSettings()
assert settings.tool_store_description == "Custom store description" assert settings.tool_store_description == "Custom store description"
assert settings.tool_find_description == DEFAULT_TOOL_FIND_DESCRIPTION assert settings.tool_find_description == DEFAULT_TOOL_FIND_DESCRIPTION
@patch.dict( def test_custom_find_description(self, monkeypatch):
os.environ,
{"TOOL_FIND_DESCRIPTION": "Custom find description"},
)
def test_custom_find_description(self):
"""Test loading custom find description from environment variable.""" """Test loading custom find description from environment variable."""
monkeypatch.setenv("TOOL_FIND_DESCRIPTION", "Custom find description")
settings = ToolSettings() settings = ToolSettings()
assert settings.tool_store_description == DEFAULT_TOOL_STORE_DESCRIPTION assert settings.tool_store_description == DEFAULT_TOOL_STORE_DESCRIPTION
assert settings.tool_find_description == "Custom find description" assert settings.tool_find_description == "Custom find description"
@patch.dict( def test_all_custom_values(self, monkeypatch):
os.environ,
{
"TOOL_STORE_DESCRIPTION": "Custom store description",
"TOOL_FIND_DESCRIPTION": "Custom find description",
},
)
def test_all_custom_values(self):
"""Test loading all custom values from environment variables.""" """Test loading all custom values from environment variables."""
monkeypatch.setenv("TOOL_STORE_DESCRIPTION", "Custom store description")
monkeypatch.setenv("TOOL_FIND_DESCRIPTION", "Custom find description")
settings = ToolSettings() settings = ToolSettings()
assert settings.tool_store_description == "Custom store description" assert settings.tool_store_description == "Custom store description"
assert settings.tool_find_description == "Custom find description" assert settings.tool_find_description == "Custom find description"