Add BM25 hybrid search (dense + sparse vectors with RRF)
Some checks failed
pre-commit / main (push) Has been cancelled
Run Tests / Python 3.10 (push) Has been cancelled
Run Tests / Python 3.11 (push) Has been cancelled
Run Tests / Python 3.12 (push) Has been cancelled
Run Tests / Python 3.13 (push) Has been cancelled

- Add SparseTextEmbedding("Qdrant/bm25") to FastEmbedProvider for BM25 tokenization
- Add sparse vector config (IDF modifier) to collection creation
- Store both dense and sparse vectors per document
- Use Qdrant prefetch + Reciprocal Rank Fusion for hybrid search
- Add HYBRID_SEARCH env var (default: false) for backward compatibility

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Mr. Kutin
2026-04-03 10:27:52 +03:00
parent e4ec69b2da
commit e9f0a1fa4a
6 changed files with 134 additions and 24 deletions

View File

@@ -1,4 +1,13 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
@dataclass
class SparseVector:
"""A sparse vector representation with indices and values."""
indices: list[int]
values: list[float]
class EmbeddingProvider(ABC): class EmbeddingProvider(ABC):
@@ -23,3 +32,15 @@ class EmbeddingProvider(ABC):
def get_vector_size(self) -> int: def get_vector_size(self) -> int:
"""Get the size of the vector for the Qdrant collection.""" """Get the size of the vector for the Qdrant collection."""
pass pass
def supports_sparse(self) -> bool:
"""Whether this provider supports sparse (BM25) embeddings."""
return False
async def embed_documents_sparse(self, documents: list[str]) -> list[SparseVector]:
"""Embed documents into sparse vectors. Override if supports_sparse() is True."""
raise NotImplementedError
async def embed_query_sparse(self, query: str) -> SparseVector:
"""Embed a query into a sparse vector. Override if supports_sparse() is True."""
raise NotImplementedError

View File

@@ -3,15 +3,18 @@ from mcp_server_qdrant.embeddings.types import EmbeddingProviderType
from mcp_server_qdrant.settings import EmbeddingProviderSettings from mcp_server_qdrant.settings import EmbeddingProviderSettings
def create_embedding_provider(settings: EmbeddingProviderSettings) -> EmbeddingProvider: def create_embedding_provider(
settings: EmbeddingProviderSettings, enable_sparse: bool = False
) -> EmbeddingProvider:
""" """
Create an embedding provider based on the specified type. Create an embedding provider based on the specified type.
:param settings: The settings for the embedding provider. :param settings: The settings for the embedding provider.
:param enable_sparse: Whether to enable sparse (BM25) embeddings.
:return: An instance of the specified embedding provider. :return: An instance of the specified embedding provider.
""" """
if settings.provider_type == EmbeddingProviderType.FASTEMBED: if settings.provider_type == EmbeddingProviderType.FASTEMBED:
from mcp_server_qdrant.embeddings.fastembed import FastEmbedProvider from mcp_server_qdrant.embeddings.fastembed import FastEmbedProvider
return FastEmbedProvider(settings.model_name) return FastEmbedProvider(settings.model_name, enable_sparse=enable_sparse)
else: else:
raise ValueError(f"Unsupported embedding provider: {settings.provider_type}") raise ValueError(f"Unsupported embedding provider: {settings.provider_type}")

View File

@@ -1,24 +1,31 @@
import asyncio import asyncio
from fastembed import TextEmbedding from fastembed import SparseTextEmbedding, TextEmbedding
from fastembed.common.model_description import DenseModelDescription from fastembed.common.model_description import DenseModelDescription
from mcp_server_qdrant.embeddings.base import EmbeddingProvider from mcp_server_qdrant.embeddings.base import EmbeddingProvider, SparseVector
class FastEmbedProvider(EmbeddingProvider): class FastEmbedProvider(EmbeddingProvider):
""" """
FastEmbed implementation of the embedding provider. FastEmbed implementation of the embedding provider.
:param model_name: The name of the FastEmbed model to use. :param model_name: The name of the FastEmbed model to use.
:param enable_sparse: Whether to enable BM25 sparse embeddings for hybrid search.
""" """
def __init__(self, model_name: str): def __init__(self, model_name: str, enable_sparse: bool = False):
self.model_name = model_name self.model_name = model_name
self.embedding_model = TextEmbedding(model_name) self.embedding_model = TextEmbedding(model_name)
self._enable_sparse = enable_sparse
self._sparse_model = None
if enable_sparse:
self._sparse_model = SparseTextEmbedding("Qdrant/bm25")
def supports_sparse(self) -> bool:
return self._enable_sparse and self._sparse_model is not None
async def embed_documents(self, documents: list[str]) -> list[list[float]]: async def embed_documents(self, documents: list[str]) -> list[list[float]]:
"""Embed a list of documents into vectors.""" """Embed a list of documents into vectors."""
# Run in a thread pool since FastEmbed is synchronous
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
embeddings = await loop.run_in_executor( embeddings = await loop.run_in_executor(
None, lambda: list(self.embedding_model.passage_embed(documents)) None, lambda: list(self.embedding_model.passage_embed(documents))
@@ -27,13 +34,37 @@ class FastEmbedProvider(EmbeddingProvider):
async def embed_query(self, query: str) -> list[float]: async def embed_query(self, query: str) -> list[float]:
"""Embed a query into a vector.""" """Embed a query into a vector."""
# Run in a thread pool since FastEmbed is synchronous
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
embeddings = await loop.run_in_executor( embeddings = await loop.run_in_executor(
None, lambda: list(self.embedding_model.query_embed([query])) None, lambda: list(self.embedding_model.query_embed([query]))
) )
return embeddings[0].tolist() return embeddings[0].tolist()
async def embed_documents_sparse(self, documents: list[str]) -> list[SparseVector]:
"""Embed documents into BM25 sparse vectors."""
loop = asyncio.get_event_loop()
results = await loop.run_in_executor(
None, lambda: list(self._sparse_model.passage_embed(documents))
)
return [
SparseVector(
indices=r.indices.tolist(),
values=r.values.tolist(),
)
for r in results
]
async def embed_query_sparse(self, query: str) -> SparseVector:
"""Embed a query into a BM25 sparse vector."""
loop = asyncio.get_event_loop()
results = await loop.run_in_executor(
None, lambda: list(self._sparse_model.query_embed([query]))
)
return SparseVector(
indices=results[0].indices.tolist(),
values=results[0].values.tolist(),
)
def get_vector_name(self) -> str: def get_vector_name(self) -> str:
""" """
Return the name of the vector for the Qdrant collection. Return the name of the vector for the Qdrant collection.

View File

@@ -57,7 +57,8 @@ class QdrantMCPServer(FastMCP):
if embedding_provider_settings: if embedding_provider_settings:
self.embedding_provider_settings = embedding_provider_settings self.embedding_provider_settings = embedding_provider_settings
self.embedding_provider = create_embedding_provider( self.embedding_provider = create_embedding_provider(
embedding_provider_settings embedding_provider_settings,
enable_sparse=qdrant_settings.hybrid_search,
) )
else: else:
self.embedding_provider_settings = None self.embedding_provider_settings = None
@@ -72,6 +73,7 @@ class QdrantMCPServer(FastMCP):
self.embedding_provider, self.embedding_provider,
qdrant_settings.local_path, qdrant_settings.local_path,
make_indexes(qdrant_settings.filterable_fields_dict()), make_indexes(qdrant_settings.filterable_fields_dict()),
hybrid_search=qdrant_settings.hybrid_search,
) )
super().__init__(name=name, instructions=instructions, **settings) super().__init__(name=name, instructions=instructions, **settings)

View File

@@ -23,6 +23,9 @@ class Entry(BaseModel):
metadata: Metadata | None = None metadata: Metadata | None = None
SPARSE_VECTOR_NAME = "bm25"
class QdrantConnector: class QdrantConnector:
""" """
Encapsulates the connection to a Qdrant server and all the methods to interact with it. Encapsulates the connection to a Qdrant server and all the methods to interact with it.
@@ -32,6 +35,7 @@ class QdrantConnector:
the collection name to be provided. the collection name to be provided.
:param embedding_provider: The embedding provider to use. :param embedding_provider: The embedding provider to use.
:param qdrant_local_path: The path to the storage directory for the Qdrant client, if local mode is used. :param qdrant_local_path: The path to the storage directory for the Qdrant client, if local mode is used.
:param hybrid_search: Whether to enable hybrid search (dense + BM25 sparse vectors with RRF).
""" """
def __init__( def __init__(
@@ -42,15 +46,19 @@ class QdrantConnector:
embedding_provider: EmbeddingProvider, embedding_provider: EmbeddingProvider,
qdrant_local_path: str | None = None, qdrant_local_path: str | None = None,
field_indexes: dict[str, models.PayloadSchemaType] | None = None, field_indexes: dict[str, models.PayloadSchemaType] | None = None,
hybrid_search: bool = False,
): ):
self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None
self._qdrant_api_key = qdrant_api_key self._qdrant_api_key = qdrant_api_key
self._default_collection_name = collection_name self._default_collection_name = collection_name
self._embedding_provider = embedding_provider self._embedding_provider = embedding_provider
self._hybrid_search = hybrid_search and embedding_provider.supports_sparse()
self._client = AsyncQdrantClient( self._client = AsyncQdrantClient(
location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path
) )
self._field_indexes = field_indexes self._field_indexes = field_indexes
if self._hybrid_search:
logger.info("Hybrid search enabled (dense + BM25 sparse vectors with RRF)")
async def get_collection_names(self) -> list[str]: async def get_collection_names(self) -> list[str]:
""" """
@@ -72,19 +80,30 @@ class QdrantConnector:
await self._ensure_collection_exists(collection_name) await self._ensure_collection_exists(collection_name)
# Embed the document # Embed the document
# ToDo: instead of embedding text explicitly, use `models.Document`,
# it should unlock usage of server-side inference.
embeddings = await self._embedding_provider.embed_documents([entry.content]) embeddings = await self._embedding_provider.embed_documents([entry.content])
# Add to Qdrant # Build vector dict
vector_name = self._embedding_provider.get_vector_name() vector_name = self._embedding_provider.get_vector_name()
vector_data: dict = {vector_name: embeddings[0]}
# Add sparse vector if hybrid search is enabled
if self._hybrid_search:
sparse_embeddings = await self._embedding_provider.embed_documents_sparse(
[entry.content]
)
sparse = sparse_embeddings[0]
vector_data[SPARSE_VECTOR_NAME] = models.SparseVector(
indices=sparse.indices, values=sparse.values
)
# Add to Qdrant
payload = {"document": entry.content, METADATA_PATH: entry.metadata} payload = {"document": entry.content, METADATA_PATH: entry.metadata}
await self._client.upsert( await self._client.upsert(
collection_name=collection_name, collection_name=collection_name,
points=[ points=[
models.PointStruct( models.PointStruct(
id=uuid.uuid4().hex, id=uuid.uuid4().hex,
vector={vector_name: embeddings[0]}, vector=vector_data,
payload=payload, payload=payload,
) )
], ],
@@ -113,21 +132,43 @@ class QdrantConnector:
if not collection_exists: if not collection_exists:
return [] return []
# Embed the query
# ToDo: instead of embedding text explicitly, use `models.Document`,
# it should unlock usage of server-side inference.
query_vector = await self._embedding_provider.embed_query(query) query_vector = await self._embedding_provider.embed_query(query)
vector_name = self._embedding_provider.get_vector_name() vector_name = self._embedding_provider.get_vector_name()
# Search in Qdrant # Hybrid search: prefetch dense + sparse, fuse with RRF
search_results = await self._client.query_points( if self._hybrid_search:
collection_name=collection_name, sparse_vector = await self._embedding_provider.embed_query_sparse(query)
query=query_vector, search_results = await self._client.query_points(
using=vector_name, collection_name=collection_name,
limit=limit, prefetch=[
query_filter=query_filter, models.Prefetch(
) query=query_vector,
using=vector_name,
limit=limit,
filter=query_filter,
),
models.Prefetch(
query=models.SparseVector(
indices=sparse_vector.indices,
values=sparse_vector.values,
),
using=SPARSE_VECTOR_NAME,
limit=limit,
filter=query_filter,
),
],
query=models.FusionQuery(fusion=models.Fusion.RRF),
limit=limit,
)
else:
# Dense-only search (original behavior)
search_results = await self._client.query_points(
collection_name=collection_name,
query=query_vector,
using=vector_name,
limit=limit,
query_filter=query_filter,
)
return [ return [
Entry( Entry(
@@ -149,6 +190,16 @@ class QdrantConnector:
# Use the vector name as defined in the embedding provider # Use the vector name as defined in the embedding provider
vector_name = self._embedding_provider.get_vector_name() vector_name = self._embedding_provider.get_vector_name()
# Sparse vectors config for hybrid search (BM25)
sparse_config = None
if self._hybrid_search:
sparse_config = {
SPARSE_VECTOR_NAME: models.SparseVectorParams(
modifier=models.Modifier.IDF,
)
}
await self._client.create_collection( await self._client.create_collection(
collection_name=collection_name, collection_name=collection_name,
vectors_config={ vectors_config={
@@ -157,6 +208,7 @@ class QdrantConnector:
distance=models.Distance.COSINE, distance=models.Distance.COSINE,
) )
}, },
sparse_vectors_config=sparse_config,
) )
# Create payload indexes if configured # Create payload indexes if configured

View File

@@ -78,6 +78,7 @@ class QdrantSettings(BaseSettings):
location: str | None = Field(default=None, validation_alias="QDRANT_URL") location: str | None = Field(default=None, validation_alias="QDRANT_URL")
api_key: str | None = Field(default=None, validation_alias="QDRANT_API_KEY") api_key: str | None = Field(default=None, validation_alias="QDRANT_API_KEY")
hybrid_search: bool = Field(default=False, validation_alias="HYBRID_SEARCH")
collection_name: str | None = Field( collection_name: str | None = Field(
default=None, validation_alias="COLLECTION_NAME" default=None, validation_alias="COLLECTION_NAME"
) )