new: update type hints (#64)

* new: update type hints

* fix: do not pass location and path to qdrant client, and do not accept them together

* new: update settings tests

* fix: revert removal of local path
This commit is contained in:
George
2025-06-12 00:55:07 +04:00
committed by GitHub
parent b657656363
commit 28bf298a32
6 changed files with 83 additions and 78 deletions

View File

@@ -1,17 +1,16 @@
from abc import ABC, abstractmethod
from typing import List
class EmbeddingProvider(ABC):
"""Abstract base class for embedding providers."""
@abstractmethod
async def embed_documents(self, documents: List[str]) -> List[List[float]]:
async def embed_documents(self, documents: list[str]) -> list[list[float]]:
"""Embed a list of documents into vectors."""
pass
@abstractmethod
async def embed_query(self, query: str) -> List[float]:
async def embed_query(self, query: str) -> list[float]:
"""Embed a query into a vector."""
pass

View File

@@ -1,5 +1,4 @@
import asyncio
from typing import List
from fastembed import TextEmbedding
from fastembed.common.model_description import DenseModelDescription
@@ -17,7 +16,7 @@ class FastEmbedProvider(EmbeddingProvider):
self.model_name = model_name
self.embedding_model = TextEmbedding(model_name)
async def embed_documents(self, documents: List[str]) -> List[List[float]]:
async def embed_documents(self, documents: list[str]) -> list[list[float]]:
"""Embed a list of documents into vectors."""
# Run in a thread pool since FastEmbed is synchronous
loop = asyncio.get_event_loop()
@@ -26,7 +25,7 @@ class FastEmbedProvider(EmbeddingProvider):
)
return [embedding.tolist() for embedding in embeddings]
async def embed_query(self, query: str) -> List[float]:
async def embed_query(self, query: str) -> list[float]:
"""Embed a query into a vector."""
# Run in a thread pool since FastEmbed is synchronous
loop = asyncio.get_event_loop()

View File

@@ -1,6 +1,6 @@
import json
import logging
from typing import Annotated, Any, List, Optional
from typing import Annotated, Any
from fastmcp import Context, FastMCP
from pydantic import Field
@@ -76,7 +76,7 @@ class QdrantMCPServer(FastMCP):
# If we set it to be optional, some of the MCP clients, like Cursor, cannot
# handle the optional parameter correctly.
metadata: Annotated[
Optional[Metadata],
Metadata | None,
Field(
description="Extra metadata stored along with memorised information. Any json is accepted."
),
@@ -106,14 +106,15 @@ class QdrantMCPServer(FastMCP):
collection_name: Annotated[
str, Field(description="The collection to search in")
],
query_filter: Optional[ArbitraryFilter] = None,
) -> List[str]:
query_filter: ArbitraryFilter | None = None,
) -> list[str]:
"""
Find memories in Qdrant.
:param ctx: The context for the request.
:param query: The query to use for the search.
:param collection_name: The name of the collection to search in, optional. If not provided,
the default collection is used.
:param query_filter: The filter to apply to the query.
:return: A list of entries found.
"""
@@ -123,10 +124,6 @@ class QdrantMCPServer(FastMCP):
query_filter = models.Filter(**query_filter) if query_filter else None
await ctx.debug(f"Finding results for query {query}")
if collection_name:
await ctx.debug(
f"Overriding the collection name with {collection_name}"
)
entries = await self.qdrant_connector.search(
query,

View File

@@ -1,6 +1,6 @@
import logging
import uuid
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel
from qdrant_client import AsyncQdrantClient, models
@@ -10,9 +10,8 @@ from mcp_server_qdrant.settings import METADATA_PATH
logger = logging.getLogger(__name__)
Metadata = Dict[str, Any]
ArbitraryFilter = Dict[str, Any]
Metadata = dict[str, Any]
ArbitraryFilter = dict[str, Any]
class Entry(BaseModel):
@@ -21,7 +20,7 @@ class Entry(BaseModel):
"""
content: str
metadata: Optional[Metadata] = None
metadata: Metadata | None = None
class QdrantConnector:
@@ -37,12 +36,12 @@ class QdrantConnector:
def __init__(
self,
qdrant_url: Optional[str],
qdrant_api_key: Optional[str],
collection_name: Optional[str],
qdrant_url: str | None,
qdrant_api_key: str | None,
collection_name: str | None,
embedding_provider: EmbeddingProvider,
qdrant_local_path: Optional[str] = None,
field_indexes: Optional[dict[str, models.PayloadSchemaType]] = None,
qdrant_local_path: str | None = None,
field_indexes: dict[str, models.PayloadSchemaType] | None = None,
):
self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None
self._qdrant_api_key = qdrant_api_key
@@ -61,7 +60,7 @@ class QdrantConnector:
response = await self._client.get_collections()
return [collection.name for collection in response.collections]
async def store(self, entry: Entry, *, collection_name: Optional[str] = None):
async def store(self, entry: Entry, *, collection_name: str | None = None):
"""
Store some information in the Qdrant collection, along with the specified metadata.
:param entry: The entry to store in the Qdrant collection.
@@ -95,9 +94,9 @@ class QdrantConnector:
self,
query: str,
*,
collection_name: Optional[str] = None,
collection_name: str | None = None,
limit: int = 10,
query_filter: Optional[models.Filter] = None,
query_filter: models.Filter | None = None,
) -> list[Entry]:
"""
Find points in the Qdrant collection. If there are no entries found, an empty list is returned.
@@ -105,6 +104,8 @@ class QdrantConnector:
:param collection_name: The name of the collection to search in, optional. If not provided,
the default collection is used.
:param limit: The maximum number of entries to return.
:param query_filter: The filter to apply to the query, if any.
:return: A list of entries found.
"""
collection_name = collection_name or self._default_collection_name

View File

@@ -1,6 +1,6 @@
from typing import Literal, Optional
from typing import Literal
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from pydantic_settings import BaseSettings
from mcp_server_qdrant.embeddings.types import EmbeddingProviderType
@@ -56,7 +56,7 @@ class FilterableField(BaseModel):
field_type: Literal["keyword", "integer", "float", "boolean"] = Field(
description="The type of the field"
)
condition: Optional[Literal["==", "!=", ">", ">=", "<", "<=", "any", "except"]] = (
condition: Literal["==", "!=", ">", ">=", "<", "<=", "any", "except"] | None = (
Field(
default=None,
description=(
@@ -76,18 +76,16 @@ class QdrantSettings(BaseSettings):
Configuration for the Qdrant connector.
"""
location: Optional[str] = Field(default=None, validation_alias="QDRANT_URL")
api_key: Optional[str] = Field(default=None, validation_alias="QDRANT_API_KEY")
collection_name: Optional[str] = Field(
location: str | None = Field(default=None, validation_alias="QDRANT_URL")
api_key: str | None = Field(default=None, validation_alias="QDRANT_API_KEY")
collection_name: str | None = Field(
default=None, validation_alias="COLLECTION_NAME"
)
local_path: Optional[str] = Field(
default=None, validation_alias="QDRANT_LOCAL_PATH"
)
local_path: str | None = Field(default=None, validation_alias="QDRANT_LOCAL_PATH")
search_limit: int = Field(default=10, validation_alias="QDRANT_SEARCH_LIMIT")
read_only: bool = Field(default=False, validation_alias="QDRANT_READ_ONLY")
filterable_fields: Optional[list[FilterableField]] = Field(default=None)
filterable_fields: list[FilterableField] | None = Field(default=None)
allow_arbitrary_filter: bool = Field(
default=False, validation_alias="QDRANT_ALLOW_ARBITRARY_FILTER"
@@ -106,3 +104,12 @@ class QdrantSettings(BaseSettings):
for field in self.filterable_fields
if field.condition is not None
}
@model_validator(mode="after")
def check_local_path_conflict(self) -> "QdrantSettings":
if self.local_path:
if self.location is not None or self.api_key is not None:
raise ValueError(
"If 'local_path' is set, 'location' and 'api_key' must be None."
)
return self