Configurable filters (#58)
* add configurable filters * hello to hr department * rollback debug code * add arbitrary filter * dont consider fields without conditions * in and except condition * proper annotation types for optional and list fields * fix types import * skip non-required fields * fix: fix match except condition, fix boolean filter * fix: apply ruff * fix: make condition optional in filterable field * fix: do not set default value for required fields (#63) * fix: do not set default value for required fields * fix: temp fix fastmcp to <2.8.0 cause of the breaking changes in the api * fix: add missing changes to pyproject.toml * fix: downgrade fastmcp even further to <2.7.0 --------- Co-authored-by: George Panchuk <george.panchuk@qdrant.tech> Co-authored-by: George <panchuk.george@outlook.com>
This commit is contained in:
@@ -6,11 +6,14 @@ from pydantic import BaseModel
|
||||
from qdrant_client import AsyncQdrantClient, models
|
||||
|
||||
from mcp_server_qdrant.embeddings.base import EmbeddingProvider
|
||||
from mcp_server_qdrant.settings import METADATA_PATH
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Metadata = Dict[str, Any]
|
||||
|
||||
ArbitraryFilter = Dict[str, Any]
|
||||
|
||||
|
||||
class Entry(BaseModel):
|
||||
"""
|
||||
@@ -39,6 +42,7 @@ class QdrantConnector:
|
||||
collection_name: Optional[str],
|
||||
embedding_provider: EmbeddingProvider,
|
||||
qdrant_local_path: Optional[str] = None,
|
||||
field_indexes: Optional[dict[str, models.PayloadSchemaType]] = None,
|
||||
):
|
||||
self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None
|
||||
self._qdrant_api_key = qdrant_api_key
|
||||
@@ -47,6 +51,7 @@ class QdrantConnector:
|
||||
self._client = AsyncQdrantClient(
|
||||
location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path
|
||||
)
|
||||
self._field_indexes = field_indexes
|
||||
|
||||
async def get_collection_names(self) -> list[str]:
|
||||
"""
|
||||
@@ -74,7 +79,7 @@ class QdrantConnector:
|
||||
|
||||
# Add to Qdrant
|
||||
vector_name = self._embedding_provider.get_vector_name()
|
||||
payload = {"document": entry.content, "metadata": entry.metadata}
|
||||
payload = {"document": entry.content, METADATA_PATH: entry.metadata}
|
||||
await self._client.upsert(
|
||||
collection_name=collection_name,
|
||||
points=[
|
||||
@@ -87,7 +92,12 @@ class QdrantConnector:
|
||||
)
|
||||
|
||||
async def search(
|
||||
self, query: str, *, collection_name: Optional[str] = None, limit: int = 10
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
collection_name: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
query_filter: Optional[models.Filter] = None,
|
||||
) -> list[Entry]:
|
||||
"""
|
||||
Find points in the Qdrant collection. If there are no entries found, an empty list is returned.
|
||||
@@ -115,6 +125,7 @@ class QdrantConnector:
|
||||
query=query_vector,
|
||||
using=vector_name,
|
||||
limit=limit,
|
||||
query_filter=query_filter,
|
||||
)
|
||||
|
||||
return [
|
||||
@@ -146,3 +157,13 @@ class QdrantConnector:
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# Create payload indexes if configured
|
||||
|
||||
if self._field_indexes:
|
||||
for field_name, field_type in self._field_indexes.items():
|
||||
await self._client.create_payload_index(
|
||||
collection_name=collection_name,
|
||||
field_name=field_name,
|
||||
field_schema=field_type,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user