From 8afd73849c6f898d32dd811745cb4ea34629d632 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20=C5=81ukawski?= Date: Wed, 5 Mar 2025 23:42:19 +0100 Subject: [PATCH] Fix compatibility with the previous versions --- src/mcp_server_qdrant/embeddings/base.py | 5 ++++ src/mcp_server_qdrant/embeddings/fastembed.py | 8 ++++++ src/mcp_server_qdrant/qdrant.py | 21 +++++++++----- src/mcp_server_qdrant/server.py | 28 ++++++++++++++++++- uv.lock | 6 ++-- 5 files changed, 57 insertions(+), 11 deletions(-) diff --git a/src/mcp_server_qdrant/embeddings/base.py b/src/mcp_server_qdrant/embeddings/base.py index 7440772..e2fe34d 100644 --- a/src/mcp_server_qdrant/embeddings/base.py +++ b/src/mcp_server_qdrant/embeddings/base.py @@ -14,3 +14,8 @@ class EmbeddingProvider(ABC): async def embed_query(self, query: str) -> List[float]: """Embed a query into a vector.""" pass + + @abstractmethod + def get_vector_name(self) -> str: + """Get the name of the vector for the Qdrant collection.""" + pass diff --git a/src/mcp_server_qdrant/embeddings/fastembed.py b/src/mcp_server_qdrant/embeddings/fastembed.py index cc79b85..90a9d9e 100644 --- a/src/mcp_server_qdrant/embeddings/fastembed.py +++ b/src/mcp_server_qdrant/embeddings/fastembed.py @@ -35,3 +35,11 @@ class FastEmbedProvider(EmbeddingProvider): None, lambda: list(self.embedding_model.query_embed([query])) ) return embeddings[0].tolist() + + def get_vector_name(self) -> str: + """ + Return the name of the vector for the Qdrant collection. + Important: This is compatible with the FastEmbed logic used before 0.6.0. + """ + model_name = self.embedding_model.model_name.split("/")[-1].lower() + return f"fast-{model_name}" diff --git a/src/mcp_server_qdrant/qdrant.py b/src/mcp_server_qdrant/qdrant.py index f5b4cf3..4340fd0 100644 --- a/src/mcp_server_qdrant/qdrant.py +++ b/src/mcp_server_qdrant/qdrant.py @@ -1,3 +1,4 @@ +import uuid from typing import Optional from qdrant_client import AsyncQdrantClient, models @@ -40,12 +41,16 @@ class QdrantConnector: sample_vector = await self._embedding_provider.embed_query("sample text") vector_size = len(sample_vector) + # Use the vector name as defined in the embedding provider + vector_name = self._embedding_provider.get_vector_name() await self._client.create_collection( collection_name=self._collection_name, - vectors_config=models.VectorParams( - size=vector_size, - distance=models.Distance.COSINE, - ), + vectors_config={ + vector_name: models.VectorParams( + size=vector_size, + distance=models.Distance.COSINE, + ) + }, ) async def store_memory(self, information: str): @@ -59,12 +64,13 @@ class QdrantConnector: embeddings = await self._embedding_provider.embed_documents([information]) # Add to Qdrant + vector_name = self._embedding_provider.get_vector_name() await self._client.upsert( collection_name=self._collection_name, points=[ models.PointStruct( - id=hash(information), # Simple hash as ID - vector=embeddings[0], + id=uuid.uuid4().hex, + vector={vector_name: embeddings[0]}, payload={"document": information}, ) ], @@ -82,11 +88,12 @@ class QdrantConnector: # Embed the query query_vector = await self._embedding_provider.embed_query(query) + vector_name = self._embedding_provider.get_vector_name() # Search in Qdrant search_results = await self._client.search( collection_name=self._collection_name, - query_vector=query_vector, + query_vector=models.NamedVector(name=vector_name, vector=query_vector), limit=10, ) diff --git a/src/mcp_server_qdrant/server.py b/src/mcp_server_qdrant/server.py index 12a03cb..ca2ca8a 100644 --- a/src/mcp_server_qdrant/server.py +++ b/src/mcp_server_qdrant/server.py @@ -1,4 +1,5 @@ import asyncio +import importlib.metadata from typing import Optional import click @@ -11,6 +12,15 @@ from .embeddings.factory import create_embedding_provider from .qdrant import QdrantConnector +def get_package_version() -> str: + """Get the package version using importlib.metadata.""" + try: + return importlib.metadata.version("mcp-server-qdrant") + except importlib.metadata.PackageNotFoundError: + # Fall back to a default version if package is not installed + return "0.0.0" + + def serve( qdrant_url: Optional[str], qdrant_api_key: Optional[str], @@ -140,6 +150,13 @@ def serve( required=True, help="Collection name", ) +@click.option( + "--fastembed-model-name", + envvar="FASTEMBED_MODEL_NAME", + required=False, + help="FastEmbed model name", + default="sentence-transformers/all-MiniLM-L6-v2", +) @click.option( "--embedding-provider", envvar="EMBEDDING_PROVIDER", @@ -165,6 +182,7 @@ def main( qdrant_url: Optional[str], qdrant_api_key: str, collection_name: Optional[str], + fastembed_model_name: Optional[str], embedding_provider: str, embedding_model: str, qdrant_local_path: Optional[str], @@ -175,6 +193,14 @@ def main( "Exactly one of qdrant-url or qdrant-local-path must be provided" ) + # Warn if fastembed_model_name is provided, as this is going to be deprecated + if fastembed_model_name: + click.echo( + "Warning: --fastembed-model-name parameter is deprecated and will be removed in a future version. " + "Please use --embedding-provider and --embedding-model instead", + err=True, + ) + async def _run(): async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): server = serve( @@ -190,7 +216,7 @@ def main( write_stream, InitializationOptions( server_name="qdrant", - server_version="0.5.1", + server_version=get_package_version(), capabilities=server.get_capabilities( notification_options=NotificationOptions(), experimental_capabilities={}, diff --git a/uv.lock b/uv.lock index e91b996..d349834 100644 --- a/uv.lock +++ b/uv.lock @@ -118,14 +118,14 @@ wheels = [ [[package]] name = "click" -version = "8.1.7" +version = "8.1.8" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "colorama", marker = "platform_system == 'Windows'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } +sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } wheels = [ - { url = "https://files.pythonhosted.org/packages/00/2e/d53fa4befbf2cfa713304affc7ca780ce4fc1fd8710527771b58311a3229/click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28", size = 97941 }, + { url = "https://files.pythonhosted.org/packages/7e/d4/7ebdbd03970677812aac39c869717059dbb71a4cfc033ca6e5221787892c/click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2", size = 98188 }, ] [[package]]