Compare commits
10 Commits
244139beb5
...
e4ec69b2da
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e4ec69b2da | ||
|
|
860ab93a96 | ||
|
|
20825bca92 | ||
|
|
8d6f388543 | ||
|
|
59fca57369 | ||
|
|
5a7237389e | ||
|
|
598ed6fa72 | ||
|
|
3fdb4c4b1b | ||
|
|
28bf298a32 | ||
|
|
b657656363 |
6
.github/workflows/pre-commit.yaml
vendored
6
.github/workflows/pre-commit.yaml
vendored
@@ -9,8 +9,8 @@ jobs:
|
||||
main:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
- uses: actions/checkout@f43a0e5ff2bd294095638e18286ca9a3d1956744 # v3.6.0
|
||||
- uses: actions/setup-python@7f4fc3e22c37d6ff65e88745f38bd3157c663f7c # v4.9.1
|
||||
with:
|
||||
python-version: 3.x
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
||||
|
||||
4
.github/workflows/pypi-publish.yaml
vendored
4
.github/workflows/pypi-publish.yaml
vendored
@@ -24,10 +24,10 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@ee0669bd1cc54295c223e0bb666b733df41de1c5 # v2.7.0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
uses: actions/setup-python@e9aba2c848f5ebd159c070c61ea2c4e2b122355e # v2.3.4
|
||||
with:
|
||||
python-version: '3.10.x'
|
||||
|
||||
|
||||
4
.github/workflows/pytest.yaml
vendored
4
.github/workflows/pytest.yaml
vendored
@@ -16,10 +16,10 @@ jobs:
|
||||
name: Python ${{ matrix.python-version }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@f43a0e5ff2bd294095638e18286ca9a3d1956744 # v3.6.0
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@7f4fc3e22c37d6ff65e88745f38bd3157c663f7c # v4.9.1
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ important ones are listed below:
|
||||
|---------------------------------------|-----------------------------------------------------------|---------------|
|
||||
| `FASTMCP_DEBUG` | Enable debug mode | `false` |
|
||||
| `FASTMCP_LOG_LEVEL` | Set logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) | `INFO` |
|
||||
| `FASTMCP_HOST` | Host address to bind the server to | `0.0.0.0` |
|
||||
| `FASTMCP_HOST` | Host address to bind the server to | `127.0.0.1` |
|
||||
| `FASTMCP_PORT` | Port to run the server on | `8000` |
|
||||
| `FASTMCP_WARN_ON_DUPLICATE_RESOURCES` | Show warnings for duplicate resources | `true` |
|
||||
| `FASTMCP_WARN_ON_DUPLICATE_TOOLS` | Show warnings for duplicate tools | `true` |
|
||||
@@ -121,12 +121,17 @@ docker build -t mcp-server-qdrant .
|
||||
|
||||
# Run the container
|
||||
docker run -p 8000:8000 \
|
||||
-e FASTMCP_HOST="0.0.0.0" \
|
||||
-e QDRANT_URL="http://your-qdrant-server:6333" \
|
||||
-e QDRANT_API_KEY="your-api-key" \
|
||||
-e COLLECTION_NAME="your-collection" \
|
||||
mcp-server-qdrant
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> Please note that we set `FASTMCP_HOST="0.0.0.0"` to make the server listen on all network interfaces. This is
|
||||
> necessary when running the server in a Docker container.
|
||||
|
||||
### Installing via Smithery
|
||||
|
||||
To install Qdrant MCP Server for Claude Desktop automatically via [Smithery](https://smithery.ai/protocol/mcp-server-qdrant):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "mcp-server-qdrant"
|
||||
version = "0.7.1"
|
||||
version = "0.8.1"
|
||||
description = "MCP server for retrieving context from a Qdrant vector database"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
@@ -8,8 +8,8 @@ license = "Apache-2.0"
|
||||
dependencies = [
|
||||
"fastembed>=0.6.0",
|
||||
"qdrant-client>=1.12.0",
|
||||
"pydantic>=2.10.6",
|
||||
"fastmcp>=2.5.1",
|
||||
"pydantic>=2.10.6,<2.12.0",
|
||||
"fastmcp==2.7.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
@@ -18,6 +18,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[tool.uv]
|
||||
dev-dependencies = [
|
||||
"ipdb>=0.13.13",
|
||||
"isort>=6.0.1",
|
||||
"mypy>=1.9.0",
|
||||
"pre-commit>=4.1.0",
|
||||
|
||||
194
src/mcp_server_qdrant/common/filters.py
Normal file
194
src/mcp_server_qdrant/common/filters.py
Normal file
@@ -0,0 +1,194 @@
|
||||
from typing import Any
|
||||
|
||||
from qdrant_client import models
|
||||
|
||||
from mcp_server_qdrant.qdrant import ArbitraryFilter
|
||||
from mcp_server_qdrant.settings import METADATA_PATH, FilterableField
|
||||
|
||||
|
||||
def make_filter(
|
||||
filterable_fields: dict[str, FilterableField], values: dict[str, Any]
|
||||
) -> ArbitraryFilter:
|
||||
must_conditions = []
|
||||
must_not_conditions = []
|
||||
|
||||
for raw_field_name, field_value in values.items():
|
||||
if raw_field_name not in filterable_fields:
|
||||
raise ValueError(f"Field {raw_field_name} is not a filterable field")
|
||||
|
||||
field = filterable_fields[raw_field_name]
|
||||
|
||||
if field_value is None:
|
||||
if field.required:
|
||||
raise ValueError(f"Field {raw_field_name} is required")
|
||||
else:
|
||||
continue
|
||||
|
||||
field_name = f"{METADATA_PATH}.{raw_field_name}"
|
||||
|
||||
if field.field_type == "keyword":
|
||||
if field.condition == "==":
|
||||
must_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=field_name, match=models.MatchValue(value=field_value)
|
||||
)
|
||||
)
|
||||
elif field.condition == "!=":
|
||||
must_not_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=field_name, match=models.MatchValue(value=field_value)
|
||||
)
|
||||
)
|
||||
elif field.condition == "any":
|
||||
must_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=field_name, match=models.MatchAny(any=field_value)
|
||||
)
|
||||
)
|
||||
elif field.condition == "except":
|
||||
must_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=field_name,
|
||||
match=models.MatchExcept(**{"except": field_value}),
|
||||
)
|
||||
)
|
||||
elif field.condition is not None:
|
||||
raise ValueError(
|
||||
f"Invalid condition {field.condition} for keyword field {field_name}"
|
||||
)
|
||||
|
||||
elif field.field_type == "integer":
|
||||
if field.condition == "==":
|
||||
must_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=field_name, match=models.MatchValue(value=field_value)
|
||||
)
|
||||
)
|
||||
elif field.condition == "!=":
|
||||
must_not_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=field_name, match=models.MatchValue(value=field_value)
|
||||
)
|
||||
)
|
||||
elif field.condition == ">":
|
||||
must_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=field_name, range=models.Range(gt=field_value)
|
||||
)
|
||||
)
|
||||
elif field.condition == ">=":
|
||||
must_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=field_name, range=models.Range(gte=field_value)
|
||||
)
|
||||
)
|
||||
elif field.condition == "<":
|
||||
must_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=field_name, range=models.Range(lt=field_value)
|
||||
)
|
||||
)
|
||||
elif field.condition == "<=":
|
||||
must_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=field_name, range=models.Range(lte=field_value)
|
||||
)
|
||||
)
|
||||
elif field.condition == "any":
|
||||
must_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=field_name, match=models.MatchAny(any=field_value)
|
||||
)
|
||||
)
|
||||
elif field.condition == "except":
|
||||
must_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=field_name,
|
||||
match=models.MatchExcept(**{"except": field_value}),
|
||||
)
|
||||
)
|
||||
elif field.condition is not None:
|
||||
raise ValueError(
|
||||
f"Invalid condition {field.condition} for integer field {field_name}"
|
||||
)
|
||||
|
||||
elif field.field_type == "float":
|
||||
# For float values, we only support range comparisons
|
||||
if field.condition == ">":
|
||||
must_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=field_name, range=models.Range(gt=field_value)
|
||||
)
|
||||
)
|
||||
elif field.condition == ">=":
|
||||
must_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=field_name, range=models.Range(gte=field_value)
|
||||
)
|
||||
)
|
||||
elif field.condition == "<":
|
||||
must_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=field_name, range=models.Range(lt=field_value)
|
||||
)
|
||||
)
|
||||
elif field.condition == "<=":
|
||||
must_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=field_name, range=models.Range(lte=field_value)
|
||||
)
|
||||
)
|
||||
elif field.condition is not None:
|
||||
raise ValueError(
|
||||
f"Invalid condition {field.condition} for float field {field_name}. "
|
||||
"Only range comparisons (>, >=, <, <=) are supported for float values."
|
||||
)
|
||||
|
||||
elif field.field_type == "boolean":
|
||||
if field.condition == "==":
|
||||
must_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=field_name, match=models.MatchValue(value=field_value)
|
||||
)
|
||||
)
|
||||
elif field.condition == "!=":
|
||||
must_not_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=field_name, match=models.MatchValue(value=field_value)
|
||||
)
|
||||
)
|
||||
elif field.condition is not None:
|
||||
raise ValueError(
|
||||
f"Invalid condition {field.condition} for boolean field {field_name}"
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported field type {field.field_type} for field {field_name}"
|
||||
)
|
||||
|
||||
return models.Filter(
|
||||
must=must_conditions, must_not=must_not_conditions
|
||||
).model_dump()
|
||||
|
||||
|
||||
def make_indexes(
|
||||
filterable_fields: dict[str, FilterableField],
|
||||
) -> dict[str, models.PayloadSchemaType]:
|
||||
indexes = {}
|
||||
|
||||
for field_name, field in filterable_fields.items():
|
||||
if field.field_type == "keyword":
|
||||
indexes[f"{METADATA_PATH}.{field_name}"] = models.PayloadSchemaType.KEYWORD
|
||||
elif field.field_type == "integer":
|
||||
indexes[f"{METADATA_PATH}.{field_name}"] = models.PayloadSchemaType.INTEGER
|
||||
elif field.field_type == "float":
|
||||
indexes[f"{METADATA_PATH}.{field_name}"] = models.PayloadSchemaType.FLOAT
|
||||
elif field.field_type == "boolean":
|
||||
indexes[f"{METADATA_PATH}.{field_name}"] = models.PayloadSchemaType.BOOL
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported field type {field.field_type} for field {field_name}"
|
||||
)
|
||||
|
||||
return indexes
|
||||
150
src/mcp_server_qdrant/common/wrap_filters.py
Normal file
150
src/mcp_server_qdrant/common/wrap_filters.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from typing import Annotated, Callable, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from mcp_server_qdrant.common.filters import make_filter
|
||||
from mcp_server_qdrant.settings import FilterableField
|
||||
|
||||
|
||||
def wrap_filters(
|
||||
original_func: Callable, filterable_fields: dict[str, FilterableField]
|
||||
) -> Callable:
|
||||
"""
|
||||
Wraps the original_func function: replaces `filter` parameter with multiple parameters defined by `filterable_fields`.
|
||||
"""
|
||||
|
||||
sig = inspect.signature(original_func)
|
||||
|
||||
@wraps(original_func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Start with fixed values
|
||||
filter_values = {}
|
||||
|
||||
for field_name in filterable_fields:
|
||||
if field_name in kwargs:
|
||||
filter_values[field_name] = kwargs.pop(field_name)
|
||||
|
||||
query_filter = make_filter(filterable_fields, filter_values)
|
||||
|
||||
return original_func(**kwargs, query_filter=query_filter)
|
||||
|
||||
# Replace `query_filter` signature with parameters from `filterable_fields`
|
||||
|
||||
param_names = []
|
||||
|
||||
for param_name in sig.parameters:
|
||||
if param_name == "query_filter":
|
||||
continue
|
||||
param_names.append(param_name)
|
||||
|
||||
new_params = [sig.parameters[param_name] for param_name in param_names]
|
||||
required_new_params = []
|
||||
optional_new_params = []
|
||||
|
||||
# Create a new signature parameters from `filterable_fields`
|
||||
for field in filterable_fields.values():
|
||||
field_name = field.name
|
||||
field_type: type
|
||||
if field.field_type == "keyword":
|
||||
field_type = str
|
||||
elif field.field_type == "integer":
|
||||
field_type = int
|
||||
elif field.field_type == "float":
|
||||
field_type = float
|
||||
elif field.field_type == "boolean":
|
||||
field_type = bool
|
||||
else:
|
||||
raise ValueError(f"Unsupported field type: {field.field_type}")
|
||||
|
||||
if field.condition in {"any", "except"}:
|
||||
if field_type not in {str, int}:
|
||||
raise ValueError(
|
||||
f'Only "keyword" and "integer" types are supported for "{field.condition}" condition'
|
||||
)
|
||||
field_type = list[field_type] # type: ignore
|
||||
|
||||
if field.required:
|
||||
annotation = Annotated[field_type, Field(description=field.description)] # type: ignore
|
||||
parameter = inspect.Parameter(
|
||||
name=field_name,
|
||||
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=annotation,
|
||||
)
|
||||
required_new_params.append(parameter)
|
||||
else:
|
||||
annotation = Annotated[ # type: ignore
|
||||
Optional[field_type], Field(description=field.description)
|
||||
]
|
||||
parameter = inspect.Parameter(
|
||||
name=field_name,
|
||||
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
default=None,
|
||||
annotation=annotation,
|
||||
)
|
||||
optional_new_params.append(parameter)
|
||||
|
||||
new_params.extend(required_new_params)
|
||||
new_params.extend(optional_new_params)
|
||||
|
||||
# Set the new __signature__ for introspection
|
||||
new_signature = sig.replace(parameters=new_params)
|
||||
wrapper.__signature__ = new_signature # type: ignore
|
||||
|
||||
# Set the new __annotations__ for introspection
|
||||
new_annotations = {}
|
||||
for param in new_signature.parameters.values():
|
||||
if param.annotation != inspect.Parameter.empty:
|
||||
new_annotations[param.name] = param.annotation
|
||||
|
||||
# Add return type annotation if it exists
|
||||
if new_signature.return_annotation != inspect.Parameter.empty:
|
||||
new_annotations["return"] = new_signature.return_annotation
|
||||
|
||||
wrapper.__annotations__ = new_annotations
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pydantic._internal._typing_extra import get_function_type_hints
|
||||
from qdrant_client import models
|
||||
|
||||
def find(
|
||||
query: Annotated[str, Field(description="What to search for")],
|
||||
collection_name: Annotated[
|
||||
str, Field(description="The collection to search in")
|
||||
],
|
||||
query_filter: Optional[models.Filter] = None,
|
||||
) -> list[str]:
|
||||
print("query", query)
|
||||
print("collection_name", collection_name)
|
||||
print("query_filter", query_filter)
|
||||
return ["mypy rules"]
|
||||
|
||||
wrapped_find = wrap_filters(
|
||||
find,
|
||||
{
|
||||
"color": FilterableField(
|
||||
name="color",
|
||||
description="The color of the object",
|
||||
field_type="keyword",
|
||||
condition="==",
|
||||
),
|
||||
"size": FilterableField(
|
||||
name="size",
|
||||
description="The size of the object",
|
||||
field_type="keyword",
|
||||
condition="==",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
wrapped_find(query="dress", collection_name="test", color="red")
|
||||
|
||||
print("get_function_type_hints(find)", get_function_type_hints(find))
|
||||
print(
|
||||
"get_function_type_hints(wrapped_find)", get_function_type_hints(wrapped_find)
|
||||
)
|
||||
@@ -1,17 +1,16 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
"""Abstract base class for embedding providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def embed_documents(self, documents: List[str]) -> List[List[float]]:
|
||||
async def embed_documents(self, documents: list[str]) -> list[list[float]]:
|
||||
"""Embed a list of documents into vectors."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def embed_query(self, query: str) -> List[float]:
|
||||
async def embed_query(self, query: str) -> list[float]:
|
||||
"""Embed a query into a vector."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
from typing import List
|
||||
|
||||
from fastembed import TextEmbedding
|
||||
from fastembed.common.model_description import DenseModelDescription
|
||||
@@ -17,7 +16,7 @@ class FastEmbedProvider(EmbeddingProvider):
|
||||
self.model_name = model_name
|
||||
self.embedding_model = TextEmbedding(model_name)
|
||||
|
||||
async def embed_documents(self, documents: List[str]) -> List[List[float]]:
|
||||
async def embed_documents(self, documents: list[str]) -> list[list[float]]:
|
||||
"""Embed a list of documents into vectors."""
|
||||
# Run in a thread pool since FastEmbed is synchronous
|
||||
loop = asyncio.get_event_loop()
|
||||
@@ -26,7 +25,7 @@ class FastEmbedProvider(EmbeddingProvider):
|
||||
)
|
||||
return [embedding.tolist() for embedding in embeddings]
|
||||
|
||||
async def embed_query(self, query: str) -> List[float]:
|
||||
async def embed_query(self, query: str) -> list[float]:
|
||||
"""Embed a query into a vector."""
|
||||
# Run in a thread pool since FastEmbed is synchronous
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, List, Optional
|
||||
from typing import Annotated, Any, Optional
|
||||
|
||||
from fastmcp import Context, FastMCP
|
||||
from pydantic import Field
|
||||
from qdrant_client import models
|
||||
|
||||
from mcp_server_qdrant.common.filters import make_indexes
|
||||
from mcp_server_qdrant.common.func_tools import make_partial_function
|
||||
from mcp_server_qdrant.common.wrap_filters import wrap_filters
|
||||
from mcp_server_qdrant.embeddings.base import EmbeddingProvider
|
||||
from mcp_server_qdrant.embeddings.factory import create_embedding_provider
|
||||
from mcp_server_qdrant.qdrant import Entry, Metadata, QdrantConnector
|
||||
from mcp_server_qdrant.qdrant import ArbitraryFilter, Entry, Metadata, QdrantConnector
|
||||
from mcp_server_qdrant.settings import (
|
||||
EmbeddingProviderSettings,
|
||||
QdrantSettings,
|
||||
@@ -27,22 +32,46 @@ class QdrantMCPServer(FastMCP):
|
||||
self,
|
||||
tool_settings: ToolSettings,
|
||||
qdrant_settings: QdrantSettings,
|
||||
embedding_provider_settings: EmbeddingProviderSettings,
|
||||
embedding_provider_settings: Optional[EmbeddingProviderSettings] = None,
|
||||
embedding_provider: Optional[EmbeddingProvider] = None,
|
||||
name: str = "mcp-server-qdrant",
|
||||
instructions: str | None = None,
|
||||
**settings: Any,
|
||||
):
|
||||
self.tool_settings = tool_settings
|
||||
self.qdrant_settings = qdrant_settings
|
||||
self.embedding_provider_settings = embedding_provider_settings
|
||||
|
||||
self.embedding_provider = create_embedding_provider(embedding_provider_settings)
|
||||
if embedding_provider_settings and embedding_provider:
|
||||
raise ValueError(
|
||||
"Cannot provide both embedding_provider_settings and embedding_provider"
|
||||
)
|
||||
|
||||
if not embedding_provider_settings and not embedding_provider:
|
||||
raise ValueError(
|
||||
"Must provide either embedding_provider_settings or embedding_provider"
|
||||
)
|
||||
|
||||
self.embedding_provider_settings: Optional[EmbeddingProviderSettings] = None
|
||||
self.embedding_provider: Optional[EmbeddingProvider] = None
|
||||
|
||||
if embedding_provider_settings:
|
||||
self.embedding_provider_settings = embedding_provider_settings
|
||||
self.embedding_provider = create_embedding_provider(
|
||||
embedding_provider_settings
|
||||
)
|
||||
else:
|
||||
self.embedding_provider_settings = None
|
||||
self.embedding_provider = embedding_provider
|
||||
|
||||
assert self.embedding_provider is not None, "Embedding provider is required"
|
||||
|
||||
self.qdrant_connector = QdrantConnector(
|
||||
qdrant_settings.location,
|
||||
qdrant_settings.api_key,
|
||||
qdrant_settings.collection_name,
|
||||
self.embedding_provider,
|
||||
qdrant_settings.local_path,
|
||||
make_indexes(qdrant_settings.filterable_fields_dict()),
|
||||
)
|
||||
|
||||
super().__init__(name=name, instructions=instructions, **settings)
|
||||
@@ -63,12 +92,19 @@ class QdrantMCPServer(FastMCP):
|
||||
|
||||
async def store(
|
||||
ctx: Context,
|
||||
information: str,
|
||||
collection_name: str,
|
||||
information: Annotated[str, Field(description="Text to store")],
|
||||
collection_name: Annotated[
|
||||
str, Field(description="The collection to store the information in")
|
||||
],
|
||||
# The `metadata` parameter is defined as non-optional, but it can be None.
|
||||
# If we set it to be optional, some of the MCP clients, like Cursor, cannot
|
||||
# handle the optional parameter correctly.
|
||||
metadata: Optional[Metadata] = None, # type: ignore
|
||||
metadata: Annotated[
|
||||
Metadata | None,
|
||||
Field(
|
||||
description="Extra metadata stored along with memorised information. Any json is accepted."
|
||||
),
|
||||
] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Store some information in Qdrant.
|
||||
@@ -90,30 +126,37 @@ class QdrantMCPServer(FastMCP):
|
||||
|
||||
async def find(
|
||||
ctx: Context,
|
||||
query: str,
|
||||
collection_name: str,
|
||||
) -> List[str]:
|
||||
query: Annotated[str, Field(description="What to search for")],
|
||||
collection_name: Annotated[
|
||||
str, Field(description="The collection to search in")
|
||||
],
|
||||
query_filter: ArbitraryFilter | None = None,
|
||||
) -> list[str] | None:
|
||||
"""
|
||||
Find memories in Qdrant.
|
||||
:param ctx: The context for the request.
|
||||
:param query: The query to use for the search.
|
||||
:param collection_name: The name of the collection to search in, optional. If not provided,
|
||||
the default collection is used.
|
||||
:return: A list of entries found.
|
||||
:param query_filter: The filter to apply to the query.
|
||||
:return: A list of entries found or None.
|
||||
"""
|
||||
|
||||
# Log query_filter
|
||||
await ctx.debug(f"Query filter: {query_filter}")
|
||||
|
||||
query_filter = models.Filter(**query_filter) if query_filter else None
|
||||
|
||||
await ctx.debug(f"Finding results for query {query}")
|
||||
if collection_name:
|
||||
await ctx.debug(
|
||||
f"Overriding the collection name with {collection_name}"
|
||||
)
|
||||
|
||||
entries = await self.qdrant_connector.search(
|
||||
query,
|
||||
collection_name=collection_name,
|
||||
limit=self.qdrant_settings.search_limit,
|
||||
query_filter=query_filter,
|
||||
)
|
||||
if not entries:
|
||||
return [f"No information found for the query '{query}'"]
|
||||
return None
|
||||
content = [
|
||||
f"Results for the query '{query}'",
|
||||
]
|
||||
@@ -124,6 +167,15 @@ class QdrantMCPServer(FastMCP):
|
||||
find_foo = find
|
||||
store_foo = store
|
||||
|
||||
filterable_conditions = (
|
||||
self.qdrant_settings.filterable_fields_dict_with_conditions()
|
||||
)
|
||||
|
||||
if len(filterable_conditions) > 0:
|
||||
find_foo = wrap_filters(find_foo, filterable_conditions)
|
||||
elif not self.qdrant_settings.allow_arbitrary_filter:
|
||||
find_foo = make_partial_function(find_foo, {"query_filter": None})
|
||||
|
||||
if self.qdrant_settings.collection_name:
|
||||
find_foo = make_partial_function(
|
||||
find_foo, {"collection_name": self.qdrant_settings.collection_name}
|
||||
@@ -132,7 +184,7 @@ class QdrantMCPServer(FastMCP):
|
||||
store_foo, {"collection_name": self.qdrant_settings.collection_name}
|
||||
)
|
||||
|
||||
self.add_tool(
|
||||
self.tool(
|
||||
find_foo,
|
||||
name="qdrant-find",
|
||||
description=self.tool_settings.tool_find_description,
|
||||
@@ -140,7 +192,7 @@ class QdrantMCPServer(FastMCP):
|
||||
|
||||
if not self.qdrant_settings.read_only:
|
||||
# Those methods can modify the database
|
||||
self.add_tool(
|
||||
self.tool(
|
||||
store_foo,
|
||||
name="qdrant-store",
|
||||
description=self.tool_settings.tool_store_description,
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from qdrant_client import AsyncQdrantClient, models
|
||||
|
||||
from mcp_server_qdrant.embeddings.base import EmbeddingProvider
|
||||
from mcp_server_qdrant.settings import METADATA_PATH
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Metadata = Dict[str, Any]
|
||||
Metadata = dict[str, Any]
|
||||
ArbitraryFilter = dict[str, Any]
|
||||
|
||||
|
||||
class Entry(BaseModel):
|
||||
@@ -18,7 +20,7 @@ class Entry(BaseModel):
|
||||
"""
|
||||
|
||||
content: str
|
||||
metadata: Optional[Metadata] = None
|
||||
metadata: Metadata | None = None
|
||||
|
||||
|
||||
class QdrantConnector:
|
||||
@@ -34,11 +36,12 @@ class QdrantConnector:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
qdrant_url: Optional[str],
|
||||
qdrant_api_key: Optional[str],
|
||||
collection_name: Optional[str],
|
||||
qdrant_url: str | None,
|
||||
qdrant_api_key: str | None,
|
||||
collection_name: str | None,
|
||||
embedding_provider: EmbeddingProvider,
|
||||
qdrant_local_path: Optional[str] = None,
|
||||
qdrant_local_path: str | None = None,
|
||||
field_indexes: dict[str, models.PayloadSchemaType] | None = None,
|
||||
):
|
||||
self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None
|
||||
self._qdrant_api_key = qdrant_api_key
|
||||
@@ -47,6 +50,7 @@ class QdrantConnector:
|
||||
self._client = AsyncQdrantClient(
|
||||
location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path
|
||||
)
|
||||
self._field_indexes = field_indexes
|
||||
|
||||
async def get_collection_names(self) -> list[str]:
|
||||
"""
|
||||
@@ -56,7 +60,7 @@ class QdrantConnector:
|
||||
response = await self._client.get_collections()
|
||||
return [collection.name for collection in response.collections]
|
||||
|
||||
async def store(self, entry: Entry, *, collection_name: Optional[str] = None):
|
||||
async def store(self, entry: Entry, *, collection_name: str | None = None):
|
||||
"""
|
||||
Store some information in the Qdrant collection, along with the specified metadata.
|
||||
:param entry: The entry to store in the Qdrant collection.
|
||||
@@ -74,7 +78,7 @@ class QdrantConnector:
|
||||
|
||||
# Add to Qdrant
|
||||
vector_name = self._embedding_provider.get_vector_name()
|
||||
payload = {"document": entry.content, "metadata": entry.metadata}
|
||||
payload = {"document": entry.content, METADATA_PATH: entry.metadata}
|
||||
await self._client.upsert(
|
||||
collection_name=collection_name,
|
||||
points=[
|
||||
@@ -87,7 +91,12 @@ class QdrantConnector:
|
||||
)
|
||||
|
||||
async def search(
|
||||
self, query: str, *, collection_name: Optional[str] = None, limit: int = 10
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
collection_name: str | None = None,
|
||||
limit: int = 10,
|
||||
query_filter: models.Filter | None = None,
|
||||
) -> list[Entry]:
|
||||
"""
|
||||
Find points in the Qdrant collection. If there are no entries found, an empty list is returned.
|
||||
@@ -95,6 +104,8 @@ class QdrantConnector:
|
||||
:param collection_name: The name of the collection to search in, optional. If not provided,
|
||||
the default collection is used.
|
||||
:param limit: The maximum number of entries to return.
|
||||
:param query_filter: The filter to apply to the query, if any.
|
||||
|
||||
:return: A list of entries found.
|
||||
"""
|
||||
collection_name = collection_name or self._default_collection_name
|
||||
@@ -115,6 +126,7 @@ class QdrantConnector:
|
||||
query=query_vector,
|
||||
using=vector_name,
|
||||
limit=limit,
|
||||
query_filter=query_filter,
|
||||
)
|
||||
|
||||
return [
|
||||
@@ -146,3 +158,13 @@ class QdrantConnector:
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# Create payload indexes if configured
|
||||
|
||||
if self._field_indexes:
|
||||
for field_name, field_type in self._field_indexes.items():
|
||||
await self._client.create_payload_index(
|
||||
collection_name=collection_name,
|
||||
field_name=field_name,
|
||||
field_schema=field_type,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from mcp_server_qdrant.embeddings.types import EmbeddingProviderType
|
||||
@@ -15,6 +15,8 @@ DEFAULT_TOOL_FIND_DESCRIPTION = (
|
||||
" - Get some personal information about the user"
|
||||
)
|
||||
|
||||
METADATA_PATH = "metadata"
|
||||
|
||||
|
||||
class ToolSettings(BaseSettings):
|
||||
"""
|
||||
@@ -46,18 +48,68 @@ class EmbeddingProviderSettings(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class FilterableField(BaseModel):
|
||||
name: str = Field(description="The name of the field payload field to filter on")
|
||||
description: str = Field(
|
||||
description="A description for the field used in the tool description"
|
||||
)
|
||||
field_type: Literal["keyword", "integer", "float", "boolean"] = Field(
|
||||
description="The type of the field"
|
||||
)
|
||||
condition: Literal["==", "!=", ">", ">=", "<", "<=", "any", "except"] | None = (
|
||||
Field(
|
||||
default=None,
|
||||
description=(
|
||||
"The condition to use for the filter. If not provided, the field will be indexed, but no "
|
||||
"filter argument will be exposed to MCP tool."
|
||||
),
|
||||
)
|
||||
)
|
||||
required: bool = Field(
|
||||
default=False,
|
||||
description="Whether the field is required for the filter.",
|
||||
)
|
||||
|
||||
|
||||
class QdrantSettings(BaseSettings):
|
||||
"""
|
||||
Configuration for the Qdrant connector.
|
||||
"""
|
||||
|
||||
location: Optional[str] = Field(default=None, validation_alias="QDRANT_URL")
|
||||
api_key: Optional[str] = Field(default=None, validation_alias="QDRANT_API_KEY")
|
||||
collection_name: Optional[str] = Field(
|
||||
location: str | None = Field(default=None, validation_alias="QDRANT_URL")
|
||||
api_key: str | None = Field(default=None, validation_alias="QDRANT_API_KEY")
|
||||
collection_name: str | None = Field(
|
||||
default=None, validation_alias="COLLECTION_NAME"
|
||||
)
|
||||
local_path: Optional[str] = Field(
|
||||
default=None, validation_alias="QDRANT_LOCAL_PATH"
|
||||
)
|
||||
local_path: str | None = Field(default=None, validation_alias="QDRANT_LOCAL_PATH")
|
||||
search_limit: int = Field(default=10, validation_alias="QDRANT_SEARCH_LIMIT")
|
||||
read_only: bool = Field(default=False, validation_alias="QDRANT_READ_ONLY")
|
||||
|
||||
filterable_fields: list[FilterableField] | None = Field(default=None)
|
||||
|
||||
allow_arbitrary_filter: bool = Field(
|
||||
default=False, validation_alias="QDRANT_ALLOW_ARBITRARY_FILTER"
|
||||
)
|
||||
|
||||
def filterable_fields_dict(self) -> dict[str, FilterableField]:
|
||||
if self.filterable_fields is None:
|
||||
return {}
|
||||
return {field.name: field for field in self.filterable_fields}
|
||||
|
||||
def filterable_fields_dict_with_conditions(self) -> dict[str, FilterableField]:
|
||||
if self.filterable_fields is None:
|
||||
return {}
|
||||
return {
|
||||
field.name: field
|
||||
for field in self.filterable_fields
|
||||
if field.condition is not None
|
||||
}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_local_path_conflict(self) -> "QdrantSettings":
|
||||
if self.local_path:
|
||||
if self.location is not None or self.api_key is not None:
|
||||
raise ValueError(
|
||||
"If 'local_path' is set, 'location' and 'api_key' must be None."
|
||||
)
|
||||
return self
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from mcp_server_qdrant.embeddings.types import EmbeddingProviderType
|
||||
from mcp_server_qdrant.settings import (
|
||||
@@ -18,34 +17,51 @@ class TestQdrantSettings:
|
||||
# Should not raise error because there are no required fields
|
||||
QdrantSettings()
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{"QDRANT_URL": "http://localhost:6333", "COLLECTION_NAME": "test_collection"},
|
||||
)
|
||||
def test_minimal_config(self):
|
||||
def test_minimal_config(self, monkeypatch):
|
||||
"""Test loading minimal configuration from environment variables."""
|
||||
monkeypatch.setenv("QDRANT_URL", "http://localhost:6333")
|
||||
monkeypatch.setenv("COLLECTION_NAME", "test_collection")
|
||||
|
||||
settings = QdrantSettings()
|
||||
assert settings.location == "http://localhost:6333"
|
||||
assert settings.collection_name == "test_collection"
|
||||
assert settings.api_key is None
|
||||
assert settings.local_path is None
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"QDRANT_URL": "http://qdrant.example.com:6333",
|
||||
"QDRANT_API_KEY": "test_api_key",
|
||||
"COLLECTION_NAME": "my_memories",
|
||||
"QDRANT_LOCAL_PATH": "/tmp/qdrant",
|
||||
},
|
||||
)
|
||||
def test_full_config(self):
|
||||
def test_full_config(self, monkeypatch):
|
||||
"""Test loading full configuration from environment variables."""
|
||||
monkeypatch.setenv("QDRANT_URL", "http://qdrant.example.com:6333")
|
||||
monkeypatch.setenv("QDRANT_API_KEY", "test_api_key")
|
||||
monkeypatch.setenv("COLLECTION_NAME", "my_memories")
|
||||
monkeypatch.setenv("QDRANT_SEARCH_LIMIT", "15")
|
||||
monkeypatch.setenv("QDRANT_READ_ONLY", "1")
|
||||
|
||||
settings = QdrantSettings()
|
||||
assert settings.location == "http://qdrant.example.com:6333"
|
||||
assert settings.api_key == "test_api_key"
|
||||
assert settings.collection_name == "my_memories"
|
||||
assert settings.local_path == "/tmp/qdrant"
|
||||
assert settings.search_limit == 15
|
||||
assert settings.read_only is True
|
||||
|
||||
def test_local_path_config(self, monkeypatch):
|
||||
"""Test loading local path configuration from environment variables."""
|
||||
monkeypatch.setenv("QDRANT_LOCAL_PATH", "/path/to/local/qdrant")
|
||||
|
||||
settings = QdrantSettings()
|
||||
assert settings.local_path == "/path/to/local/qdrant"
|
||||
|
||||
def test_local_path_is_exclusive_with_url(self, monkeypatch):
|
||||
"""Test that local path cannot be set if Qdrant URL is provided."""
|
||||
monkeypatch.setenv("QDRANT_URL", "http://localhost:6333")
|
||||
monkeypatch.setenv("QDRANT_LOCAL_PATH", "/path/to/local/qdrant")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
QdrantSettings()
|
||||
|
||||
monkeypatch.delenv("QDRANT_URL", raising=False)
|
||||
monkeypatch.setenv("QDRANT_API_KEY", "test_api_key")
|
||||
with pytest.raises(ValueError):
|
||||
QdrantSettings()
|
||||
|
||||
|
||||
class TestEmbeddingProviderSettings:
|
||||
@@ -55,12 +71,9 @@ class TestEmbeddingProviderSettings:
|
||||
assert settings.provider_type == EmbeddingProviderType.FASTEMBED
|
||||
assert settings.model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{"EMBEDDING_MODEL": "custom_model"},
|
||||
)
|
||||
def test_custom_values(self):
|
||||
def test_custom_values(self, monkeypatch):
|
||||
"""Test loading custom values from environment variables."""
|
||||
monkeypatch.setenv("EMBEDDING_MODEL", "custom_model")
|
||||
settings = EmbeddingProviderSettings()
|
||||
assert settings.provider_type == EmbeddingProviderType.FASTEMBED
|
||||
assert settings.model_name == "custom_model"
|
||||
@@ -73,35 +86,24 @@ class TestToolSettings:
|
||||
assert settings.tool_store_description == DEFAULT_TOOL_STORE_DESCRIPTION
|
||||
assert settings.tool_find_description == DEFAULT_TOOL_FIND_DESCRIPTION
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{"TOOL_STORE_DESCRIPTION": "Custom store description"},
|
||||
)
|
||||
def test_custom_store_description(self):
|
||||
def test_custom_store_description(self, monkeypatch):
|
||||
"""Test loading custom store description from environment variable."""
|
||||
monkeypatch.setenv("TOOL_STORE_DESCRIPTION", "Custom store description")
|
||||
settings = ToolSettings()
|
||||
assert settings.tool_store_description == "Custom store description"
|
||||
assert settings.tool_find_description == DEFAULT_TOOL_FIND_DESCRIPTION
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{"TOOL_FIND_DESCRIPTION": "Custom find description"},
|
||||
)
|
||||
def test_custom_find_description(self):
|
||||
def test_custom_find_description(self, monkeypatch):
|
||||
"""Test loading custom find description from environment variable."""
|
||||
monkeypatch.setenv("TOOL_FIND_DESCRIPTION", "Custom find description")
|
||||
settings = ToolSettings()
|
||||
assert settings.tool_store_description == DEFAULT_TOOL_STORE_DESCRIPTION
|
||||
assert settings.tool_find_description == "Custom find description"
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"TOOL_STORE_DESCRIPTION": "Custom store description",
|
||||
"TOOL_FIND_DESCRIPTION": "Custom find description",
|
||||
},
|
||||
)
|
||||
def test_all_custom_values(self):
|
||||
def test_all_custom_values(self, monkeypatch):
|
||||
"""Test loading all custom values from environment variables."""
|
||||
monkeypatch.setenv("TOOL_STORE_DESCRIPTION", "Custom store description")
|
||||
monkeypatch.setenv("TOOL_FIND_DESCRIPTION", "Custom find description")
|
||||
settings = ToolSettings()
|
||||
assert settings.tool_store_description == "Custom store description"
|
||||
assert settings.tool_find_description == "Custom find description"
|
||||
|
||||
Reference in New Issue
Block a user