Configurable filters (#58)
* add configurable filters * hello to hr department * rollback debug code * add arbitrary filter * dont consider fields without conditions * in and except condition * proper annotation types for optional and list fields * fix types import * skip non-required fields * fix: fix match except condition, fix boolean filter * fix: apply ruff * fix: make condition optional in filterable field * fix: do not set default value for required fields (#63) * fix: do not set default value for required fields * fix: temp fix fastmcp to <2.8.0 cause of the breaking changes in the api * fix: add missing changes to pyproject.toml * fix: downgrade fastmcp even further to <2.7.0 --------- Co-authored-by: George Panchuk <george.panchuk@qdrant.tech> Co-authored-by: George <panchuk.george@outlook.com>
This commit is contained in:
@@ -9,7 +9,7 @@ dependencies = [
|
|||||||
"fastembed>=0.6.0",
|
"fastembed>=0.6.0",
|
||||||
"qdrant-client>=1.12.0",
|
"qdrant-client>=1.12.0",
|
||||||
"pydantic>=2.10.6",
|
"pydantic>=2.10.6",
|
||||||
"fastmcp>=2.5.1",
|
"fastmcp>=2.5.1,<2.7.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
@@ -18,6 +18,7 @@ build-backend = "hatchling.build"
|
|||||||
|
|
||||||
[tool.uv]
|
[tool.uv]
|
||||||
dev-dependencies = [
|
dev-dependencies = [
|
||||||
|
"ipdb>=0.13.13",
|
||||||
"isort>=6.0.1",
|
"isort>=6.0.1",
|
||||||
"mypy>=1.9.0",
|
"mypy>=1.9.0",
|
||||||
"pre-commit>=4.1.0",
|
"pre-commit>=4.1.0",
|
||||||
|
|||||||
194
src/mcp_server_qdrant/common/filters.py
Normal file
194
src/mcp_server_qdrant/common/filters.py
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from qdrant_client import models
|
||||||
|
|
||||||
|
from mcp_server_qdrant.qdrant import ArbitraryFilter
|
||||||
|
from mcp_server_qdrant.settings import METADATA_PATH, FilterableField
|
||||||
|
|
||||||
|
|
||||||
|
def make_filter(
|
||||||
|
filterable_fields: dict[str, FilterableField], values: dict[str, Any]
|
||||||
|
) -> ArbitraryFilter:
|
||||||
|
must_conditions = []
|
||||||
|
must_not_conditions = []
|
||||||
|
|
||||||
|
for raw_field_name, field_value in values.items():
|
||||||
|
if raw_field_name not in filterable_fields:
|
||||||
|
raise ValueError(f"Field {raw_field_name} is not a filterable field")
|
||||||
|
|
||||||
|
field = filterable_fields[raw_field_name]
|
||||||
|
|
||||||
|
if field_value is None:
|
||||||
|
if field.required:
|
||||||
|
raise ValueError(f"Field {raw_field_name} is required")
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
field_name = f"{METADATA_PATH}.{raw_field_name}"
|
||||||
|
|
||||||
|
if field.field_type == "keyword":
|
||||||
|
if field.condition == "==":
|
||||||
|
must_conditions.append(
|
||||||
|
models.FieldCondition(
|
||||||
|
key=field_name, match=models.MatchValue(value=field_value)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif field.condition == "!=":
|
||||||
|
must_not_conditions.append(
|
||||||
|
models.FieldCondition(
|
||||||
|
key=field_name, match=models.MatchValue(value=field_value)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif field.condition == "any":
|
||||||
|
must_conditions.append(
|
||||||
|
models.FieldCondition(
|
||||||
|
key=field_name, match=models.MatchAny(any=field_value)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif field.condition == "except":
|
||||||
|
must_conditions.append(
|
||||||
|
models.FieldCondition(
|
||||||
|
key=field_name,
|
||||||
|
match=models.MatchExcept(**{"except": field_value}),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif field.condition is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid condition {field.condition} for keyword field {field_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif field.field_type == "integer":
|
||||||
|
if field.condition == "==":
|
||||||
|
must_conditions.append(
|
||||||
|
models.FieldCondition(
|
||||||
|
key=field_name, match=models.MatchValue(value=field_value)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif field.condition == "!=":
|
||||||
|
must_not_conditions.append(
|
||||||
|
models.FieldCondition(
|
||||||
|
key=field_name, match=models.MatchValue(value=field_value)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif field.condition == ">":
|
||||||
|
must_conditions.append(
|
||||||
|
models.FieldCondition(
|
||||||
|
key=field_name, range=models.Range(gt=field_value)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif field.condition == ">=":
|
||||||
|
must_conditions.append(
|
||||||
|
models.FieldCondition(
|
||||||
|
key=field_name, range=models.Range(gte=field_value)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif field.condition == "<":
|
||||||
|
must_conditions.append(
|
||||||
|
models.FieldCondition(
|
||||||
|
key=field_name, range=models.Range(lt=field_value)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif field.condition == "<=":
|
||||||
|
must_conditions.append(
|
||||||
|
models.FieldCondition(
|
||||||
|
key=field_name, range=models.Range(lte=field_value)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif field.condition == "any":
|
||||||
|
must_conditions.append(
|
||||||
|
models.FieldCondition(
|
||||||
|
key=field_name, match=models.MatchAny(any=field_value)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif field.condition == "except":
|
||||||
|
must_conditions.append(
|
||||||
|
models.FieldCondition(
|
||||||
|
key=field_name,
|
||||||
|
match=models.MatchExcept(**{"except": field_value}),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif field.condition is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid condition {field.condition} for integer field {field_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif field.field_type == "float":
|
||||||
|
# For float values, we only support range comparisons
|
||||||
|
if field.condition == ">":
|
||||||
|
must_conditions.append(
|
||||||
|
models.FieldCondition(
|
||||||
|
key=field_name, range=models.Range(gt=field_value)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif field.condition == ">=":
|
||||||
|
must_conditions.append(
|
||||||
|
models.FieldCondition(
|
||||||
|
key=field_name, range=models.Range(gte=field_value)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif field.condition == "<":
|
||||||
|
must_conditions.append(
|
||||||
|
models.FieldCondition(
|
||||||
|
key=field_name, range=models.Range(lt=field_value)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif field.condition == "<=":
|
||||||
|
must_conditions.append(
|
||||||
|
models.FieldCondition(
|
||||||
|
key=field_name, range=models.Range(lte=field_value)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif field.condition is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid condition {field.condition} for float field {field_name}. "
|
||||||
|
"Only range comparisons (>, >=, <, <=) are supported for float values."
|
||||||
|
)
|
||||||
|
|
||||||
|
elif field.field_type == "boolean":
|
||||||
|
if field.condition == "==":
|
||||||
|
must_conditions.append(
|
||||||
|
models.FieldCondition(
|
||||||
|
key=field_name, match=models.MatchValue(value=field_value)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif field.condition == "!=":
|
||||||
|
must_not_conditions.append(
|
||||||
|
models.FieldCondition(
|
||||||
|
key=field_name, match=models.MatchValue(value=field_value)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif field.condition is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid condition {field.condition} for boolean field {field_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported field type {field.field_type} for field {field_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return models.Filter(
|
||||||
|
must=must_conditions, must_not=must_not_conditions
|
||||||
|
).model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
def make_indexes(
|
||||||
|
filterable_fields: dict[str, FilterableField],
|
||||||
|
) -> dict[str, models.PayloadSchemaType]:
|
||||||
|
indexes = {}
|
||||||
|
|
||||||
|
for field_name, field in filterable_fields.items():
|
||||||
|
if field.field_type == "keyword":
|
||||||
|
indexes[f"{METADATA_PATH}.{field_name}"] = models.PayloadSchemaType.KEYWORD
|
||||||
|
elif field.field_type == "integer":
|
||||||
|
indexes[f"{METADATA_PATH}.{field_name}"] = models.PayloadSchemaType.INTEGER
|
||||||
|
elif field.field_type == "float":
|
||||||
|
indexes[f"{METADATA_PATH}.{field_name}"] = models.PayloadSchemaType.FLOAT
|
||||||
|
elif field.field_type == "boolean":
|
||||||
|
indexes[f"{METADATA_PATH}.{field_name}"] = models.PayloadSchemaType.BOOL
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported field type {field.field_type} for field {field_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return indexes
|
||||||
150
src/mcp_server_qdrant/common/wrap_filters.py
Normal file
150
src/mcp_server_qdrant/common/wrap_filters.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
import inspect
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Annotated, Callable, Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from mcp_server_qdrant.common.filters import make_filter
|
||||||
|
from mcp_server_qdrant.settings import FilterableField
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_filters(
|
||||||
|
original_func: Callable, filterable_fields: dict[str, FilterableField]
|
||||||
|
) -> Callable:
|
||||||
|
"""
|
||||||
|
Wraps the original_func function: replaces `filter` parameter with multiple parameters defined by `filterable_fields`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sig = inspect.signature(original_func)
|
||||||
|
|
||||||
|
@wraps(original_func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# Start with fixed values
|
||||||
|
filter_values = {}
|
||||||
|
|
||||||
|
for field_name in filterable_fields:
|
||||||
|
if field_name in kwargs:
|
||||||
|
filter_values[field_name] = kwargs.pop(field_name)
|
||||||
|
|
||||||
|
query_filter = make_filter(filterable_fields, filter_values)
|
||||||
|
|
||||||
|
return original_func(**kwargs, query_filter=query_filter)
|
||||||
|
|
||||||
|
# Replace `query_filter` signature with parameters from `filterable_fields`
|
||||||
|
|
||||||
|
param_names = []
|
||||||
|
|
||||||
|
for param_name in sig.parameters:
|
||||||
|
if param_name == "query_filter":
|
||||||
|
continue
|
||||||
|
param_names.append(param_name)
|
||||||
|
|
||||||
|
new_params = [sig.parameters[param_name] for param_name in param_names]
|
||||||
|
required_new_params = []
|
||||||
|
optional_new_params = []
|
||||||
|
|
||||||
|
# Create a new signature parameters from `filterable_fields`
|
||||||
|
for field in filterable_fields.values():
|
||||||
|
field_name = field.name
|
||||||
|
field_type: type
|
||||||
|
if field.field_type == "keyword":
|
||||||
|
field_type = str
|
||||||
|
elif field.field_type == "integer":
|
||||||
|
field_type = int
|
||||||
|
elif field.field_type == "float":
|
||||||
|
field_type = float
|
||||||
|
elif field.field_type == "boolean":
|
||||||
|
field_type = bool
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported field type: {field.field_type}")
|
||||||
|
|
||||||
|
if field.condition in {"any", "except"}:
|
||||||
|
if field_type not in {str, int}:
|
||||||
|
raise ValueError(
|
||||||
|
f'Only "keyword" and "integer" types are supported for "{field.condition}" condition'
|
||||||
|
)
|
||||||
|
field_type = list[field_type] # type: ignore
|
||||||
|
|
||||||
|
if field.required:
|
||||||
|
annotation = Annotated[field_type, Field(description=field.description)] # type: ignore
|
||||||
|
parameter = inspect.Parameter(
|
||||||
|
name=field_name,
|
||||||
|
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||||
|
annotation=annotation,
|
||||||
|
)
|
||||||
|
required_new_params.append(parameter)
|
||||||
|
else:
|
||||||
|
annotation = Annotated[ # type: ignore
|
||||||
|
Optional[field_type], Field(description=field.description)
|
||||||
|
]
|
||||||
|
parameter = inspect.Parameter(
|
||||||
|
name=field_name,
|
||||||
|
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||||
|
default=None,
|
||||||
|
annotation=annotation,
|
||||||
|
)
|
||||||
|
optional_new_params.append(parameter)
|
||||||
|
|
||||||
|
new_params.extend(required_new_params)
|
||||||
|
new_params.extend(optional_new_params)
|
||||||
|
|
||||||
|
# Set the new __signature__ for introspection
|
||||||
|
new_signature = sig.replace(parameters=new_params)
|
||||||
|
wrapper.__signature__ = new_signature # type: ignore
|
||||||
|
|
||||||
|
# Set the new __annotations__ for introspection
|
||||||
|
new_annotations = {}
|
||||||
|
for param in new_signature.parameters.values():
|
||||||
|
if param.annotation != inspect.Parameter.empty:
|
||||||
|
new_annotations[param.name] = param.annotation
|
||||||
|
|
||||||
|
# Add return type annotation if it exists
|
||||||
|
if new_signature.return_annotation != inspect.Parameter.empty:
|
||||||
|
new_annotations["return"] = new_signature.return_annotation
|
||||||
|
|
||||||
|
wrapper.__annotations__ = new_annotations
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from pydantic._internal._typing_extra import get_function_type_hints
|
||||||
|
from qdrant_client import models
|
||||||
|
|
||||||
|
def find(
|
||||||
|
query: Annotated[str, Field(description="What to search for")],
|
||||||
|
collection_name: Annotated[
|
||||||
|
str, Field(description="The collection to search in")
|
||||||
|
],
|
||||||
|
query_filter: Optional[models.Filter] = None,
|
||||||
|
) -> list[str]:
|
||||||
|
print("query", query)
|
||||||
|
print("collection_name", collection_name)
|
||||||
|
print("query_filter", query_filter)
|
||||||
|
return ["mypy rules"]
|
||||||
|
|
||||||
|
wrapped_find = wrap_filters(
|
||||||
|
find,
|
||||||
|
{
|
||||||
|
"color": FilterableField(
|
||||||
|
name="color",
|
||||||
|
description="The color of the object",
|
||||||
|
field_type="keyword",
|
||||||
|
condition="==",
|
||||||
|
),
|
||||||
|
"size": FilterableField(
|
||||||
|
name="size",
|
||||||
|
description="The size of the object",
|
||||||
|
field_type="keyword",
|
||||||
|
condition="==",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapped_find(query="dress", collection_name="test", color="red")
|
||||||
|
|
||||||
|
print("get_function_type_hints(find)", get_function_type_hints(find))
|
||||||
|
print(
|
||||||
|
"get_function_type_hints(wrapped_find)", get_function_type_hints(wrapped_find)
|
||||||
|
)
|
||||||
@@ -1,12 +1,16 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, List, Optional
|
from typing import Annotated, Any, List, Optional
|
||||||
|
|
||||||
from fastmcp import Context, FastMCP
|
from fastmcp import Context, FastMCP
|
||||||
|
from pydantic import Field
|
||||||
|
from qdrant_client import models
|
||||||
|
|
||||||
|
from mcp_server_qdrant.common.filters import make_indexes
|
||||||
from mcp_server_qdrant.common.func_tools import make_partial_function
|
from mcp_server_qdrant.common.func_tools import make_partial_function
|
||||||
|
from mcp_server_qdrant.common.wrap_filters import wrap_filters
|
||||||
from mcp_server_qdrant.embeddings.factory import create_embedding_provider
|
from mcp_server_qdrant.embeddings.factory import create_embedding_provider
|
||||||
from mcp_server_qdrant.qdrant import Entry, Metadata, QdrantConnector
|
from mcp_server_qdrant.qdrant import ArbitraryFilter, Entry, Metadata, QdrantConnector
|
||||||
from mcp_server_qdrant.settings import (
|
from mcp_server_qdrant.settings import (
|
||||||
EmbeddingProviderSettings,
|
EmbeddingProviderSettings,
|
||||||
QdrantSettings,
|
QdrantSettings,
|
||||||
@@ -43,6 +47,7 @@ class QdrantMCPServer(FastMCP):
|
|||||||
qdrant_settings.collection_name,
|
qdrant_settings.collection_name,
|
||||||
self.embedding_provider,
|
self.embedding_provider,
|
||||||
qdrant_settings.local_path,
|
qdrant_settings.local_path,
|
||||||
|
make_indexes(qdrant_settings.filterable_fields_dict()),
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(name=name, instructions=instructions, **settings)
|
super().__init__(name=name, instructions=instructions, **settings)
|
||||||
@@ -63,12 +68,19 @@ class QdrantMCPServer(FastMCP):
|
|||||||
|
|
||||||
async def store(
|
async def store(
|
||||||
ctx: Context,
|
ctx: Context,
|
||||||
information: str,
|
information: Annotated[str, Field(description="Text to store")],
|
||||||
collection_name: str,
|
collection_name: Annotated[
|
||||||
|
str, Field(description="The collection to store the information in")
|
||||||
|
],
|
||||||
# The `metadata` parameter is defined as non-optional, but it can be None.
|
# The `metadata` parameter is defined as non-optional, but it can be None.
|
||||||
# If we set it to be optional, some of the MCP clients, like Cursor, cannot
|
# If we set it to be optional, some of the MCP clients, like Cursor, cannot
|
||||||
# handle the optional parameter correctly.
|
# handle the optional parameter correctly.
|
||||||
metadata: Optional[Metadata] = None, # type: ignore
|
metadata: Annotated[
|
||||||
|
Optional[Metadata],
|
||||||
|
Field(
|
||||||
|
description="Extra metadata stored along with memorised information. Any json is accepted."
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Store some information in Qdrant.
|
Store some information in Qdrant.
|
||||||
@@ -90,8 +102,11 @@ class QdrantMCPServer(FastMCP):
|
|||||||
|
|
||||||
async def find(
|
async def find(
|
||||||
ctx: Context,
|
ctx: Context,
|
||||||
query: str,
|
query: Annotated[str, Field(description="What to search for")],
|
||||||
collection_name: str,
|
collection_name: Annotated[
|
||||||
|
str, Field(description="The collection to search in")
|
||||||
|
],
|
||||||
|
query_filter: Optional[ArbitraryFilter] = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Find memories in Qdrant.
|
Find memories in Qdrant.
|
||||||
@@ -101,6 +116,12 @@ class QdrantMCPServer(FastMCP):
|
|||||||
the default collection is used.
|
the default collection is used.
|
||||||
:return: A list of entries found.
|
:return: A list of entries found.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Log query_filter
|
||||||
|
await ctx.debug(f"Query filter: {query_filter}")
|
||||||
|
|
||||||
|
query_filter = models.Filter(**query_filter) if query_filter else None
|
||||||
|
|
||||||
await ctx.debug(f"Finding results for query {query}")
|
await ctx.debug(f"Finding results for query {query}")
|
||||||
if collection_name:
|
if collection_name:
|
||||||
await ctx.debug(
|
await ctx.debug(
|
||||||
@@ -111,6 +132,7 @@ class QdrantMCPServer(FastMCP):
|
|||||||
query,
|
query,
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
limit=self.qdrant_settings.search_limit,
|
limit=self.qdrant_settings.search_limit,
|
||||||
|
query_filter=query_filter,
|
||||||
)
|
)
|
||||||
if not entries:
|
if not entries:
|
||||||
return [f"No information found for the query '{query}'"]
|
return [f"No information found for the query '{query}'"]
|
||||||
@@ -124,6 +146,15 @@ class QdrantMCPServer(FastMCP):
|
|||||||
find_foo = find
|
find_foo = find
|
||||||
store_foo = store
|
store_foo = store
|
||||||
|
|
||||||
|
filterable_conditions = (
|
||||||
|
self.qdrant_settings.filterable_fields_dict_with_conditions()
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(filterable_conditions) > 0:
|
||||||
|
find_foo = wrap_filters(find_foo, filterable_conditions)
|
||||||
|
elif not self.qdrant_settings.allow_arbitrary_filter:
|
||||||
|
find_foo = make_partial_function(find_foo, {"query_filter": None})
|
||||||
|
|
||||||
if self.qdrant_settings.collection_name:
|
if self.qdrant_settings.collection_name:
|
||||||
find_foo = make_partial_function(
|
find_foo = make_partial_function(
|
||||||
find_foo, {"collection_name": self.qdrant_settings.collection_name}
|
find_foo, {"collection_name": self.qdrant_settings.collection_name}
|
||||||
|
|||||||
@@ -6,11 +6,14 @@ from pydantic import BaseModel
|
|||||||
from qdrant_client import AsyncQdrantClient, models
|
from qdrant_client import AsyncQdrantClient, models
|
||||||
|
|
||||||
from mcp_server_qdrant.embeddings.base import EmbeddingProvider
|
from mcp_server_qdrant.embeddings.base import EmbeddingProvider
|
||||||
|
from mcp_server_qdrant.settings import METADATA_PATH
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
Metadata = Dict[str, Any]
|
Metadata = Dict[str, Any]
|
||||||
|
|
||||||
|
ArbitraryFilter = Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class Entry(BaseModel):
|
class Entry(BaseModel):
|
||||||
"""
|
"""
|
||||||
@@ -39,6 +42,7 @@ class QdrantConnector:
|
|||||||
collection_name: Optional[str],
|
collection_name: Optional[str],
|
||||||
embedding_provider: EmbeddingProvider,
|
embedding_provider: EmbeddingProvider,
|
||||||
qdrant_local_path: Optional[str] = None,
|
qdrant_local_path: Optional[str] = None,
|
||||||
|
field_indexes: Optional[dict[str, models.PayloadSchemaType]] = None,
|
||||||
):
|
):
|
||||||
self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None
|
self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None
|
||||||
self._qdrant_api_key = qdrant_api_key
|
self._qdrant_api_key = qdrant_api_key
|
||||||
@@ -47,6 +51,7 @@ class QdrantConnector:
|
|||||||
self._client = AsyncQdrantClient(
|
self._client = AsyncQdrantClient(
|
||||||
location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path
|
location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path
|
||||||
)
|
)
|
||||||
|
self._field_indexes = field_indexes
|
||||||
|
|
||||||
async def get_collection_names(self) -> list[str]:
|
async def get_collection_names(self) -> list[str]:
|
||||||
"""
|
"""
|
||||||
@@ -74,7 +79,7 @@ class QdrantConnector:
|
|||||||
|
|
||||||
# Add to Qdrant
|
# Add to Qdrant
|
||||||
vector_name = self._embedding_provider.get_vector_name()
|
vector_name = self._embedding_provider.get_vector_name()
|
||||||
payload = {"document": entry.content, "metadata": entry.metadata}
|
payload = {"document": entry.content, METADATA_PATH: entry.metadata}
|
||||||
await self._client.upsert(
|
await self._client.upsert(
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
points=[
|
points=[
|
||||||
@@ -87,7 +92,12 @@ class QdrantConnector:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
self, query: str, *, collection_name: Optional[str] = None, limit: int = 10
|
self,
|
||||||
|
query: str,
|
||||||
|
*,
|
||||||
|
collection_name: Optional[str] = None,
|
||||||
|
limit: int = 10,
|
||||||
|
query_filter: Optional[models.Filter] = None,
|
||||||
) -> list[Entry]:
|
) -> list[Entry]:
|
||||||
"""
|
"""
|
||||||
Find points in the Qdrant collection. If there are no entries found, an empty list is returned.
|
Find points in the Qdrant collection. If there are no entries found, an empty list is returned.
|
||||||
@@ -115,6 +125,7 @@ class QdrantConnector:
|
|||||||
query=query_vector,
|
query=query_vector,
|
||||||
using=vector_name,
|
using=vector_name,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
query_filter=query_filter,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
@@ -146,3 +157,13 @@ class QdrantConnector:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create payload indexes if configured
|
||||||
|
|
||||||
|
if self._field_indexes:
|
||||||
|
for field_name, field_type in self._field_indexes.items():
|
||||||
|
await self._client.create_payload_index(
|
||||||
|
collection_name=collection_name,
|
||||||
|
field_name=field_name,
|
||||||
|
field_schema=field_type,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import BaseModel, Field
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
from mcp_server_qdrant.embeddings.types import EmbeddingProviderType
|
from mcp_server_qdrant.embeddings.types import EmbeddingProviderType
|
||||||
@@ -15,6 +15,8 @@ DEFAULT_TOOL_FIND_DESCRIPTION = (
|
|||||||
" - Get some personal information about the user"
|
" - Get some personal information about the user"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
METADATA_PATH = "metadata"
|
||||||
|
|
||||||
|
|
||||||
class ToolSettings(BaseSettings):
|
class ToolSettings(BaseSettings):
|
||||||
"""
|
"""
|
||||||
@@ -46,6 +48,29 @@ class EmbeddingProviderSettings(BaseSettings):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FilterableField(BaseModel):
|
||||||
|
name: str = Field(description="The name of the field payload field to filter on")
|
||||||
|
description: str = Field(
|
||||||
|
description="A description for the field used in the tool description"
|
||||||
|
)
|
||||||
|
field_type: Literal["keyword", "integer", "float", "boolean"] = Field(
|
||||||
|
description="The type of the field"
|
||||||
|
)
|
||||||
|
condition: Optional[Literal["==", "!=", ">", ">=", "<", "<=", "any", "except"]] = (
|
||||||
|
Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"The condition to use for the filter. If not provided, the field will be indexed, but no "
|
||||||
|
"filter argument will be exposed to MCP tool."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
required: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Whether the field is required for the filter.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class QdrantSettings(BaseSettings):
|
class QdrantSettings(BaseSettings):
|
||||||
"""
|
"""
|
||||||
Configuration for the Qdrant connector.
|
Configuration for the Qdrant connector.
|
||||||
@@ -61,3 +86,23 @@ class QdrantSettings(BaseSettings):
|
|||||||
)
|
)
|
||||||
search_limit: int = Field(default=10, validation_alias="QDRANT_SEARCH_LIMIT")
|
search_limit: int = Field(default=10, validation_alias="QDRANT_SEARCH_LIMIT")
|
||||||
read_only: bool = Field(default=False, validation_alias="QDRANT_READ_ONLY")
|
read_only: bool = Field(default=False, validation_alias="QDRANT_READ_ONLY")
|
||||||
|
|
||||||
|
filterable_fields: Optional[list[FilterableField]] = Field(default=None)
|
||||||
|
|
||||||
|
allow_arbitrary_filter: bool = Field(
|
||||||
|
default=False, validation_alias="QDRANT_ALLOW_ARBITRARY_FILTER"
|
||||||
|
)
|
||||||
|
|
||||||
|
def filterable_fields_dict(self) -> dict[str, FilterableField]:
|
||||||
|
if self.filterable_fields is None:
|
||||||
|
return {}
|
||||||
|
return {field.name: field for field in self.filterable_fields}
|
||||||
|
|
||||||
|
def filterable_fields_dict_with_conditions(self) -> dict[str, FilterableField]:
|
||||||
|
if self.filterable_fields is None:
|
||||||
|
return {}
|
||||||
|
return {
|
||||||
|
field.name: field
|
||||||
|
for field in self.filterable_fields
|
||||||
|
if field.condition is not None
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user