Fix compatibility with the previous versions

This commit is contained in:
Kacper Łukawski
2025-03-05 23:42:19 +01:00
parent f252489c3d
commit 8afd73849c
5 changed files with 57 additions and 11 deletions

View File

@@ -14,3 +14,8 @@ class EmbeddingProvider(ABC):
async def embed_query(self, query: str) -> List[float]:
"""Embed a query into a vector."""
pass
@abstractmethod
def get_vector_name(self) -> str:
"""Get the name of the vector for the Qdrant collection."""
pass

View File

@@ -35,3 +35,11 @@ class FastEmbedProvider(EmbeddingProvider):
None, lambda: list(self.embedding_model.query_embed([query]))
)
return embeddings[0].tolist()
def get_vector_name(self) -> str:
"""
Return the name of the vector for the Qdrant collection.
Important: This is compatible with the FastEmbed logic used before 0.6.0.
"""
model_name = self.embedding_model.model_name.split("/")[-1].lower()
return f"fast-{model_name}"

View File

@@ -1,3 +1,4 @@
import uuid
from typing import Optional
from qdrant_client import AsyncQdrantClient, models
@@ -40,12 +41,16 @@ class QdrantConnector:
sample_vector = await self._embedding_provider.embed_query("sample text")
vector_size = len(sample_vector)
# Use the vector name as defined in the embedding provider
vector_name = self._embedding_provider.get_vector_name()
await self._client.create_collection(
collection_name=self._collection_name,
vectors_config=models.VectorParams(
size=vector_size,
distance=models.Distance.COSINE,
),
vectors_config={
vector_name: models.VectorParams(
size=vector_size,
distance=models.Distance.COSINE,
)
},
)
async def store_memory(self, information: str):
@@ -59,12 +64,13 @@ class QdrantConnector:
embeddings = await self._embedding_provider.embed_documents([information])
# Add to Qdrant
vector_name = self._embedding_provider.get_vector_name()
await self._client.upsert(
collection_name=self._collection_name,
points=[
models.PointStruct(
id=hash(information), # Simple hash as ID
vector=embeddings[0],
id=uuid.uuid4().hex,
vector={vector_name: embeddings[0]},
payload={"document": information},
)
],
@@ -82,11 +88,12 @@ class QdrantConnector:
# Embed the query
query_vector = await self._embedding_provider.embed_query(query)
vector_name = self._embedding_provider.get_vector_name()
# Search in Qdrant
search_results = await self._client.search(
collection_name=self._collection_name,
query_vector=query_vector,
query_vector=models.NamedVector(name=vector_name, vector=query_vector),
limit=10,
)

View File

@@ -1,4 +1,5 @@
import asyncio
import importlib.metadata
from typing import Optional
import click
@@ -11,6 +12,15 @@ from .embeddings.factory import create_embedding_provider
from .qdrant import QdrantConnector
def get_package_version() -> str:
"""Get the package version using importlib.metadata."""
try:
return importlib.metadata.version("mcp-server-qdrant")
except importlib.metadata.PackageNotFoundError:
# Fall back to a default version if package is not installed
return "0.0.0"
def serve(
qdrant_url: Optional[str],
qdrant_api_key: Optional[str],
@@ -140,6 +150,13 @@ def serve(
required=True,
help="Collection name",
)
@click.option(
"--fastembed-model-name",
envvar="FASTEMBED_MODEL_NAME",
required=False,
help="FastEmbed model name",
default="sentence-transformers/all-MiniLM-L6-v2",
)
@click.option(
"--embedding-provider",
envvar="EMBEDDING_PROVIDER",
@@ -165,6 +182,7 @@ def main(
qdrant_url: Optional[str],
qdrant_api_key: str,
collection_name: Optional[str],
fastembed_model_name: Optional[str],
embedding_provider: str,
embedding_model: str,
qdrant_local_path: Optional[str],
@@ -175,6 +193,14 @@ def main(
"Exactly one of qdrant-url or qdrant-local-path must be provided"
)
# Warn if fastembed_model_name is provided, as this is going to be deprecated
if fastembed_model_name:
click.echo(
"Warning: --fastembed-model-name parameter is deprecated and will be removed in a future version. "
"Please use --embedding-provider and --embedding-model instead",
err=True,
)
async def _run():
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
server = serve(
@@ -190,7 +216,7 @@ def main(
write_stream,
InitializationOptions(
server_name="qdrant",
server_version="0.5.1",
server_version=get_package_version(),
capabilities=server.get_capabilities(
notification_options=NotificationOptions(),
experimental_capabilities={},