From bf471783fd5328cf8ee6d961671332170ce8c3bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20=C5=81ukawski?= Date: Fri, 7 Mar 2025 13:08:03 +0100 Subject: [PATCH] Refactor server and embedding provider to improve modularity - Modify `serve()` function to accept a pre-configured QdrantConnector - Update `create_embedding_provider()` to simplify model name handling - Improve error handling and parameter passing in server tools - Restructure main function to create connector and server more explicitly --- src/mcp_server_qdrant/embeddings/factory.py | 5 +- src/mcp_server_qdrant/server.py | 51 ++++++++------------- 2 files changed, 21 insertions(+), 35 deletions(-) diff --git a/src/mcp_server_qdrant/embeddings/factory.py b/src/mcp_server_qdrant/embeddings/factory.py index b8b7f4e..c69a194 100644 --- a/src/mcp_server_qdrant/embeddings/factory.py +++ b/src/mcp_server_qdrant/embeddings/factory.py @@ -1,18 +1,17 @@ from mcp_server_qdrant.embeddings import EmbeddingProvider -def create_embedding_provider(provider_type: str, **kwargs) -> EmbeddingProvider: +def create_embedding_provider(provider_type: str, model_name: str) -> EmbeddingProvider: """ Create an embedding provider based on the specified type. :param provider_type: The type of embedding provider to create. - :param kwargs: Additional arguments to pass to the provider constructor. + :param model_name: The name of the model to use for embeddings, specific to the provider type. :return: An instance of the specified embedding provider. """ if provider_type.lower() == "fastembed": from .fastembed import FastEmbedProvider - model_name = kwargs.get("model_name", "sentence-transformers/all-MiniLM-L6-v2") return FastEmbedProvider(model_name) else: raise ValueError(f"Unsupported embedding provider: {provider_type}") diff --git a/src/mcp_server_qdrant/server.py b/src/mcp_server_qdrant/server.py index ca2ca8a..95db682 100644 --- a/src/mcp_server_qdrant/server.py +++ b/src/mcp_server_qdrant/server.py @@ -22,37 +22,14 @@ def get_package_version() -> str: def serve( - qdrant_url: Optional[str], - qdrant_api_key: Optional[str], - collection_name: str, - embedding_provider_type: str, - embedding_model_name: str, - qdrant_local_path: Optional[str] = None, + qdrant_connector: QdrantConnector, ) -> Server: """ Instantiate the server and configure tools to store and find memories in Qdrant. - :param qdrant_url: The URL of the Qdrant server. - :param qdrant_api_key: The API key to use for the Qdrant server. - :param collection_name: The name of the collection to use. - :param embedding_provider_type: The type of embedding provider to use. - :param embedding_model_name: The name of the embedding model to use. - :param qdrant_local_path: The path to the storage directory for the Qdrant client, if local mode is used. + :param qdrant_connector: An instance of QdrantConnector to use for storing and retrieving memories. """ server = Server("qdrant") - # Create the embedding provider - embedding_provider = create_embedding_provider( - embedding_provider_type, model_name=embedding_model_name - ) - - qdrant = QdrantConnector( - qdrant_url, - qdrant_api_key, - collection_name, - embedding_provider, - qdrant_local_path, - ) - @server.list_tools() async def handle_list_tools() -> list[types.Tool]: """ @@ -90,8 +67,8 @@ def serve( "properties": { "query": { "type": "string", - "description": "The query to search for in the memories", - }, + "description": "The query to search for", + } }, "required": ["query"], }, @@ -109,14 +86,14 @@ def serve( if not arguments or "information" not in arguments: raise ValueError("Missing required argument 'information'") information = arguments["information"] - await qdrant.store_memory(information) + await qdrant_connector.store_memory(information) return [types.TextContent(type="text", text=f"Remembered: {information}")] if name == "qdrant-find-memories": if not arguments or "query" not in arguments: raise ValueError("Missing required argument 'query'") query = arguments["query"] - memories = await qdrant.find_memories(query) + memories = await qdrant_connector.find_memories(query) content = [ types.TextContent( type="text", text=f"Memories for the query '{query}'" @@ -128,6 +105,8 @@ def serve( ) return content + raise ValueError(f"Unknown tool: {name}") + return server @@ -203,14 +182,22 @@ def main( async def _run(): async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): - server = serve( + # Create the embedding provider + provider = create_embedding_provider( + provider_type=embedding_provider, model_name=embedding_model + ) + + # Create the Qdrant connector + qdrant_connector = QdrantConnector( qdrant_url, qdrant_api_key, collection_name, - embedding_provider, - embedding_model, + provider, qdrant_local_path, ) + + # Create and run the server + server = serve(qdrant_connector) await server.run( read_stream, write_stream,