Merge pull request #15 from qdrant/fix/collection-compatibility

Fix compatibility with the previous versions
This commit is contained in:
Kacper Łukawski
2025-03-05 23:43:30 +01:00
committed by GitHub
5 changed files with 57 additions and 11 deletions

View File

@@ -14,3 +14,8 @@ class EmbeddingProvider(ABC):
async def embed_query(self, query: str) -> List[float]: async def embed_query(self, query: str) -> List[float]:
"""Embed a query into a vector.""" """Embed a query into a vector."""
pass pass
@abstractmethod
def get_vector_name(self) -> str:
"""Get the name of the vector for the Qdrant collection."""
pass

View File

@@ -35,3 +35,11 @@ class FastEmbedProvider(EmbeddingProvider):
None, lambda: list(self.embedding_model.query_embed([query])) None, lambda: list(self.embedding_model.query_embed([query]))
) )
return embeddings[0].tolist() return embeddings[0].tolist()
def get_vector_name(self) -> str:
"""
Return the name of the vector for the Qdrant collection.
Important: This is compatible with the FastEmbed logic used before 0.6.0.
"""
model_name = self.embedding_model.model_name.split("/")[-1].lower()
return f"fast-{model_name}"

View File

@@ -1,3 +1,4 @@
import uuid
from typing import Optional from typing import Optional
from qdrant_client import AsyncQdrantClient, models from qdrant_client import AsyncQdrantClient, models
@@ -40,12 +41,16 @@ class QdrantConnector:
sample_vector = await self._embedding_provider.embed_query("sample text") sample_vector = await self._embedding_provider.embed_query("sample text")
vector_size = len(sample_vector) vector_size = len(sample_vector)
# Use the vector name as defined in the embedding provider
vector_name = self._embedding_provider.get_vector_name()
await self._client.create_collection( await self._client.create_collection(
collection_name=self._collection_name, collection_name=self._collection_name,
vectors_config=models.VectorParams( vectors_config={
size=vector_size, vector_name: models.VectorParams(
distance=models.Distance.COSINE, size=vector_size,
), distance=models.Distance.COSINE,
)
},
) )
async def store_memory(self, information: str): async def store_memory(self, information: str):
@@ -59,12 +64,13 @@ class QdrantConnector:
embeddings = await self._embedding_provider.embed_documents([information]) embeddings = await self._embedding_provider.embed_documents([information])
# Add to Qdrant # Add to Qdrant
vector_name = self._embedding_provider.get_vector_name()
await self._client.upsert( await self._client.upsert(
collection_name=self._collection_name, collection_name=self._collection_name,
points=[ points=[
models.PointStruct( models.PointStruct(
id=hash(information), # Simple hash as ID id=uuid.uuid4().hex,
vector=embeddings[0], vector={vector_name: embeddings[0]},
payload={"document": information}, payload={"document": information},
) )
], ],
@@ -82,11 +88,12 @@ class QdrantConnector:
# Embed the query # Embed the query
query_vector = await self._embedding_provider.embed_query(query) query_vector = await self._embedding_provider.embed_query(query)
vector_name = self._embedding_provider.get_vector_name()
# Search in Qdrant # Search in Qdrant
search_results = await self._client.search( search_results = await self._client.search(
collection_name=self._collection_name, collection_name=self._collection_name,
query_vector=query_vector, query_vector=models.NamedVector(name=vector_name, vector=query_vector),
limit=10, limit=10,
) )

View File

@@ -1,4 +1,5 @@
import asyncio import asyncio
import importlib.metadata
from typing import Optional from typing import Optional
import click import click
@@ -11,6 +12,15 @@ from .embeddings.factory import create_embedding_provider
from .qdrant import QdrantConnector from .qdrant import QdrantConnector
def get_package_version() -> str:
"""Get the package version using importlib.metadata."""
try:
return importlib.metadata.version("mcp-server-qdrant")
except importlib.metadata.PackageNotFoundError:
# Fall back to a default version if package is not installed
return "0.0.0"
def serve( def serve(
qdrant_url: Optional[str], qdrant_url: Optional[str],
qdrant_api_key: Optional[str], qdrant_api_key: Optional[str],
@@ -140,6 +150,13 @@ def serve(
required=True, required=True,
help="Collection name", help="Collection name",
) )
@click.option(
"--fastembed-model-name",
envvar="FASTEMBED_MODEL_NAME",
required=False,
help="FastEmbed model name",
default="sentence-transformers/all-MiniLM-L6-v2",
)
@click.option( @click.option(
"--embedding-provider", "--embedding-provider",
envvar="EMBEDDING_PROVIDER", envvar="EMBEDDING_PROVIDER",
@@ -165,6 +182,7 @@ def main(
qdrant_url: Optional[str], qdrant_url: Optional[str],
qdrant_api_key: str, qdrant_api_key: str,
collection_name: Optional[str], collection_name: Optional[str],
fastembed_model_name: Optional[str],
embedding_provider: str, embedding_provider: str,
embedding_model: str, embedding_model: str,
qdrant_local_path: Optional[str], qdrant_local_path: Optional[str],
@@ -175,6 +193,14 @@ def main(
"Exactly one of qdrant-url or qdrant-local-path must be provided" "Exactly one of qdrant-url or qdrant-local-path must be provided"
) )
# Warn if fastembed_model_name is provided, as this is going to be deprecated
if fastembed_model_name:
click.echo(
"Warning: --fastembed-model-name parameter is deprecated and will be removed in a future version. "
"Please use --embedding-provider and --embedding-model instead",
err=True,
)
async def _run(): async def _run():
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
server = serve( server = serve(
@@ -190,7 +216,7 @@ def main(
write_stream, write_stream,
InitializationOptions( InitializationOptions(
server_name="qdrant", server_name="qdrant",
server_version="0.5.1", server_version=get_package_version(),
capabilities=server.get_capabilities( capabilities=server.get_capabilities(
notification_options=NotificationOptions(), notification_options=NotificationOptions(),
experimental_capabilities={}, experimental_capabilities={},

6
uv.lock generated
View File

@@ -118,14 +118,14 @@ wheels = [
[[package]] [[package]]
name = "click" name = "click"
version = "8.1.7" version = "8.1.8"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "colorama", marker = "platform_system == 'Windows'" }, { name = "colorama", marker = "platform_system == 'Windows'" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/00/2e/d53fa4befbf2cfa713304affc7ca780ce4fc1fd8710527771b58311a3229/click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28", size = 97941 }, { url = "https://files.pythonhosted.org/packages/7e/d4/7ebdbd03970677812aac39c869717059dbb71a4cfc033ca6e5221787892c/click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2", size = 98188 },
] ]
[[package]] [[package]]