Add BM25 hybrid search (dense + sparse vectors with RRF)
Some checks failed
Some checks failed
- Add SparseTextEmbedding("Qdrant/bm25") to FastEmbedProvider for BM25 tokenization
- Add sparse vector config (IDF modifier) to collection creation
- Store both dense and sparse vectors per document
- Use Qdrant prefetch + Reciprocal Rank Fusion for hybrid search
- Add HYBRID_SEARCH env var (default: false) for backward compatibility
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,4 +1,13 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SparseVector:
|
||||||
|
"""A sparse vector representation with indices and values."""
|
||||||
|
|
||||||
|
indices: list[int]
|
||||||
|
values: list[float]
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingProvider(ABC):
|
class EmbeddingProvider(ABC):
|
||||||
@@ -23,3 +32,15 @@ class EmbeddingProvider(ABC):
|
|||||||
def get_vector_size(self) -> int:
|
def get_vector_size(self) -> int:
|
||||||
"""Get the size of the vector for the Qdrant collection."""
|
"""Get the size of the vector for the Qdrant collection."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def supports_sparse(self) -> bool:
|
||||||
|
"""Whether this provider supports sparse (BM25) embeddings."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def embed_documents_sparse(self, documents: list[str]) -> list[SparseVector]:
|
||||||
|
"""Embed documents into sparse vectors. Override if supports_sparse() is True."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def embed_query_sparse(self, query: str) -> SparseVector:
|
||||||
|
"""Embed a query into a sparse vector. Override if supports_sparse() is True."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -3,15 +3,18 @@ from mcp_server_qdrant.embeddings.types import EmbeddingProviderType
|
|||||||
from mcp_server_qdrant.settings import EmbeddingProviderSettings
|
from mcp_server_qdrant.settings import EmbeddingProviderSettings
|
||||||
|
|
||||||
|
|
||||||
def create_embedding_provider(settings: EmbeddingProviderSettings) -> EmbeddingProvider:
|
def create_embedding_provider(
|
||||||
|
settings: EmbeddingProviderSettings, enable_sparse: bool = False
|
||||||
|
) -> EmbeddingProvider:
|
||||||
"""
|
"""
|
||||||
Create an embedding provider based on the specified type.
|
Create an embedding provider based on the specified type.
|
||||||
:param settings: The settings for the embedding provider.
|
:param settings: The settings for the embedding provider.
|
||||||
|
:param enable_sparse: Whether to enable sparse (BM25) embeddings.
|
||||||
:return: An instance of the specified embedding provider.
|
:return: An instance of the specified embedding provider.
|
||||||
"""
|
"""
|
||||||
if settings.provider_type == EmbeddingProviderType.FASTEMBED:
|
if settings.provider_type == EmbeddingProviderType.FASTEMBED:
|
||||||
from mcp_server_qdrant.embeddings.fastembed import FastEmbedProvider
|
from mcp_server_qdrant.embeddings.fastembed import FastEmbedProvider
|
||||||
|
|
||||||
return FastEmbedProvider(settings.model_name)
|
return FastEmbedProvider(settings.model_name, enable_sparse=enable_sparse)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported embedding provider: {settings.provider_type}")
|
raise ValueError(f"Unsupported embedding provider: {settings.provider_type}")
|
||||||
|
|||||||
@@ -1,24 +1,31 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from fastembed import TextEmbedding
|
from fastembed import SparseTextEmbedding, TextEmbedding
|
||||||
from fastembed.common.model_description import DenseModelDescription
|
from fastembed.common.model_description import DenseModelDescription
|
||||||
|
|
||||||
from mcp_server_qdrant.embeddings.base import EmbeddingProvider
|
from mcp_server_qdrant.embeddings.base import EmbeddingProvider, SparseVector
|
||||||
|
|
||||||
|
|
||||||
class FastEmbedProvider(EmbeddingProvider):
|
class FastEmbedProvider(EmbeddingProvider):
|
||||||
"""
|
"""
|
||||||
FastEmbed implementation of the embedding provider.
|
FastEmbed implementation of the embedding provider.
|
||||||
:param model_name: The name of the FastEmbed model to use.
|
:param model_name: The name of the FastEmbed model to use.
|
||||||
|
:param enable_sparse: Whether to enable BM25 sparse embeddings for hybrid search.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_name: str):
|
def __init__(self, model_name: str, enable_sparse: bool = False):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.embedding_model = TextEmbedding(model_name)
|
self.embedding_model = TextEmbedding(model_name)
|
||||||
|
self._enable_sparse = enable_sparse
|
||||||
|
self._sparse_model = None
|
||||||
|
if enable_sparse:
|
||||||
|
self._sparse_model = SparseTextEmbedding("Qdrant/bm25")
|
||||||
|
|
||||||
|
def supports_sparse(self) -> bool:
|
||||||
|
return self._enable_sparse and self._sparse_model is not None
|
||||||
|
|
||||||
async def embed_documents(self, documents: list[str]) -> list[list[float]]:
|
async def embed_documents(self, documents: list[str]) -> list[list[float]]:
|
||||||
"""Embed a list of documents into vectors."""
|
"""Embed a list of documents into vectors."""
|
||||||
# Run in a thread pool since FastEmbed is synchronous
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
embeddings = await loop.run_in_executor(
|
embeddings = await loop.run_in_executor(
|
||||||
None, lambda: list(self.embedding_model.passage_embed(documents))
|
None, lambda: list(self.embedding_model.passage_embed(documents))
|
||||||
@@ -27,13 +34,37 @@ class FastEmbedProvider(EmbeddingProvider):
|
|||||||
|
|
||||||
async def embed_query(self, query: str) -> list[float]:
|
async def embed_query(self, query: str) -> list[float]:
|
||||||
"""Embed a query into a vector."""
|
"""Embed a query into a vector."""
|
||||||
# Run in a thread pool since FastEmbed is synchronous
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
embeddings = await loop.run_in_executor(
|
embeddings = await loop.run_in_executor(
|
||||||
None, lambda: list(self.embedding_model.query_embed([query]))
|
None, lambda: list(self.embedding_model.query_embed([query]))
|
||||||
)
|
)
|
||||||
return embeddings[0].tolist()
|
return embeddings[0].tolist()
|
||||||
|
|
||||||
|
async def embed_documents_sparse(self, documents: list[str]) -> list[SparseVector]:
|
||||||
|
"""Embed documents into BM25 sparse vectors."""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
results = await loop.run_in_executor(
|
||||||
|
None, lambda: list(self._sparse_model.passage_embed(documents))
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
SparseVector(
|
||||||
|
indices=r.indices.tolist(),
|
||||||
|
values=r.values.tolist(),
|
||||||
|
)
|
||||||
|
for r in results
|
||||||
|
]
|
||||||
|
|
||||||
|
async def embed_query_sparse(self, query: str) -> SparseVector:
|
||||||
|
"""Embed a query into a BM25 sparse vector."""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
results = await loop.run_in_executor(
|
||||||
|
None, lambda: list(self._sparse_model.query_embed([query]))
|
||||||
|
)
|
||||||
|
return SparseVector(
|
||||||
|
indices=results[0].indices.tolist(),
|
||||||
|
values=results[0].values.tolist(),
|
||||||
|
)
|
||||||
|
|
||||||
def get_vector_name(self) -> str:
|
def get_vector_name(self) -> str:
|
||||||
"""
|
"""
|
||||||
Return the name of the vector for the Qdrant collection.
|
Return the name of the vector for the Qdrant collection.
|
||||||
|
|||||||
@@ -57,7 +57,8 @@ class QdrantMCPServer(FastMCP):
|
|||||||
if embedding_provider_settings:
|
if embedding_provider_settings:
|
||||||
self.embedding_provider_settings = embedding_provider_settings
|
self.embedding_provider_settings = embedding_provider_settings
|
||||||
self.embedding_provider = create_embedding_provider(
|
self.embedding_provider = create_embedding_provider(
|
||||||
embedding_provider_settings
|
embedding_provider_settings,
|
||||||
|
enable_sparse=qdrant_settings.hybrid_search,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.embedding_provider_settings = None
|
self.embedding_provider_settings = None
|
||||||
@@ -72,6 +73,7 @@ class QdrantMCPServer(FastMCP):
|
|||||||
self.embedding_provider,
|
self.embedding_provider,
|
||||||
qdrant_settings.local_path,
|
qdrant_settings.local_path,
|
||||||
make_indexes(qdrant_settings.filterable_fields_dict()),
|
make_indexes(qdrant_settings.filterable_fields_dict()),
|
||||||
|
hybrid_search=qdrant_settings.hybrid_search,
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(name=name, instructions=instructions, **settings)
|
super().__init__(name=name, instructions=instructions, **settings)
|
||||||
|
|||||||
@@ -23,6 +23,9 @@ class Entry(BaseModel):
|
|||||||
metadata: Metadata | None = None
|
metadata: Metadata | None = None
|
||||||
|
|
||||||
|
|
||||||
|
SPARSE_VECTOR_NAME = "bm25"
|
||||||
|
|
||||||
|
|
||||||
class QdrantConnector:
|
class QdrantConnector:
|
||||||
"""
|
"""
|
||||||
Encapsulates the connection to a Qdrant server and all the methods to interact with it.
|
Encapsulates the connection to a Qdrant server and all the methods to interact with it.
|
||||||
@@ -32,6 +35,7 @@ class QdrantConnector:
|
|||||||
the collection name to be provided.
|
the collection name to be provided.
|
||||||
:param embedding_provider: The embedding provider to use.
|
:param embedding_provider: The embedding provider to use.
|
||||||
:param qdrant_local_path: The path to the storage directory for the Qdrant client, if local mode is used.
|
:param qdrant_local_path: The path to the storage directory for the Qdrant client, if local mode is used.
|
||||||
|
:param hybrid_search: Whether to enable hybrid search (dense + BM25 sparse vectors with RRF).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -42,15 +46,19 @@ class QdrantConnector:
|
|||||||
embedding_provider: EmbeddingProvider,
|
embedding_provider: EmbeddingProvider,
|
||||||
qdrant_local_path: str | None = None,
|
qdrant_local_path: str | None = None,
|
||||||
field_indexes: dict[str, models.PayloadSchemaType] | None = None,
|
field_indexes: dict[str, models.PayloadSchemaType] | None = None,
|
||||||
|
hybrid_search: bool = False,
|
||||||
):
|
):
|
||||||
self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None
|
self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None
|
||||||
self._qdrant_api_key = qdrant_api_key
|
self._qdrant_api_key = qdrant_api_key
|
||||||
self._default_collection_name = collection_name
|
self._default_collection_name = collection_name
|
||||||
self._embedding_provider = embedding_provider
|
self._embedding_provider = embedding_provider
|
||||||
|
self._hybrid_search = hybrid_search and embedding_provider.supports_sparse()
|
||||||
self._client = AsyncQdrantClient(
|
self._client = AsyncQdrantClient(
|
||||||
location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path
|
location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path
|
||||||
)
|
)
|
||||||
self._field_indexes = field_indexes
|
self._field_indexes = field_indexes
|
||||||
|
if self._hybrid_search:
|
||||||
|
logger.info("Hybrid search enabled (dense + BM25 sparse vectors with RRF)")
|
||||||
|
|
||||||
async def get_collection_names(self) -> list[str]:
|
async def get_collection_names(self) -> list[str]:
|
||||||
"""
|
"""
|
||||||
@@ -72,19 +80,30 @@ class QdrantConnector:
|
|||||||
await self._ensure_collection_exists(collection_name)
|
await self._ensure_collection_exists(collection_name)
|
||||||
|
|
||||||
# Embed the document
|
# Embed the document
|
||||||
# ToDo: instead of embedding text explicitly, use `models.Document`,
|
|
||||||
# it should unlock usage of server-side inference.
|
|
||||||
embeddings = await self._embedding_provider.embed_documents([entry.content])
|
embeddings = await self._embedding_provider.embed_documents([entry.content])
|
||||||
|
|
||||||
# Add to Qdrant
|
# Build vector dict
|
||||||
vector_name = self._embedding_provider.get_vector_name()
|
vector_name = self._embedding_provider.get_vector_name()
|
||||||
|
vector_data: dict = {vector_name: embeddings[0]}
|
||||||
|
|
||||||
|
# Add sparse vector if hybrid search is enabled
|
||||||
|
if self._hybrid_search:
|
||||||
|
sparse_embeddings = await self._embedding_provider.embed_documents_sparse(
|
||||||
|
[entry.content]
|
||||||
|
)
|
||||||
|
sparse = sparse_embeddings[0]
|
||||||
|
vector_data[SPARSE_VECTOR_NAME] = models.SparseVector(
|
||||||
|
indices=sparse.indices, values=sparse.values
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to Qdrant
|
||||||
payload = {"document": entry.content, METADATA_PATH: entry.metadata}
|
payload = {"document": entry.content, METADATA_PATH: entry.metadata}
|
||||||
await self._client.upsert(
|
await self._client.upsert(
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
points=[
|
points=[
|
||||||
models.PointStruct(
|
models.PointStruct(
|
||||||
id=uuid.uuid4().hex,
|
id=uuid.uuid4().hex,
|
||||||
vector={vector_name: embeddings[0]},
|
vector=vector_data,
|
||||||
payload=payload,
|
payload=payload,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
@@ -113,21 +132,43 @@ class QdrantConnector:
|
|||||||
if not collection_exists:
|
if not collection_exists:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Embed the query
|
|
||||||
# ToDo: instead of embedding text explicitly, use `models.Document`,
|
|
||||||
# it should unlock usage of server-side inference.
|
|
||||||
|
|
||||||
query_vector = await self._embedding_provider.embed_query(query)
|
query_vector = await self._embedding_provider.embed_query(query)
|
||||||
vector_name = self._embedding_provider.get_vector_name()
|
vector_name = self._embedding_provider.get_vector_name()
|
||||||
|
|
||||||
# Search in Qdrant
|
# Hybrid search: prefetch dense + sparse, fuse with RRF
|
||||||
search_results = await self._client.query_points(
|
if self._hybrid_search:
|
||||||
collection_name=collection_name,
|
sparse_vector = await self._embedding_provider.embed_query_sparse(query)
|
||||||
query=query_vector,
|
search_results = await self._client.query_points(
|
||||||
using=vector_name,
|
collection_name=collection_name,
|
||||||
limit=limit,
|
prefetch=[
|
||||||
query_filter=query_filter,
|
models.Prefetch(
|
||||||
)
|
query=query_vector,
|
||||||
|
using=vector_name,
|
||||||
|
limit=limit,
|
||||||
|
filter=query_filter,
|
||||||
|
),
|
||||||
|
models.Prefetch(
|
||||||
|
query=models.SparseVector(
|
||||||
|
indices=sparse_vector.indices,
|
||||||
|
values=sparse_vector.values,
|
||||||
|
),
|
||||||
|
using=SPARSE_VECTOR_NAME,
|
||||||
|
limit=limit,
|
||||||
|
filter=query_filter,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
query=models.FusionQuery(fusion=models.Fusion.RRF),
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Dense-only search (original behavior)
|
||||||
|
search_results = await self._client.query_points(
|
||||||
|
collection_name=collection_name,
|
||||||
|
query=query_vector,
|
||||||
|
using=vector_name,
|
||||||
|
limit=limit,
|
||||||
|
query_filter=query_filter,
|
||||||
|
)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
Entry(
|
Entry(
|
||||||
@@ -149,6 +190,16 @@ class QdrantConnector:
|
|||||||
|
|
||||||
# Use the vector name as defined in the embedding provider
|
# Use the vector name as defined in the embedding provider
|
||||||
vector_name = self._embedding_provider.get_vector_name()
|
vector_name = self._embedding_provider.get_vector_name()
|
||||||
|
|
||||||
|
# Sparse vectors config for hybrid search (BM25)
|
||||||
|
sparse_config = None
|
||||||
|
if self._hybrid_search:
|
||||||
|
sparse_config = {
|
||||||
|
SPARSE_VECTOR_NAME: models.SparseVectorParams(
|
||||||
|
modifier=models.Modifier.IDF,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
await self._client.create_collection(
|
await self._client.create_collection(
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
vectors_config={
|
vectors_config={
|
||||||
@@ -157,6 +208,7 @@ class QdrantConnector:
|
|||||||
distance=models.Distance.COSINE,
|
distance=models.Distance.COSINE,
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
|
sparse_vectors_config=sparse_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create payload indexes if configured
|
# Create payload indexes if configured
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ class QdrantSettings(BaseSettings):
|
|||||||
|
|
||||||
location: str | None = Field(default=None, validation_alias="QDRANT_URL")
|
location: str | None = Field(default=None, validation_alias="QDRANT_URL")
|
||||||
api_key: str | None = Field(default=None, validation_alias="QDRANT_API_KEY")
|
api_key: str | None = Field(default=None, validation_alias="QDRANT_API_KEY")
|
||||||
|
hybrid_search: bool = Field(default=False, validation_alias="HYBRID_SEARCH")
|
||||||
collection_name: str | None = Field(
|
collection_name: str | None = Field(
|
||||||
default=None, validation_alias="COLLECTION_NAME"
|
default=None, validation_alias="COLLECTION_NAME"
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user