Abstract the embedding providers

This commit is contained in:
Kacper Łukawski
2025-03-05 22:40:58 +01:00
parent f8ecab23be
commit 56f1c797fc
11 changed files with 257 additions and 46 deletions

View File

@@ -0,0 +1,67 @@
from mcp_server_qdrant.embeddings import EmbeddingProvider
def create_embedding_provider(provider_type: str, **kwargs) -> EmbeddingProvider:
"""
Create an embedding provider based on the specified type.
:param provider_type: The type of embedding provider to create.
:param kwargs: Additional arguments to pass to the provider constructor.
:return: An instance of the specified embedding provider.
"""
if provider_type.lower() == "fastembed":
from .fastembed import FastEmbedProvider
model_name = kwargs.get("model_name", "sentence-transformers/all-MiniLM-L6-v2")
return FastEmbedProvider(model_name)
else:
raise ValueError(f"Unsupported embedding provider: {provider_type}")
from typing import Optional
from .fastembed import FastEmbedProvider
from .base import EmbeddingProvider
def create_embedding_provider(provider_type: str, model_name: Optional[str] = None) -> EmbeddingProvider:
"""
Create an embedding provider based on the provider type.
Args:
provider_type: The type of embedding provider to create.
model_name: The name of the model to use.
Returns:
An instance of EmbeddingProvider.
Raises:
ValueError: If the provider type is not supported.
"""
if provider_type.lower() == "fastembed":
return FastEmbedProvider(model_name)
else:
raise ValueError(f"Unsupported embedding provider: {provider_type}")
from typing import Literal
from .fastembed import FastEmbedProvider
def create_embedding_provider(
provider_type: Literal["fastembed"],
**kwargs
) -> FastEmbedProvider:
"""
Factory function to create an embedding provider.
Args:
provider_type: The type of embedding provider to create.
**kwargs: Additional arguments to pass to the provider constructor.
Returns:
An instance of the requested embedding provider.
Raises:
ValueError: If the provider type is not supported.
"""
if provider_type == "fastembed":
model_name = kwargs.get("model_name", "sentence-transformers/all-MiniLM-L6-v2")
return FastEmbedProvider(model_name)
else:
raise ValueError(f"Unsupported embedding provider: {provider_type}")