Configurable filters (#58)

* add configurable filters

* hello to hr department

* rollback debug code

* add arbitrary filter

* dont consider fields without conditions

* in and except condition

* proper annotation types for optional and list fields

* fix types import

* skip non-required fields

* fix: fix match except condition, fix boolean filter

* fix: apply ruff

* fix: make condition optional in filterable field

* fix: do not set default value for required fields (#63)

* fix: do not set default value for required fields

* fix: temp fix fastmcp to <2.8.0 cause of the breaking changes in the api

* fix: add missing changes to pyproject.toml

* fix: downgrade fastmcp even further to <2.7.0

---------

Co-authored-by: George Panchuk <george.panchuk@qdrant.tech>
Co-authored-by: George <panchuk.george@outlook.com>
This commit is contained in:
Andrey Vasnetsov
2025-06-11 16:19:18 +02:00
committed by GitHub
parent 244139beb5
commit b657656363
7 changed files with 1422 additions and 558 deletions

View File

@@ -9,7 +9,7 @@ dependencies = [
"fastembed>=0.6.0", "fastembed>=0.6.0",
"qdrant-client>=1.12.0", "qdrant-client>=1.12.0",
"pydantic>=2.10.6", "pydantic>=2.10.6",
"fastmcp>=2.5.1", "fastmcp>=2.5.1,<2.7.0",
] ]
[build-system] [build-system]
@@ -18,6 +18,7 @@ build-backend = "hatchling.build"
[tool.uv] [tool.uv]
dev-dependencies = [ dev-dependencies = [
"ipdb>=0.13.13",
"isort>=6.0.1", "isort>=6.0.1",
"mypy>=1.9.0", "mypy>=1.9.0",
"pre-commit>=4.1.0", "pre-commit>=4.1.0",

View File

@@ -0,0 +1,194 @@
from typing import Any
from qdrant_client import models
from mcp_server_qdrant.qdrant import ArbitraryFilter
from mcp_server_qdrant.settings import METADATA_PATH, FilterableField
def make_filter(
filterable_fields: dict[str, FilterableField], values: dict[str, Any]
) -> ArbitraryFilter:
must_conditions = []
must_not_conditions = []
for raw_field_name, field_value in values.items():
if raw_field_name not in filterable_fields:
raise ValueError(f"Field {raw_field_name} is not a filterable field")
field = filterable_fields[raw_field_name]
if field_value is None:
if field.required:
raise ValueError(f"Field {raw_field_name} is required")
else:
continue
field_name = f"{METADATA_PATH}.{raw_field_name}"
if field.field_type == "keyword":
if field.condition == "==":
must_conditions.append(
models.FieldCondition(
key=field_name, match=models.MatchValue(value=field_value)
)
)
elif field.condition == "!=":
must_not_conditions.append(
models.FieldCondition(
key=field_name, match=models.MatchValue(value=field_value)
)
)
elif field.condition == "any":
must_conditions.append(
models.FieldCondition(
key=field_name, match=models.MatchAny(any=field_value)
)
)
elif field.condition == "except":
must_conditions.append(
models.FieldCondition(
key=field_name,
match=models.MatchExcept(**{"except": field_value}),
)
)
elif field.condition is not None:
raise ValueError(
f"Invalid condition {field.condition} for keyword field {field_name}"
)
elif field.field_type == "integer":
if field.condition == "==":
must_conditions.append(
models.FieldCondition(
key=field_name, match=models.MatchValue(value=field_value)
)
)
elif field.condition == "!=":
must_not_conditions.append(
models.FieldCondition(
key=field_name, match=models.MatchValue(value=field_value)
)
)
elif field.condition == ">":
must_conditions.append(
models.FieldCondition(
key=field_name, range=models.Range(gt=field_value)
)
)
elif field.condition == ">=":
must_conditions.append(
models.FieldCondition(
key=field_name, range=models.Range(gte=field_value)
)
)
elif field.condition == "<":
must_conditions.append(
models.FieldCondition(
key=field_name, range=models.Range(lt=field_value)
)
)
elif field.condition == "<=":
must_conditions.append(
models.FieldCondition(
key=field_name, range=models.Range(lte=field_value)
)
)
elif field.condition == "any":
must_conditions.append(
models.FieldCondition(
key=field_name, match=models.MatchAny(any=field_value)
)
)
elif field.condition == "except":
must_conditions.append(
models.FieldCondition(
key=field_name,
match=models.MatchExcept(**{"except": field_value}),
)
)
elif field.condition is not None:
raise ValueError(
f"Invalid condition {field.condition} for integer field {field_name}"
)
elif field.field_type == "float":
# For float values, we only support range comparisons
if field.condition == ">":
must_conditions.append(
models.FieldCondition(
key=field_name, range=models.Range(gt=field_value)
)
)
elif field.condition == ">=":
must_conditions.append(
models.FieldCondition(
key=field_name, range=models.Range(gte=field_value)
)
)
elif field.condition == "<":
must_conditions.append(
models.FieldCondition(
key=field_name, range=models.Range(lt=field_value)
)
)
elif field.condition == "<=":
must_conditions.append(
models.FieldCondition(
key=field_name, range=models.Range(lte=field_value)
)
)
elif field.condition is not None:
raise ValueError(
f"Invalid condition {field.condition} for float field {field_name}. "
"Only range comparisons (>, >=, <, <=) are supported for float values."
)
elif field.field_type == "boolean":
if field.condition == "==":
must_conditions.append(
models.FieldCondition(
key=field_name, match=models.MatchValue(value=field_value)
)
)
elif field.condition == "!=":
must_not_conditions.append(
models.FieldCondition(
key=field_name, match=models.MatchValue(value=field_value)
)
)
elif field.condition is not None:
raise ValueError(
f"Invalid condition {field.condition} for boolean field {field_name}"
)
else:
raise ValueError(
f"Unsupported field type {field.field_type} for field {field_name}"
)
return models.Filter(
must=must_conditions, must_not=must_not_conditions
).model_dump()
def make_indexes(
filterable_fields: dict[str, FilterableField],
) -> dict[str, models.PayloadSchemaType]:
indexes = {}
for field_name, field in filterable_fields.items():
if field.field_type == "keyword":
indexes[f"{METADATA_PATH}.{field_name}"] = models.PayloadSchemaType.KEYWORD
elif field.field_type == "integer":
indexes[f"{METADATA_PATH}.{field_name}"] = models.PayloadSchemaType.INTEGER
elif field.field_type == "float":
indexes[f"{METADATA_PATH}.{field_name}"] = models.PayloadSchemaType.FLOAT
elif field.field_type == "boolean":
indexes[f"{METADATA_PATH}.{field_name}"] = models.PayloadSchemaType.BOOL
else:
raise ValueError(
f"Unsupported field type {field.field_type} for field {field_name}"
)
return indexes

View File

@@ -0,0 +1,150 @@
import inspect
from functools import wraps
from typing import Annotated, Callable, Optional
from pydantic import Field
from mcp_server_qdrant.common.filters import make_filter
from mcp_server_qdrant.settings import FilterableField
def wrap_filters(
original_func: Callable, filterable_fields: dict[str, FilterableField]
) -> Callable:
"""
Wraps the original_func function: replaces `filter` parameter with multiple parameters defined by `filterable_fields`.
"""
sig = inspect.signature(original_func)
@wraps(original_func)
def wrapper(*args, **kwargs):
# Start with fixed values
filter_values = {}
for field_name in filterable_fields:
if field_name in kwargs:
filter_values[field_name] = kwargs.pop(field_name)
query_filter = make_filter(filterable_fields, filter_values)
return original_func(**kwargs, query_filter=query_filter)
# Replace `query_filter` signature with parameters from `filterable_fields`
param_names = []
for param_name in sig.parameters:
if param_name == "query_filter":
continue
param_names.append(param_name)
new_params = [sig.parameters[param_name] for param_name in param_names]
required_new_params = []
optional_new_params = []
# Create a new signature parameters from `filterable_fields`
for field in filterable_fields.values():
field_name = field.name
field_type: type
if field.field_type == "keyword":
field_type = str
elif field.field_type == "integer":
field_type = int
elif field.field_type == "float":
field_type = float
elif field.field_type == "boolean":
field_type = bool
else:
raise ValueError(f"Unsupported field type: {field.field_type}")
if field.condition in {"any", "except"}:
if field_type not in {str, int}:
raise ValueError(
f'Only "keyword" and "integer" types are supported for "{field.condition}" condition'
)
field_type = list[field_type] # type: ignore
if field.required:
annotation = Annotated[field_type, Field(description=field.description)] # type: ignore
parameter = inspect.Parameter(
name=field_name,
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=annotation,
)
required_new_params.append(parameter)
else:
annotation = Annotated[ # type: ignore
Optional[field_type], Field(description=field.description)
]
parameter = inspect.Parameter(
name=field_name,
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=None,
annotation=annotation,
)
optional_new_params.append(parameter)
new_params.extend(required_new_params)
new_params.extend(optional_new_params)
# Set the new __signature__ for introspection
new_signature = sig.replace(parameters=new_params)
wrapper.__signature__ = new_signature # type: ignore
# Set the new __annotations__ for introspection
new_annotations = {}
for param in new_signature.parameters.values():
if param.annotation != inspect.Parameter.empty:
new_annotations[param.name] = param.annotation
# Add return type annotation if it exists
if new_signature.return_annotation != inspect.Parameter.empty:
new_annotations["return"] = new_signature.return_annotation
wrapper.__annotations__ = new_annotations
return wrapper
if __name__ == "__main__":
from pydantic._internal._typing_extra import get_function_type_hints
from qdrant_client import models
def find(
query: Annotated[str, Field(description="What to search for")],
collection_name: Annotated[
str, Field(description="The collection to search in")
],
query_filter: Optional[models.Filter] = None,
) -> list[str]:
print("query", query)
print("collection_name", collection_name)
print("query_filter", query_filter)
return ["mypy rules"]
wrapped_find = wrap_filters(
find,
{
"color": FilterableField(
name="color",
description="The color of the object",
field_type="keyword",
condition="==",
),
"size": FilterableField(
name="size",
description="The size of the object",
field_type="keyword",
condition="==",
required=True,
),
},
)
wrapped_find(query="dress", collection_name="test", color="red")
print("get_function_type_hints(find)", get_function_type_hints(find))
print(
"get_function_type_hints(wrapped_find)", get_function_type_hints(wrapped_find)
)

View File

@@ -1,12 +1,16 @@
import json import json
import logging import logging
from typing import Any, List, Optional from typing import Annotated, Any, List, Optional
from fastmcp import Context, FastMCP from fastmcp import Context, FastMCP
from pydantic import Field
from qdrant_client import models
from mcp_server_qdrant.common.filters import make_indexes
from mcp_server_qdrant.common.func_tools import make_partial_function from mcp_server_qdrant.common.func_tools import make_partial_function
from mcp_server_qdrant.common.wrap_filters import wrap_filters
from mcp_server_qdrant.embeddings.factory import create_embedding_provider from mcp_server_qdrant.embeddings.factory import create_embedding_provider
from mcp_server_qdrant.qdrant import Entry, Metadata, QdrantConnector from mcp_server_qdrant.qdrant import ArbitraryFilter, Entry, Metadata, QdrantConnector
from mcp_server_qdrant.settings import ( from mcp_server_qdrant.settings import (
EmbeddingProviderSettings, EmbeddingProviderSettings,
QdrantSettings, QdrantSettings,
@@ -43,6 +47,7 @@ class QdrantMCPServer(FastMCP):
qdrant_settings.collection_name, qdrant_settings.collection_name,
self.embedding_provider, self.embedding_provider,
qdrant_settings.local_path, qdrant_settings.local_path,
make_indexes(qdrant_settings.filterable_fields_dict()),
) )
super().__init__(name=name, instructions=instructions, **settings) super().__init__(name=name, instructions=instructions, **settings)
@@ -63,12 +68,19 @@ class QdrantMCPServer(FastMCP):
async def store( async def store(
ctx: Context, ctx: Context,
information: str, information: Annotated[str, Field(description="Text to store")],
collection_name: str, collection_name: Annotated[
str, Field(description="The collection to store the information in")
],
# The `metadata` parameter is defined as non-optional, but it can be None. # The `metadata` parameter is defined as non-optional, but it can be None.
# If we set it to be optional, some of the MCP clients, like Cursor, cannot # If we set it to be optional, some of the MCP clients, like Cursor, cannot
# handle the optional parameter correctly. # handle the optional parameter correctly.
metadata: Optional[Metadata] = None, # type: ignore metadata: Annotated[
Optional[Metadata],
Field(
description="Extra metadata stored along with memorised information. Any json is accepted."
),
] = None,
) -> str: ) -> str:
""" """
Store some information in Qdrant. Store some information in Qdrant.
@@ -90,8 +102,11 @@ class QdrantMCPServer(FastMCP):
async def find( async def find(
ctx: Context, ctx: Context,
query: str, query: Annotated[str, Field(description="What to search for")],
collection_name: str, collection_name: Annotated[
str, Field(description="The collection to search in")
],
query_filter: Optional[ArbitraryFilter] = None,
) -> List[str]: ) -> List[str]:
""" """
Find memories in Qdrant. Find memories in Qdrant.
@@ -101,6 +116,12 @@ class QdrantMCPServer(FastMCP):
the default collection is used. the default collection is used.
:return: A list of entries found. :return: A list of entries found.
""" """
# Log query_filter
await ctx.debug(f"Query filter: {query_filter}")
query_filter = models.Filter(**query_filter) if query_filter else None
await ctx.debug(f"Finding results for query {query}") await ctx.debug(f"Finding results for query {query}")
if collection_name: if collection_name:
await ctx.debug( await ctx.debug(
@@ -111,6 +132,7 @@ class QdrantMCPServer(FastMCP):
query, query,
collection_name=collection_name, collection_name=collection_name,
limit=self.qdrant_settings.search_limit, limit=self.qdrant_settings.search_limit,
query_filter=query_filter,
) )
if not entries: if not entries:
return [f"No information found for the query '{query}'"] return [f"No information found for the query '{query}'"]
@@ -124,6 +146,15 @@ class QdrantMCPServer(FastMCP):
find_foo = find find_foo = find
store_foo = store store_foo = store
filterable_conditions = (
self.qdrant_settings.filterable_fields_dict_with_conditions()
)
if len(filterable_conditions) > 0:
find_foo = wrap_filters(find_foo, filterable_conditions)
elif not self.qdrant_settings.allow_arbitrary_filter:
find_foo = make_partial_function(find_foo, {"query_filter": None})
if self.qdrant_settings.collection_name: if self.qdrant_settings.collection_name:
find_foo = make_partial_function( find_foo = make_partial_function(
find_foo, {"collection_name": self.qdrant_settings.collection_name} find_foo, {"collection_name": self.qdrant_settings.collection_name}

View File

@@ -6,11 +6,14 @@ from pydantic import BaseModel
from qdrant_client import AsyncQdrantClient, models from qdrant_client import AsyncQdrantClient, models
from mcp_server_qdrant.embeddings.base import EmbeddingProvider from mcp_server_qdrant.embeddings.base import EmbeddingProvider
from mcp_server_qdrant.settings import METADATA_PATH
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
Metadata = Dict[str, Any] Metadata = Dict[str, Any]
ArbitraryFilter = Dict[str, Any]
class Entry(BaseModel): class Entry(BaseModel):
""" """
@@ -39,6 +42,7 @@ class QdrantConnector:
collection_name: Optional[str], collection_name: Optional[str],
embedding_provider: EmbeddingProvider, embedding_provider: EmbeddingProvider,
qdrant_local_path: Optional[str] = None, qdrant_local_path: Optional[str] = None,
field_indexes: Optional[dict[str, models.PayloadSchemaType]] = None,
): ):
self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None
self._qdrant_api_key = qdrant_api_key self._qdrant_api_key = qdrant_api_key
@@ -47,6 +51,7 @@ class QdrantConnector:
self._client = AsyncQdrantClient( self._client = AsyncQdrantClient(
location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path
) )
self._field_indexes = field_indexes
async def get_collection_names(self) -> list[str]: async def get_collection_names(self) -> list[str]:
""" """
@@ -74,7 +79,7 @@ class QdrantConnector:
# Add to Qdrant # Add to Qdrant
vector_name = self._embedding_provider.get_vector_name() vector_name = self._embedding_provider.get_vector_name()
payload = {"document": entry.content, "metadata": entry.metadata} payload = {"document": entry.content, METADATA_PATH: entry.metadata}
await self._client.upsert( await self._client.upsert(
collection_name=collection_name, collection_name=collection_name,
points=[ points=[
@@ -87,7 +92,12 @@ class QdrantConnector:
) )
async def search( async def search(
self, query: str, *, collection_name: Optional[str] = None, limit: int = 10 self,
query: str,
*,
collection_name: Optional[str] = None,
limit: int = 10,
query_filter: Optional[models.Filter] = None,
) -> list[Entry]: ) -> list[Entry]:
""" """
Find points in the Qdrant collection. If there are no entries found, an empty list is returned. Find points in the Qdrant collection. If there are no entries found, an empty list is returned.
@@ -115,6 +125,7 @@ class QdrantConnector:
query=query_vector, query=query_vector,
using=vector_name, using=vector_name,
limit=limit, limit=limit,
query_filter=query_filter,
) )
return [ return [
@@ -146,3 +157,13 @@ class QdrantConnector:
) )
}, },
) )
# Create payload indexes if configured
if self._field_indexes:
for field_name, field_type in self._field_indexes.items():
await self._client.create_payload_index(
collection_name=collection_name,
field_name=field_name,
field_schema=field_type,
)

View File

@@ -1,6 +1,6 @@
from typing import Optional from typing import Literal, Optional
from pydantic import Field from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
from mcp_server_qdrant.embeddings.types import EmbeddingProviderType from mcp_server_qdrant.embeddings.types import EmbeddingProviderType
@@ -15,6 +15,8 @@ DEFAULT_TOOL_FIND_DESCRIPTION = (
" - Get some personal information about the user" " - Get some personal information about the user"
) )
METADATA_PATH = "metadata"
class ToolSettings(BaseSettings): class ToolSettings(BaseSettings):
""" """
@@ -46,6 +48,29 @@ class EmbeddingProviderSettings(BaseSettings):
) )
class FilterableField(BaseModel):
name: str = Field(description="The name of the field payload field to filter on")
description: str = Field(
description="A description for the field used in the tool description"
)
field_type: Literal["keyword", "integer", "float", "boolean"] = Field(
description="The type of the field"
)
condition: Optional[Literal["==", "!=", ">", ">=", "<", "<=", "any", "except"]] = (
Field(
default=None,
description=(
"The condition to use for the filter. If not provided, the field will be indexed, but no "
"filter argument will be exposed to MCP tool."
),
)
)
required: bool = Field(
default=False,
description="Whether the field is required for the filter.",
)
class QdrantSettings(BaseSettings): class QdrantSettings(BaseSettings):
""" """
Configuration for the Qdrant connector. Configuration for the Qdrant connector.
@@ -61,3 +86,23 @@ class QdrantSettings(BaseSettings):
) )
search_limit: int = Field(default=10, validation_alias="QDRANT_SEARCH_LIMIT") search_limit: int = Field(default=10, validation_alias="QDRANT_SEARCH_LIMIT")
read_only: bool = Field(default=False, validation_alias="QDRANT_READ_ONLY") read_only: bool = Field(default=False, validation_alias="QDRANT_READ_ONLY")
filterable_fields: Optional[list[FilterableField]] = Field(default=None)
allow_arbitrary_filter: bool = Field(
default=False, validation_alias="QDRANT_ALLOW_ARBITRARY_FILTER"
)
def filterable_fields_dict(self) -> dict[str, FilterableField]:
if self.filterable_fields is None:
return {}
return {field.name: field for field in self.filterable_fields}
def filterable_fields_dict_with_conditions(self) -> dict[str, FilterableField]:
if self.filterable_fields is None:
return {}
return {
field.name: field
for field in self.filterable_fields
if field.condition is not None
}

1514
uv.lock generated

File diff suppressed because it is too large Load Diff