Abstract the embedding providers
This commit is contained in:
5
src/mcp_server_qdrant/embeddings/__init__.py
Normal file
5
src/mcp_server_qdrant/embeddings/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .base import EmbeddingProvider
|
||||
from .factory import create_embedding_provider
|
||||
from .fastembed import FastEmbedProvider
|
||||
|
||||
__all__ = ["EmbeddingProvider", "FastEmbedProvider", "create_embedding_provider"]
|
||||
16
src/mcp_server_qdrant/embeddings/base.py
Normal file
16
src/mcp_server_qdrant/embeddings/base.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
"""Abstract base class for embedding providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def embed_documents(self, documents: List[str]) -> List[List[float]]:
|
||||
"""Embed a list of documents into vectors."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def embed_query(self, query: str) -> List[float]:
|
||||
"""Embed a query into a vector."""
|
||||
pass
|
||||
67
src/mcp_server_qdrant/embeddings/factory.py
Normal file
67
src/mcp_server_qdrant/embeddings/factory.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from mcp_server_qdrant.embeddings import EmbeddingProvider
|
||||
|
||||
|
||||
def create_embedding_provider(provider_type: str, **kwargs) -> EmbeddingProvider:
|
||||
"""
|
||||
Create an embedding provider based on the specified type.
|
||||
|
||||
:param provider_type: The type of embedding provider to create.
|
||||
:param kwargs: Additional arguments to pass to the provider constructor.
|
||||
:return: An instance of the specified embedding provider.
|
||||
"""
|
||||
if provider_type.lower() == "fastembed":
|
||||
from .fastembed import FastEmbedProvider
|
||||
model_name = kwargs.get("model_name", "sentence-transformers/all-MiniLM-L6-v2")
|
||||
return FastEmbedProvider(model_name)
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding provider: {provider_type}")
|
||||
from typing import Optional
|
||||
from .fastembed import FastEmbedProvider
|
||||
from .base import EmbeddingProvider
|
||||
|
||||
|
||||
def create_embedding_provider(provider_type: str, model_name: Optional[str] = None) -> EmbeddingProvider:
|
||||
"""
|
||||
Create an embedding provider based on the provider type.
|
||||
|
||||
Args:
|
||||
provider_type: The type of embedding provider to create.
|
||||
model_name: The name of the model to use.
|
||||
|
||||
Returns:
|
||||
An instance of EmbeddingProvider.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider type is not supported.
|
||||
"""
|
||||
if provider_type.lower() == "fastembed":
|
||||
return FastEmbedProvider(model_name)
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding provider: {provider_type}")
|
||||
from typing import Literal
|
||||
|
||||
from .fastembed import FastEmbedProvider
|
||||
|
||||
|
||||
def create_embedding_provider(
|
||||
provider_type: Literal["fastembed"],
|
||||
**kwargs
|
||||
) -> FastEmbedProvider:
|
||||
"""
|
||||
Factory function to create an embedding provider.
|
||||
|
||||
Args:
|
||||
provider_type: The type of embedding provider to create.
|
||||
**kwargs: Additional arguments to pass to the provider constructor.
|
||||
|
||||
Returns:
|
||||
An instance of the requested embedding provider.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider type is not supported.
|
||||
"""
|
||||
if provider_type == "fastembed":
|
||||
model_name = kwargs.get("model_name", "sentence-transformers/all-MiniLM-L6-v2")
|
||||
return FastEmbedProvider(model_name)
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding provider: {provider_type}")
|
||||
36
src/mcp_server_qdrant/embeddings/fastembed.py
Normal file
36
src/mcp_server_qdrant/embeddings/fastembed.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import List
|
||||
import asyncio
|
||||
from fastembed import TextEmbedding
|
||||
|
||||
from .base import EmbeddingProvider
|
||||
|
||||
|
||||
class FastEmbedProvider(EmbeddingProvider):
|
||||
"""FastEmbed implementation of the embedding provider."""
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
"""
|
||||
Initialize the FastEmbed provider.
|
||||
|
||||
:param model_name: The name of the FastEmbed model to use.
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.embedding_model = TextEmbedding(model_name)
|
||||
|
||||
async def embed_documents(self, documents: List[str]) -> List[List[float]]:
|
||||
"""Embed a list of documents into vectors."""
|
||||
# Run in a thread pool since FastEmbed is synchronous
|
||||
loop = asyncio.get_event_loop()
|
||||
embeddings = await loop.run_in_executor(
|
||||
None, lambda: list(self.embedding_model.passage_embed(documents))
|
||||
)
|
||||
return [embedding.tolist() for embedding in embeddings]
|
||||
|
||||
async def embed_query(self, query: str) -> List[float]:
|
||||
"""Embed a query into a vector."""
|
||||
# Run in a thread pool since FastEmbed is synchronous
|
||||
loop = asyncio.get_event_loop()
|
||||
embeddings = await loop.run_in_executor(
|
||||
None, lambda: list(self.embedding_model.query_embed([query]))
|
||||
)
|
||||
return embeddings[0].tolist()
|
||||
Reference in New Issue
Block a user