Abstract the embedding providers

This commit is contained in:
Kacper Łukawski
2025-03-05 22:40:58 +01:00
parent f8ecab23be
commit 56f1c797fc
11 changed files with 257 additions and 46 deletions

View File

@@ -0,0 +1,5 @@
from .base import EmbeddingProvider
from .factory import create_embedding_provider
from .fastembed import FastEmbedProvider
__all__ = ["EmbeddingProvider", "FastEmbedProvider", "create_embedding_provider"]

View File

@@ -0,0 +1,16 @@
from abc import ABC, abstractmethod
from typing import List
class EmbeddingProvider(ABC):
"""Abstract base class for embedding providers."""
@abstractmethod
async def embed_documents(self, documents: List[str]) -> List[List[float]]:
"""Embed a list of documents into vectors."""
pass
@abstractmethod
async def embed_query(self, query: str) -> List[float]:
"""Embed a query into a vector."""
pass

View File

@@ -0,0 +1,67 @@
from mcp_server_qdrant.embeddings import EmbeddingProvider
def create_embedding_provider(provider_type: str, **kwargs) -> EmbeddingProvider:
"""
Create an embedding provider based on the specified type.
:param provider_type: The type of embedding provider to create.
:param kwargs: Additional arguments to pass to the provider constructor.
:return: An instance of the specified embedding provider.
"""
if provider_type.lower() == "fastembed":
from .fastembed import FastEmbedProvider
model_name = kwargs.get("model_name", "sentence-transformers/all-MiniLM-L6-v2")
return FastEmbedProvider(model_name)
else:
raise ValueError(f"Unsupported embedding provider: {provider_type}")
from typing import Optional
from .fastembed import FastEmbedProvider
from .base import EmbeddingProvider
def create_embedding_provider(provider_type: str, model_name: Optional[str] = None) -> EmbeddingProvider:
"""
Create an embedding provider based on the provider type.
Args:
provider_type: The type of embedding provider to create.
model_name: The name of the model to use.
Returns:
An instance of EmbeddingProvider.
Raises:
ValueError: If the provider type is not supported.
"""
if provider_type.lower() == "fastembed":
return FastEmbedProvider(model_name)
else:
raise ValueError(f"Unsupported embedding provider: {provider_type}")
from typing import Literal
from .fastembed import FastEmbedProvider
def create_embedding_provider(
provider_type: Literal["fastembed"],
**kwargs
) -> FastEmbedProvider:
"""
Factory function to create an embedding provider.
Args:
provider_type: The type of embedding provider to create.
**kwargs: Additional arguments to pass to the provider constructor.
Returns:
An instance of the requested embedding provider.
Raises:
ValueError: If the provider type is not supported.
"""
if provider_type == "fastembed":
model_name = kwargs.get("model_name", "sentence-transformers/all-MiniLM-L6-v2")
return FastEmbedProvider(model_name)
else:
raise ValueError(f"Unsupported embedding provider: {provider_type}")

View File

@@ -0,0 +1,36 @@
from typing import List
import asyncio
from fastembed import TextEmbedding
from .base import EmbeddingProvider
class FastEmbedProvider(EmbeddingProvider):
"""FastEmbed implementation of the embedding provider."""
def __init__(self, model_name: str):
"""
Initialize the FastEmbed provider.
:param model_name: The name of the FastEmbed model to use.
"""
self.model_name = model_name
self.embedding_model = TextEmbedding(model_name)
async def embed_documents(self, documents: List[str]) -> List[List[float]]:
"""Embed a list of documents into vectors."""
# Run in a thread pool since FastEmbed is synchronous
loop = asyncio.get_event_loop()
embeddings = await loop.run_in_executor(
None, lambda: list(self.embedding_model.passage_embed(documents))
)
return [embedding.tolist() for embedding in embeddings]
async def embed_query(self, query: str) -> List[float]:
"""Embed a query into a vector."""
# Run in a thread pool since FastEmbed is synchronous
loop = asyncio.get_event_loop()
embeddings = await loop.run_in_executor(
None, lambda: list(self.embedding_model.query_embed([query]))
)
return embeddings[0].tolist()