From e9f0a1fa4ad1e4ffb5ca59e76b3996add2f29b45 Mon Sep 17 00:00:00 2001 From: "Mr. Kutin" Date: Fri, 3 Apr 2026 10:27:52 +0300 Subject: [PATCH] Add BM25 hybrid search (dense + sparse vectors with RRF) - Add SparseTextEmbedding("Qdrant/bm25") to FastEmbedProvider for BM25 tokenization - Add sparse vector config (IDF modifier) to collection creation - Store both dense and sparse vectors per document - Use Qdrant prefetch + Reciprocal Rank Fusion for hybrid search - Add HYBRID_SEARCH env var (default: false) for backward compatibility Co-Authored-By: Claude Opus 4.6 --- src/mcp_server_qdrant/embeddings/base.py | 21 +++++ src/mcp_server_qdrant/embeddings/factory.py | 7 +- src/mcp_server_qdrant/embeddings/fastembed.py | 41 +++++++-- src/mcp_server_qdrant/mcp_server.py | 4 +- src/mcp_server_qdrant/qdrant.py | 84 +++++++++++++++---- src/mcp_server_qdrant/settings.py | 1 + 6 files changed, 134 insertions(+), 24 deletions(-) diff --git a/src/mcp_server_qdrant/embeddings/base.py b/src/mcp_server_qdrant/embeddings/base.py index 5c47a17..abcbaff 100644 --- a/src/mcp_server_qdrant/embeddings/base.py +++ b/src/mcp_server_qdrant/embeddings/base.py @@ -1,4 +1,13 @@ from abc import ABC, abstractmethod +from dataclasses import dataclass + + +@dataclass +class SparseVector: + """A sparse vector representation with indices and values.""" + + indices: list[int] + values: list[float] class EmbeddingProvider(ABC): @@ -23,3 +32,15 @@ class EmbeddingProvider(ABC): def get_vector_size(self) -> int: """Get the size of the vector for the Qdrant collection.""" pass + + def supports_sparse(self) -> bool: + """Whether this provider supports sparse (BM25) embeddings.""" + return False + + async def embed_documents_sparse(self, documents: list[str]) -> list[SparseVector]: + """Embed documents into sparse vectors. Override if supports_sparse() is True.""" + raise NotImplementedError + + async def embed_query_sparse(self, query: str) -> SparseVector: + """Embed a query into a sparse vector. Override if supports_sparse() is True.""" + raise NotImplementedError diff --git a/src/mcp_server_qdrant/embeddings/factory.py b/src/mcp_server_qdrant/embeddings/factory.py index efac851..519d393 100644 --- a/src/mcp_server_qdrant/embeddings/factory.py +++ b/src/mcp_server_qdrant/embeddings/factory.py @@ -3,15 +3,18 @@ from mcp_server_qdrant.embeddings.types import EmbeddingProviderType from mcp_server_qdrant.settings import EmbeddingProviderSettings -def create_embedding_provider(settings: EmbeddingProviderSettings) -> EmbeddingProvider: +def create_embedding_provider( + settings: EmbeddingProviderSettings, enable_sparse: bool = False +) -> EmbeddingProvider: """ Create an embedding provider based on the specified type. :param settings: The settings for the embedding provider. + :param enable_sparse: Whether to enable sparse (BM25) embeddings. :return: An instance of the specified embedding provider. """ if settings.provider_type == EmbeddingProviderType.FASTEMBED: from mcp_server_qdrant.embeddings.fastembed import FastEmbedProvider - return FastEmbedProvider(settings.model_name) + return FastEmbedProvider(settings.model_name, enable_sparse=enable_sparse) else: raise ValueError(f"Unsupported embedding provider: {settings.provider_type}") diff --git a/src/mcp_server_qdrant/embeddings/fastembed.py b/src/mcp_server_qdrant/embeddings/fastembed.py index 1655f83..7e72454 100644 --- a/src/mcp_server_qdrant/embeddings/fastembed.py +++ b/src/mcp_server_qdrant/embeddings/fastembed.py @@ -1,24 +1,31 @@ import asyncio -from fastembed import TextEmbedding +from fastembed import SparseTextEmbedding, TextEmbedding from fastembed.common.model_description import DenseModelDescription -from mcp_server_qdrant.embeddings.base import EmbeddingProvider +from mcp_server_qdrant.embeddings.base import EmbeddingProvider, SparseVector class FastEmbedProvider(EmbeddingProvider): """ FastEmbed implementation of the embedding provider. :param model_name: The name of the FastEmbed model to use. + :param enable_sparse: Whether to enable BM25 sparse embeddings for hybrid search. """ - def __init__(self, model_name: str): + def __init__(self, model_name: str, enable_sparse: bool = False): self.model_name = model_name self.embedding_model = TextEmbedding(model_name) + self._enable_sparse = enable_sparse + self._sparse_model = None + if enable_sparse: + self._sparse_model = SparseTextEmbedding("Qdrant/bm25") + + def supports_sparse(self) -> bool: + return self._enable_sparse and self._sparse_model is not None async def embed_documents(self, documents: list[str]) -> list[list[float]]: """Embed a list of documents into vectors.""" - # Run in a thread pool since FastEmbed is synchronous loop = asyncio.get_event_loop() embeddings = await loop.run_in_executor( None, lambda: list(self.embedding_model.passage_embed(documents)) @@ -27,13 +34,37 @@ class FastEmbedProvider(EmbeddingProvider): async def embed_query(self, query: str) -> list[float]: """Embed a query into a vector.""" - # Run in a thread pool since FastEmbed is synchronous loop = asyncio.get_event_loop() embeddings = await loop.run_in_executor( None, lambda: list(self.embedding_model.query_embed([query])) ) return embeddings[0].tolist() + async def embed_documents_sparse(self, documents: list[str]) -> list[SparseVector]: + """Embed documents into BM25 sparse vectors.""" + loop = asyncio.get_event_loop() + results = await loop.run_in_executor( + None, lambda: list(self._sparse_model.passage_embed(documents)) + ) + return [ + SparseVector( + indices=r.indices.tolist(), + values=r.values.tolist(), + ) + for r in results + ] + + async def embed_query_sparse(self, query: str) -> SparseVector: + """Embed a query into a BM25 sparse vector.""" + loop = asyncio.get_event_loop() + results = await loop.run_in_executor( + None, lambda: list(self._sparse_model.query_embed([query])) + ) + return SparseVector( + indices=results[0].indices.tolist(), + values=results[0].values.tolist(), + ) + def get_vector_name(self) -> str: """ Return the name of the vector for the Qdrant collection. diff --git a/src/mcp_server_qdrant/mcp_server.py b/src/mcp_server_qdrant/mcp_server.py index 0617b9d..58a7cc0 100644 --- a/src/mcp_server_qdrant/mcp_server.py +++ b/src/mcp_server_qdrant/mcp_server.py @@ -57,7 +57,8 @@ class QdrantMCPServer(FastMCP): if embedding_provider_settings: self.embedding_provider_settings = embedding_provider_settings self.embedding_provider = create_embedding_provider( - embedding_provider_settings + embedding_provider_settings, + enable_sparse=qdrant_settings.hybrid_search, ) else: self.embedding_provider_settings = None @@ -72,6 +73,7 @@ class QdrantMCPServer(FastMCP): self.embedding_provider, qdrant_settings.local_path, make_indexes(qdrant_settings.filterable_fields_dict()), + hybrid_search=qdrant_settings.hybrid_search, ) super().__init__(name=name, instructions=instructions, **settings) diff --git a/src/mcp_server_qdrant/qdrant.py b/src/mcp_server_qdrant/qdrant.py index 8d3e5aa..0030cac 100644 --- a/src/mcp_server_qdrant/qdrant.py +++ b/src/mcp_server_qdrant/qdrant.py @@ -23,6 +23,9 @@ class Entry(BaseModel): metadata: Metadata | None = None +SPARSE_VECTOR_NAME = "bm25" + + class QdrantConnector: """ Encapsulates the connection to a Qdrant server and all the methods to interact with it. @@ -32,6 +35,7 @@ class QdrantConnector: the collection name to be provided. :param embedding_provider: The embedding provider to use. :param qdrant_local_path: The path to the storage directory for the Qdrant client, if local mode is used. + :param hybrid_search: Whether to enable hybrid search (dense + BM25 sparse vectors with RRF). """ def __init__( @@ -42,15 +46,19 @@ class QdrantConnector: embedding_provider: EmbeddingProvider, qdrant_local_path: str | None = None, field_indexes: dict[str, models.PayloadSchemaType] | None = None, + hybrid_search: bool = False, ): self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None self._qdrant_api_key = qdrant_api_key self._default_collection_name = collection_name self._embedding_provider = embedding_provider + self._hybrid_search = hybrid_search and embedding_provider.supports_sparse() self._client = AsyncQdrantClient( location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path ) self._field_indexes = field_indexes + if self._hybrid_search: + logger.info("Hybrid search enabled (dense + BM25 sparse vectors with RRF)") async def get_collection_names(self) -> list[str]: """ @@ -72,19 +80,30 @@ class QdrantConnector: await self._ensure_collection_exists(collection_name) # Embed the document - # ToDo: instead of embedding text explicitly, use `models.Document`, - # it should unlock usage of server-side inference. embeddings = await self._embedding_provider.embed_documents([entry.content]) - # Add to Qdrant + # Build vector dict vector_name = self._embedding_provider.get_vector_name() + vector_data: dict = {vector_name: embeddings[0]} + + # Add sparse vector if hybrid search is enabled + if self._hybrid_search: + sparse_embeddings = await self._embedding_provider.embed_documents_sparse( + [entry.content] + ) + sparse = sparse_embeddings[0] + vector_data[SPARSE_VECTOR_NAME] = models.SparseVector( + indices=sparse.indices, values=sparse.values + ) + + # Add to Qdrant payload = {"document": entry.content, METADATA_PATH: entry.metadata} await self._client.upsert( collection_name=collection_name, points=[ models.PointStruct( id=uuid.uuid4().hex, - vector={vector_name: embeddings[0]}, + vector=vector_data, payload=payload, ) ], @@ -113,21 +132,43 @@ class QdrantConnector: if not collection_exists: return [] - # Embed the query - # ToDo: instead of embedding text explicitly, use `models.Document`, - # it should unlock usage of server-side inference. - query_vector = await self._embedding_provider.embed_query(query) vector_name = self._embedding_provider.get_vector_name() - # Search in Qdrant - search_results = await self._client.query_points( - collection_name=collection_name, - query=query_vector, - using=vector_name, - limit=limit, - query_filter=query_filter, - ) + # Hybrid search: prefetch dense + sparse, fuse with RRF + if self._hybrid_search: + sparse_vector = await self._embedding_provider.embed_query_sparse(query) + search_results = await self._client.query_points( + collection_name=collection_name, + prefetch=[ + models.Prefetch( + query=query_vector, + using=vector_name, + limit=limit, + filter=query_filter, + ), + models.Prefetch( + query=models.SparseVector( + indices=sparse_vector.indices, + values=sparse_vector.values, + ), + using=SPARSE_VECTOR_NAME, + limit=limit, + filter=query_filter, + ), + ], + query=models.FusionQuery(fusion=models.Fusion.RRF), + limit=limit, + ) + else: + # Dense-only search (original behavior) + search_results = await self._client.query_points( + collection_name=collection_name, + query=query_vector, + using=vector_name, + limit=limit, + query_filter=query_filter, + ) return [ Entry( @@ -149,6 +190,16 @@ class QdrantConnector: # Use the vector name as defined in the embedding provider vector_name = self._embedding_provider.get_vector_name() + + # Sparse vectors config for hybrid search (BM25) + sparse_config = None + if self._hybrid_search: + sparse_config = { + SPARSE_VECTOR_NAME: models.SparseVectorParams( + modifier=models.Modifier.IDF, + ) + } + await self._client.create_collection( collection_name=collection_name, vectors_config={ @@ -157,6 +208,7 @@ class QdrantConnector: distance=models.Distance.COSINE, ) }, + sparse_vectors_config=sparse_config, ) # Create payload indexes if configured diff --git a/src/mcp_server_qdrant/settings.py b/src/mcp_server_qdrant/settings.py index e48c10d..2d334d4 100644 --- a/src/mcp_server_qdrant/settings.py +++ b/src/mcp_server_qdrant/settings.py @@ -78,6 +78,7 @@ class QdrantSettings(BaseSettings): location: str | None = Field(default=None, validation_alias="QDRANT_URL") api_key: str | None = Field(default=None, validation_alias="QDRANT_API_KEY") + hybrid_search: bool = Field(default=False, validation_alias="HYBRID_SEARCH") collection_name: str | None = Field( default=None, validation_alias="COLLECTION_NAME" )