new: update type hints (#64)
* new: update type hints * fix: do not pass location and path to qdrant client, and do not accept them together * new: update settings tests * fix: revert removal of local path
This commit is contained in:
@@ -1,17 +1,16 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingProvider(ABC):
|
class EmbeddingProvider(ABC):
|
||||||
"""Abstract base class for embedding providers."""
|
"""Abstract base class for embedding providers."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def embed_documents(self, documents: List[str]) -> List[List[float]]:
|
async def embed_documents(self, documents: list[str]) -> list[list[float]]:
|
||||||
"""Embed a list of documents into vectors."""
|
"""Embed a list of documents into vectors."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def embed_query(self, query: str) -> List[float]:
|
async def embed_query(self, query: str) -> list[float]:
|
||||||
"""Embed a query into a vector."""
|
"""Embed a query into a vector."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from fastembed import TextEmbedding
|
from fastembed import TextEmbedding
|
||||||
from fastembed.common.model_description import DenseModelDescription
|
from fastembed.common.model_description import DenseModelDescription
|
||||||
@@ -17,7 +16,7 @@ class FastEmbedProvider(EmbeddingProvider):
|
|||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.embedding_model = TextEmbedding(model_name)
|
self.embedding_model = TextEmbedding(model_name)
|
||||||
|
|
||||||
async def embed_documents(self, documents: List[str]) -> List[List[float]]:
|
async def embed_documents(self, documents: list[str]) -> list[list[float]]:
|
||||||
"""Embed a list of documents into vectors."""
|
"""Embed a list of documents into vectors."""
|
||||||
# Run in a thread pool since FastEmbed is synchronous
|
# Run in a thread pool since FastEmbed is synchronous
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
@@ -26,7 +25,7 @@ class FastEmbedProvider(EmbeddingProvider):
|
|||||||
)
|
)
|
||||||
return [embedding.tolist() for embedding in embeddings]
|
return [embedding.tolist() for embedding in embeddings]
|
||||||
|
|
||||||
async def embed_query(self, query: str) -> List[float]:
|
async def embed_query(self, query: str) -> list[float]:
|
||||||
"""Embed a query into a vector."""
|
"""Embed a query into a vector."""
|
||||||
# Run in a thread pool since FastEmbed is synchronous
|
# Run in a thread pool since FastEmbed is synchronous
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Annotated, Any, List, Optional
|
from typing import Annotated, Any
|
||||||
|
|
||||||
from fastmcp import Context, FastMCP
|
from fastmcp import Context, FastMCP
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@@ -76,7 +76,7 @@ class QdrantMCPServer(FastMCP):
|
|||||||
# If we set it to be optional, some of the MCP clients, like Cursor, cannot
|
# If we set it to be optional, some of the MCP clients, like Cursor, cannot
|
||||||
# handle the optional parameter correctly.
|
# handle the optional parameter correctly.
|
||||||
metadata: Annotated[
|
metadata: Annotated[
|
||||||
Optional[Metadata],
|
Metadata | None,
|
||||||
Field(
|
Field(
|
||||||
description="Extra metadata stored along with memorised information. Any json is accepted."
|
description="Extra metadata stored along with memorised information. Any json is accepted."
|
||||||
),
|
),
|
||||||
@@ -106,14 +106,15 @@ class QdrantMCPServer(FastMCP):
|
|||||||
collection_name: Annotated[
|
collection_name: Annotated[
|
||||||
str, Field(description="The collection to search in")
|
str, Field(description="The collection to search in")
|
||||||
],
|
],
|
||||||
query_filter: Optional[ArbitraryFilter] = None,
|
query_filter: ArbitraryFilter | None = None,
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Find memories in Qdrant.
|
Find memories in Qdrant.
|
||||||
:param ctx: The context for the request.
|
:param ctx: The context for the request.
|
||||||
:param query: The query to use for the search.
|
:param query: The query to use for the search.
|
||||||
:param collection_name: The name of the collection to search in, optional. If not provided,
|
:param collection_name: The name of the collection to search in, optional. If not provided,
|
||||||
the default collection is used.
|
the default collection is used.
|
||||||
|
:param query_filter: The filter to apply to the query.
|
||||||
:return: A list of entries found.
|
:return: A list of entries found.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -123,10 +124,6 @@ class QdrantMCPServer(FastMCP):
|
|||||||
query_filter = models.Filter(**query_filter) if query_filter else None
|
query_filter = models.Filter(**query_filter) if query_filter else None
|
||||||
|
|
||||||
await ctx.debug(f"Finding results for query {query}")
|
await ctx.debug(f"Finding results for query {query}")
|
||||||
if collection_name:
|
|
||||||
await ctx.debug(
|
|
||||||
f"Overriding the collection name with {collection_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
entries = await self.qdrant_connector.search(
|
entries = await self.qdrant_connector.search(
|
||||||
query,
|
query,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from qdrant_client import AsyncQdrantClient, models
|
from qdrant_client import AsyncQdrantClient, models
|
||||||
@@ -10,9 +10,8 @@ from mcp_server_qdrant.settings import METADATA_PATH
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
Metadata = Dict[str, Any]
|
Metadata = dict[str, Any]
|
||||||
|
ArbitraryFilter = dict[str, Any]
|
||||||
ArbitraryFilter = Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class Entry(BaseModel):
|
class Entry(BaseModel):
|
||||||
@@ -21,7 +20,7 @@ class Entry(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
content: str
|
content: str
|
||||||
metadata: Optional[Metadata] = None
|
metadata: Metadata | None = None
|
||||||
|
|
||||||
|
|
||||||
class QdrantConnector:
|
class QdrantConnector:
|
||||||
@@ -37,12 +36,12 @@ class QdrantConnector:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
qdrant_url: Optional[str],
|
qdrant_url: str | None,
|
||||||
qdrant_api_key: Optional[str],
|
qdrant_api_key: str | None,
|
||||||
collection_name: Optional[str],
|
collection_name: str | None,
|
||||||
embedding_provider: EmbeddingProvider,
|
embedding_provider: EmbeddingProvider,
|
||||||
qdrant_local_path: Optional[str] = None,
|
qdrant_local_path: str | None = None,
|
||||||
field_indexes: Optional[dict[str, models.PayloadSchemaType]] = None,
|
field_indexes: dict[str, models.PayloadSchemaType] | None = None,
|
||||||
):
|
):
|
||||||
self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None
|
self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None
|
||||||
self._qdrant_api_key = qdrant_api_key
|
self._qdrant_api_key = qdrant_api_key
|
||||||
@@ -61,7 +60,7 @@ class QdrantConnector:
|
|||||||
response = await self._client.get_collections()
|
response = await self._client.get_collections()
|
||||||
return [collection.name for collection in response.collections]
|
return [collection.name for collection in response.collections]
|
||||||
|
|
||||||
async def store(self, entry: Entry, *, collection_name: Optional[str] = None):
|
async def store(self, entry: Entry, *, collection_name: str | None = None):
|
||||||
"""
|
"""
|
||||||
Store some information in the Qdrant collection, along with the specified metadata.
|
Store some information in the Qdrant collection, along with the specified metadata.
|
||||||
:param entry: The entry to store in the Qdrant collection.
|
:param entry: The entry to store in the Qdrant collection.
|
||||||
@@ -95,9 +94,9 @@ class QdrantConnector:
|
|||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
*,
|
*,
|
||||||
collection_name: Optional[str] = None,
|
collection_name: str | None = None,
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
query_filter: Optional[models.Filter] = None,
|
query_filter: models.Filter | None = None,
|
||||||
) -> list[Entry]:
|
) -> list[Entry]:
|
||||||
"""
|
"""
|
||||||
Find points in the Qdrant collection. If there are no entries found, an empty list is returned.
|
Find points in the Qdrant collection. If there are no entries found, an empty list is returned.
|
||||||
@@ -105,6 +104,8 @@ class QdrantConnector:
|
|||||||
:param collection_name: The name of the collection to search in, optional. If not provided,
|
:param collection_name: The name of the collection to search in, optional. If not provided,
|
||||||
the default collection is used.
|
the default collection is used.
|
||||||
:param limit: The maximum number of entries to return.
|
:param limit: The maximum number of entries to return.
|
||||||
|
:param query_filter: The filter to apply to the query, if any.
|
||||||
|
|
||||||
:return: A list of entries found.
|
:return: A list of entries found.
|
||||||
"""
|
"""
|
||||||
collection_name = collection_name or self._default_collection_name
|
collection_name = collection_name or self._default_collection_name
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import Literal, Optional
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, model_validator
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
from mcp_server_qdrant.embeddings.types import EmbeddingProviderType
|
from mcp_server_qdrant.embeddings.types import EmbeddingProviderType
|
||||||
@@ -56,7 +56,7 @@ class FilterableField(BaseModel):
|
|||||||
field_type: Literal["keyword", "integer", "float", "boolean"] = Field(
|
field_type: Literal["keyword", "integer", "float", "boolean"] = Field(
|
||||||
description="The type of the field"
|
description="The type of the field"
|
||||||
)
|
)
|
||||||
condition: Optional[Literal["==", "!=", ">", ">=", "<", "<=", "any", "except"]] = (
|
condition: Literal["==", "!=", ">", ">=", "<", "<=", "any", "except"] | None = (
|
||||||
Field(
|
Field(
|
||||||
default=None,
|
default=None,
|
||||||
description=(
|
description=(
|
||||||
@@ -76,18 +76,16 @@ class QdrantSettings(BaseSettings):
|
|||||||
Configuration for the Qdrant connector.
|
Configuration for the Qdrant connector.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
location: Optional[str] = Field(default=None, validation_alias="QDRANT_URL")
|
location: str | None = Field(default=None, validation_alias="QDRANT_URL")
|
||||||
api_key: Optional[str] = Field(default=None, validation_alias="QDRANT_API_KEY")
|
api_key: str | None = Field(default=None, validation_alias="QDRANT_API_KEY")
|
||||||
collection_name: Optional[str] = Field(
|
collection_name: str | None = Field(
|
||||||
default=None, validation_alias="COLLECTION_NAME"
|
default=None, validation_alias="COLLECTION_NAME"
|
||||||
)
|
)
|
||||||
local_path: Optional[str] = Field(
|
local_path: str | None = Field(default=None, validation_alias="QDRANT_LOCAL_PATH")
|
||||||
default=None, validation_alias="QDRANT_LOCAL_PATH"
|
|
||||||
)
|
|
||||||
search_limit: int = Field(default=10, validation_alias="QDRANT_SEARCH_LIMIT")
|
search_limit: int = Field(default=10, validation_alias="QDRANT_SEARCH_LIMIT")
|
||||||
read_only: bool = Field(default=False, validation_alias="QDRANT_READ_ONLY")
|
read_only: bool = Field(default=False, validation_alias="QDRANT_READ_ONLY")
|
||||||
|
|
||||||
filterable_fields: Optional[list[FilterableField]] = Field(default=None)
|
filterable_fields: list[FilterableField] | None = Field(default=None)
|
||||||
|
|
||||||
allow_arbitrary_filter: bool = Field(
|
allow_arbitrary_filter: bool = Field(
|
||||||
default=False, validation_alias="QDRANT_ALLOW_ARBITRARY_FILTER"
|
default=False, validation_alias="QDRANT_ALLOW_ARBITRARY_FILTER"
|
||||||
@@ -106,3 +104,12 @@ class QdrantSettings(BaseSettings):
|
|||||||
for field in self.filterable_fields
|
for field in self.filterable_fields
|
||||||
if field.condition is not None
|
if field.condition is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_local_path_conflict(self) -> "QdrantSettings":
|
||||||
|
if self.local_path:
|
||||||
|
if self.location is not None or self.api_key is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"If 'local_path' is set, 'location' and 'api_key' must be None."
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import os
|
import pytest
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
from mcp_server_qdrant.embeddings.types import EmbeddingProviderType
|
from mcp_server_qdrant.embeddings.types import EmbeddingProviderType
|
||||||
from mcp_server_qdrant.settings import (
|
from mcp_server_qdrant.settings import (
|
||||||
@@ -18,34 +17,51 @@ class TestQdrantSettings:
|
|||||||
# Should not raise error because there are no required fields
|
# Should not raise error because there are no required fields
|
||||||
QdrantSettings()
|
QdrantSettings()
|
||||||
|
|
||||||
@patch.dict(
|
def test_minimal_config(self, monkeypatch):
|
||||||
os.environ,
|
|
||||||
{"QDRANT_URL": "http://localhost:6333", "COLLECTION_NAME": "test_collection"},
|
|
||||||
)
|
|
||||||
def test_minimal_config(self):
|
|
||||||
"""Test loading minimal configuration from environment variables."""
|
"""Test loading minimal configuration from environment variables."""
|
||||||
|
monkeypatch.setenv("QDRANT_URL", "http://localhost:6333")
|
||||||
|
monkeypatch.setenv("COLLECTION_NAME", "test_collection")
|
||||||
|
|
||||||
settings = QdrantSettings()
|
settings = QdrantSettings()
|
||||||
assert settings.location == "http://localhost:6333"
|
assert settings.location == "http://localhost:6333"
|
||||||
assert settings.collection_name == "test_collection"
|
assert settings.collection_name == "test_collection"
|
||||||
assert settings.api_key is None
|
assert settings.api_key is None
|
||||||
assert settings.local_path is None
|
assert settings.local_path is None
|
||||||
|
|
||||||
@patch.dict(
|
def test_full_config(self, monkeypatch):
|
||||||
os.environ,
|
|
||||||
{
|
|
||||||
"QDRANT_URL": "http://qdrant.example.com:6333",
|
|
||||||
"QDRANT_API_KEY": "test_api_key",
|
|
||||||
"COLLECTION_NAME": "my_memories",
|
|
||||||
"QDRANT_LOCAL_PATH": "/tmp/qdrant",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
def test_full_config(self):
|
|
||||||
"""Test loading full configuration from environment variables."""
|
"""Test loading full configuration from environment variables."""
|
||||||
|
monkeypatch.setenv("QDRANT_URL", "http://qdrant.example.com:6333")
|
||||||
|
monkeypatch.setenv("QDRANT_API_KEY", "test_api_key")
|
||||||
|
monkeypatch.setenv("COLLECTION_NAME", "my_memories")
|
||||||
|
monkeypatch.setenv("QDRANT_SEARCH_LIMIT", "15")
|
||||||
|
monkeypatch.setenv("QDRANT_READ_ONLY", "1")
|
||||||
|
|
||||||
settings = QdrantSettings()
|
settings = QdrantSettings()
|
||||||
assert settings.location == "http://qdrant.example.com:6333"
|
assert settings.location == "http://qdrant.example.com:6333"
|
||||||
assert settings.api_key == "test_api_key"
|
assert settings.api_key == "test_api_key"
|
||||||
assert settings.collection_name == "my_memories"
|
assert settings.collection_name == "my_memories"
|
||||||
assert settings.local_path == "/tmp/qdrant"
|
assert settings.search_limit == 15
|
||||||
|
assert settings.read_only is True
|
||||||
|
|
||||||
|
def test_local_path_config(self, monkeypatch):
|
||||||
|
"""Test loading local path configuration from environment variables."""
|
||||||
|
monkeypatch.setenv("QDRANT_LOCAL_PATH", "/path/to/local/qdrant")
|
||||||
|
|
||||||
|
settings = QdrantSettings()
|
||||||
|
assert settings.local_path == "/path/to/local/qdrant"
|
||||||
|
|
||||||
|
def test_local_path_is_exclusive_with_url(self, monkeypatch):
|
||||||
|
"""Test that local path cannot be set if Qdrant URL is provided."""
|
||||||
|
monkeypatch.setenv("QDRANT_URL", "http://localhost:6333")
|
||||||
|
monkeypatch.setenv("QDRANT_LOCAL_PATH", "/path/to/local/qdrant")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
QdrantSettings()
|
||||||
|
|
||||||
|
monkeypatch.delenv("QDRANT_URL", raising=False)
|
||||||
|
monkeypatch.setenv("QDRANT_API_KEY", "test_api_key")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
QdrantSettings()
|
||||||
|
|
||||||
|
|
||||||
class TestEmbeddingProviderSettings:
|
class TestEmbeddingProviderSettings:
|
||||||
@@ -55,12 +71,9 @@ class TestEmbeddingProviderSettings:
|
|||||||
assert settings.provider_type == EmbeddingProviderType.FASTEMBED
|
assert settings.provider_type == EmbeddingProviderType.FASTEMBED
|
||||||
assert settings.model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
assert settings.model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
|
||||||
@patch.dict(
|
def test_custom_values(self, monkeypatch):
|
||||||
os.environ,
|
|
||||||
{"EMBEDDING_MODEL": "custom_model"},
|
|
||||||
)
|
|
||||||
def test_custom_values(self):
|
|
||||||
"""Test loading custom values from environment variables."""
|
"""Test loading custom values from environment variables."""
|
||||||
|
monkeypatch.setenv("EMBEDDING_MODEL", "custom_model")
|
||||||
settings = EmbeddingProviderSettings()
|
settings = EmbeddingProviderSettings()
|
||||||
assert settings.provider_type == EmbeddingProviderType.FASTEMBED
|
assert settings.provider_type == EmbeddingProviderType.FASTEMBED
|
||||||
assert settings.model_name == "custom_model"
|
assert settings.model_name == "custom_model"
|
||||||
@@ -73,35 +86,24 @@ class TestToolSettings:
|
|||||||
assert settings.tool_store_description == DEFAULT_TOOL_STORE_DESCRIPTION
|
assert settings.tool_store_description == DEFAULT_TOOL_STORE_DESCRIPTION
|
||||||
assert settings.tool_find_description == DEFAULT_TOOL_FIND_DESCRIPTION
|
assert settings.tool_find_description == DEFAULT_TOOL_FIND_DESCRIPTION
|
||||||
|
|
||||||
@patch.dict(
|
def test_custom_store_description(self, monkeypatch):
|
||||||
os.environ,
|
|
||||||
{"TOOL_STORE_DESCRIPTION": "Custom store description"},
|
|
||||||
)
|
|
||||||
def test_custom_store_description(self):
|
|
||||||
"""Test loading custom store description from environment variable."""
|
"""Test loading custom store description from environment variable."""
|
||||||
|
monkeypatch.setenv("TOOL_STORE_DESCRIPTION", "Custom store description")
|
||||||
settings = ToolSettings()
|
settings = ToolSettings()
|
||||||
assert settings.tool_store_description == "Custom store description"
|
assert settings.tool_store_description == "Custom store description"
|
||||||
assert settings.tool_find_description == DEFAULT_TOOL_FIND_DESCRIPTION
|
assert settings.tool_find_description == DEFAULT_TOOL_FIND_DESCRIPTION
|
||||||
|
|
||||||
@patch.dict(
|
def test_custom_find_description(self, monkeypatch):
|
||||||
os.environ,
|
|
||||||
{"TOOL_FIND_DESCRIPTION": "Custom find description"},
|
|
||||||
)
|
|
||||||
def test_custom_find_description(self):
|
|
||||||
"""Test loading custom find description from environment variable."""
|
"""Test loading custom find description from environment variable."""
|
||||||
|
monkeypatch.setenv("TOOL_FIND_DESCRIPTION", "Custom find description")
|
||||||
settings = ToolSettings()
|
settings = ToolSettings()
|
||||||
assert settings.tool_store_description == DEFAULT_TOOL_STORE_DESCRIPTION
|
assert settings.tool_store_description == DEFAULT_TOOL_STORE_DESCRIPTION
|
||||||
assert settings.tool_find_description == "Custom find description"
|
assert settings.tool_find_description == "Custom find description"
|
||||||
|
|
||||||
@patch.dict(
|
def test_all_custom_values(self, monkeypatch):
|
||||||
os.environ,
|
|
||||||
{
|
|
||||||
"TOOL_STORE_DESCRIPTION": "Custom store description",
|
|
||||||
"TOOL_FIND_DESCRIPTION": "Custom find description",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
def test_all_custom_values(self):
|
|
||||||
"""Test loading all custom values from environment variables."""
|
"""Test loading all custom values from environment variables."""
|
||||||
|
monkeypatch.setenv("TOOL_STORE_DESCRIPTION", "Custom store description")
|
||||||
|
monkeypatch.setenv("TOOL_FIND_DESCRIPTION", "Custom find description")
|
||||||
settings = ToolSettings()
|
settings = ToolSettings()
|
||||||
assert settings.tool_store_description == "Custom store description"
|
assert settings.tool_store_description == "Custom store description"
|
||||||
assert settings.tool_find_description == "Custom find description"
|
assert settings.tool_find_description == "Custom find description"
|
||||||
|
|||||||
Reference in New Issue
Block a user