Run pre-commit

This commit is contained in:
Kacper Łukawski
2025-03-05 22:54:21 +01:00
parent 56f1c797fc
commit 6aa21b6d99
5 changed files with 27 additions and 67 deletions

View File

@@ -11,56 +11,7 @@ def create_embedding_provider(provider_type: str, **kwargs) -> EmbeddingProvider
""" """
if provider_type.lower() == "fastembed": if provider_type.lower() == "fastembed":
from .fastembed import FastEmbedProvider from .fastembed import FastEmbedProvider
model_name = kwargs.get("model_name", "sentence-transformers/all-MiniLM-L6-v2")
return FastEmbedProvider(model_name)
else:
raise ValueError(f"Unsupported embedding provider: {provider_type}")
from typing import Optional
from .fastembed import FastEmbedProvider
from .base import EmbeddingProvider
def create_embedding_provider(provider_type: str, model_name: Optional[str] = None) -> EmbeddingProvider:
"""
Create an embedding provider based on the provider type.
Args:
provider_type: The type of embedding provider to create.
model_name: The name of the model to use.
Returns:
An instance of EmbeddingProvider.
Raises:
ValueError: If the provider type is not supported.
"""
if provider_type.lower() == "fastembed":
return FastEmbedProvider(model_name)
else:
raise ValueError(f"Unsupported embedding provider: {provider_type}")
from typing import Literal
from .fastembed import FastEmbedProvider
def create_embedding_provider(
provider_type: Literal["fastembed"],
**kwargs
) -> FastEmbedProvider:
"""
Factory function to create an embedding provider.
Args:
provider_type: The type of embedding provider to create.
**kwargs: Additional arguments to pass to the provider constructor.
Returns:
An instance of the requested embedding provider.
Raises:
ValueError: If the provider type is not supported.
"""
if provider_type == "fastembed":
model_name = kwargs.get("model_name", "sentence-transformers/all-MiniLM-L6-v2") model_name = kwargs.get("model_name", "sentence-transformers/all-MiniLM-L6-v2")
return FastEmbedProvider(model_name) return FastEmbedProvider(model_name)
else: else:

View File

@@ -1,5 +1,6 @@
from typing import List
import asyncio import asyncio
from typing import List
from fastembed import TextEmbedding from fastembed import TextEmbedding
from .base import EmbeddingProvider from .base import EmbeddingProvider

View File

@@ -1,5 +1,7 @@
from typing import Optional, List from typing import Optional
from qdrant_client import AsyncQdrantClient, models from qdrant_client import AsyncQdrantClient, models
from .embeddings.base import EmbeddingProvider from .embeddings.base import EmbeddingProvider
@@ -25,7 +27,9 @@ class QdrantConnector:
self._qdrant_api_key = qdrant_api_key self._qdrant_api_key = qdrant_api_key
self._collection_name = collection_name self._collection_name = collection_name
self._embedding_provider = embedding_provider self._embedding_provider = embedding_provider
self._client = AsyncQdrantClient(location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path) self._client = AsyncQdrantClient(
location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path
)
async def _ensure_collection_exists(self): async def _ensure_collection_exists(self):
"""Ensure that the collection exists, creating it if necessary.""" """Ensure that the collection exists, creating it if necessary."""

View File

@@ -1,15 +1,14 @@
import asyncio
from typing import Optional from typing import Optional
from mcp.server import Server, NotificationOptions import click
import mcp
import mcp.types as types
from mcp.server import NotificationOptions, Server
from mcp.server.models import InitializationOptions from mcp.server.models import InitializationOptions
import click
import mcp.types as types
import asyncio
import mcp
from .qdrant import QdrantConnector
from .embeddings.factory import create_embedding_provider from .embeddings.factory import create_embedding_provider
from .qdrant import QdrantConnector
def serve( def serve(
@@ -33,12 +32,15 @@ def serve(
# Create the embedding provider # Create the embedding provider
embedding_provider = create_embedding_provider( embedding_provider = create_embedding_provider(
embedding_provider_type, embedding_provider_type, model_name=embedding_model_name
model_name=embedding_model_name
) )
qdrant = QdrantConnector( qdrant = QdrantConnector(
qdrant_url, qdrant_api_key, collection_name, embedding_provider, qdrant_local_path qdrant_url,
qdrant_api_key,
collection_name,
embedding_provider,
qdrant_local_path,
) )
@server.list_tools() @server.list_tools()
@@ -169,7 +171,9 @@ def main(
): ):
# XOR of url and local path, since we accept only one of them # XOR of url and local path, since we accept only one of them
if not (bool(qdrant_url) ^ bool(qdrant_local_path)): if not (bool(qdrant_url) ^ bool(qdrant_local_path)):
raise ValueError("Exactly one of qdrant-url or qdrant-local-path must be provided") raise ValueError(
"Exactly one of qdrant-url or qdrant-local-path must be provided"
)
async def _run(): async def _run():
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):