Abstract the embedding providers

This commit is contained in:
Kacper Łukawski
2025-03-05 22:40:58 +01:00
parent f8ecab23be
commit 56f1c797fc
11 changed files with 257 additions and 46 deletions

View File

@@ -1,20 +1,23 @@
import asyncio
from typing import Optional
import click
import mcp
import mcp.types as types
from mcp.server import NotificationOptions, Server
from mcp.server import Server, NotificationOptions
from mcp.server.models import InitializationOptions
import click
import mcp.types as types
import asyncio
import mcp
from .qdrant import QdrantConnector
from .embeddings.factory import create_embedding_provider
def serve(
qdrant_url: Optional[str],
qdrant_api_key: Optional[str],
collection_name: str,
fastembed_model_name: str,
embedding_provider_type: str,
embedding_model_name: str,
qdrant_local_path: Optional[str] = None,
) -> Server:
"""
@@ -22,17 +25,20 @@ def serve(
:param qdrant_url: The URL of the Qdrant server.
:param qdrant_api_key: The API key to use for the Qdrant server.
:param collection_name: The name of the collection to use.
:param fastembed_model_name: The name of the FastEmbed model to use.
:param embedding_provider_type: The type of embedding provider to use.
:param embedding_model_name: The name of the embedding model to use.
:param qdrant_local_path: The path to the storage directory for the Qdrant client, if local mode is used.
"""
server = Server("qdrant")
# Create the embedding provider
embedding_provider = create_embedding_provider(
embedding_provider_type,
model_name=embedding_model_name
)
qdrant = QdrantConnector(
qdrant_url,
qdrant_api_key,
collection_name,
fastembed_model_name,
qdrant_local_path,
qdrant_url, qdrant_api_key, collection_name, embedding_provider, qdrant_local_path
)
@server.list_tools()
@@ -133,10 +139,18 @@ def serve(
help="Collection name",
)
@click.option(
"--fastembed-model-name",
envvar="FASTEMBED_MODEL_NAME",
required=True,
help="FastEmbed model name",
"--embedding-provider",
envvar="EMBEDDING_PROVIDER",
required=False,
help="Embedding provider to use",
default="fastembed",
type=click.Choice(["fastembed"], case_sensitive=False),
)
@click.option(
"--embedding-model",
envvar="EMBEDDING_MODEL",
required=False,
help="Embedding model name",
default="sentence-transformers/all-MiniLM-L6-v2",
)
@click.option(
@@ -149,14 +163,13 @@ def main(
qdrant_url: Optional[str],
qdrant_api_key: str,
collection_name: Optional[str],
fastembed_model_name: str,
embedding_provider: str,
embedding_model: str,
qdrant_local_path: Optional[str],
):
# XOR of url and local path, since we accept only one of them
if not (bool(qdrant_url) ^ bool(qdrant_local_path)):
raise ValueError(
"Exactly one of qdrant-url or qdrant-local-path must be provided"
)
raise ValueError("Exactly one of qdrant-url or qdrant-local-path must be provided")
async def _run():
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
@@ -164,7 +177,8 @@ def main(
qdrant_url,
qdrant_api_key,
collection_name,
fastembed_model_name,
embedding_provider,
embedding_model,
qdrant_local_path,
)
await server.run(