Abstract the embedding providers
This commit is contained in:
@@ -1,20 +1,23 @@
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import mcp
|
||||
import mcp.types as types
|
||||
from mcp.server import NotificationOptions, Server
|
||||
from mcp.server import Server, NotificationOptions
|
||||
from mcp.server.models import InitializationOptions
|
||||
|
||||
import click
|
||||
import mcp.types as types
|
||||
import asyncio
|
||||
import mcp
|
||||
|
||||
from .qdrant import QdrantConnector
|
||||
from .embeddings.factory import create_embedding_provider
|
||||
|
||||
|
||||
def serve(
|
||||
qdrant_url: Optional[str],
|
||||
qdrant_api_key: Optional[str],
|
||||
collection_name: str,
|
||||
fastembed_model_name: str,
|
||||
embedding_provider_type: str,
|
||||
embedding_model_name: str,
|
||||
qdrant_local_path: Optional[str] = None,
|
||||
) -> Server:
|
||||
"""
|
||||
@@ -22,17 +25,20 @@ def serve(
|
||||
:param qdrant_url: The URL of the Qdrant server.
|
||||
:param qdrant_api_key: The API key to use for the Qdrant server.
|
||||
:param collection_name: The name of the collection to use.
|
||||
:param fastembed_model_name: The name of the FastEmbed model to use.
|
||||
:param embedding_provider_type: The type of embedding provider to use.
|
||||
:param embedding_model_name: The name of the embedding model to use.
|
||||
:param qdrant_local_path: The path to the storage directory for the Qdrant client, if local mode is used.
|
||||
"""
|
||||
server = Server("qdrant")
|
||||
|
||||
# Create the embedding provider
|
||||
embedding_provider = create_embedding_provider(
|
||||
embedding_provider_type,
|
||||
model_name=embedding_model_name
|
||||
)
|
||||
|
||||
qdrant = QdrantConnector(
|
||||
qdrant_url,
|
||||
qdrant_api_key,
|
||||
collection_name,
|
||||
fastembed_model_name,
|
||||
qdrant_local_path,
|
||||
qdrant_url, qdrant_api_key, collection_name, embedding_provider, qdrant_local_path
|
||||
)
|
||||
|
||||
@server.list_tools()
|
||||
@@ -133,10 +139,18 @@ def serve(
|
||||
help="Collection name",
|
||||
)
|
||||
@click.option(
|
||||
"--fastembed-model-name",
|
||||
envvar="FASTEMBED_MODEL_NAME",
|
||||
required=True,
|
||||
help="FastEmbed model name",
|
||||
"--embedding-provider",
|
||||
envvar="EMBEDDING_PROVIDER",
|
||||
required=False,
|
||||
help="Embedding provider to use",
|
||||
default="fastembed",
|
||||
type=click.Choice(["fastembed"], case_sensitive=False),
|
||||
)
|
||||
@click.option(
|
||||
"--embedding-model",
|
||||
envvar="EMBEDDING_MODEL",
|
||||
required=False,
|
||||
help="Embedding model name",
|
||||
default="sentence-transformers/all-MiniLM-L6-v2",
|
||||
)
|
||||
@click.option(
|
||||
@@ -149,14 +163,13 @@ def main(
|
||||
qdrant_url: Optional[str],
|
||||
qdrant_api_key: str,
|
||||
collection_name: Optional[str],
|
||||
fastembed_model_name: str,
|
||||
embedding_provider: str,
|
||||
embedding_model: str,
|
||||
qdrant_local_path: Optional[str],
|
||||
):
|
||||
# XOR of url and local path, since we accept only one of them
|
||||
if not (bool(qdrant_url) ^ bool(qdrant_local_path)):
|
||||
raise ValueError(
|
||||
"Exactly one of qdrant-url or qdrant-local-path must be provided"
|
||||
)
|
||||
raise ValueError("Exactly one of qdrant-url or qdrant-local-path must be provided")
|
||||
|
||||
async def _run():
|
||||
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
|
||||
@@ -164,7 +177,8 @@ def main(
|
||||
qdrant_url,
|
||||
qdrant_api_key,
|
||||
collection_name,
|
||||
fastembed_model_name,
|
||||
embedding_provider,
|
||||
embedding_model,
|
||||
qdrant_local_path,
|
||||
)
|
||||
await server.run(
|
||||
|
||||
Reference in New Issue
Block a user