diff --git a/README.md b/README.md index adcfaf3..c8e21c1 100644 --- a/README.md +++ b/README.md @@ -19,16 +19,17 @@ It acts as a semantic memory layer on top of the Qdrant database. ### Tools -1. `qdrant-store-memory` - - Store a memory in the Qdrant database +1. `qdrant-store` + - Store some information in the Qdrant database - Input: - - `information` (string): Memory to store + - `information` (string): Information to store + - `metadata` (JSON): Optional metadata to store - Returns: Confirmation message -2. `qdrant-find-memories` - - Retrieve a memory from the Qdrant database +2. `qdrant-find` + - Retrieve relevant information from the Qdrant database - Input: - - `query` (string): Query to retrieve a memory - - Returns: Memories stored in the Qdrant database as separate messages + - `query` (string): Query to use for searching + - Returns: Information stored in the Qdrant database as separate messages ## Installation in Claude Desktop diff --git a/pyproject.toml b/pyproject.toml index c33a87e..ec2d591 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "mcp[cli]>=1.3.0", "fastembed>=0.6.0", "qdrant-client>=1.12.0", + "pydantic>=2.10.6", ] [build-system] diff --git a/src/mcp_server_qdrant/embeddings/factory.py b/src/mcp_server_qdrant/embeddings/factory.py index a0a56f0..efac851 100644 --- a/src/mcp_server_qdrant/embeddings/factory.py +++ b/src/mcp_server_qdrant/embeddings/factory.py @@ -1,4 +1,4 @@ -from mcp_server_qdrant.embeddings import EmbeddingProvider +from mcp_server_qdrant.embeddings.base import EmbeddingProvider from mcp_server_qdrant.embeddings.types import EmbeddingProviderType from mcp_server_qdrant.settings import EmbeddingProviderSettings diff --git a/src/mcp_server_qdrant/qdrant.py b/src/mcp_server_qdrant/qdrant.py index ed6aac7..b56c200 100644 --- a/src/mcp_server_qdrant/qdrant.py +++ b/src/mcp_server_qdrant/qdrant.py @@ -1,9 +1,24 @@ +import logging import uuid -from typing import Optional +from typing import Any, Dict, Optional +from pydantic import BaseModel from qdrant_client import AsyncQdrantClient, models -from .embeddings.base import EmbeddingProvider +from mcp_server_qdrant.embeddings.base import EmbeddingProvider + +logger = logging.getLogger(__name__) + +Metadata = Dict[str, Any] + + +class Entry(BaseModel): + """ + A single entry in the Qdrant collection. + """ + + content: str + metadata: Optional[Metadata] = None class QdrantConnector: @@ -53,30 +68,31 @@ class QdrantConnector: }, ) - async def store(self, information: str): + async def store(self, entry: Entry): """ - Store some information in the Qdrant collection. - :param information: The information to store. + Store some information in the Qdrant collection, along with the specified metadata. + :param entry: The entry to store in the Qdrant collection. """ await self._ensure_collection_exists() # Embed the document - embeddings = await self._embedding_provider.embed_documents([information]) + embeddings = await self._embedding_provider.embed_documents([entry.content]) # Add to Qdrant vector_name = self._embedding_provider.get_vector_name() + payload = {"document": entry.content, "metadata": entry.metadata} await self._client.upsert( collection_name=self._collection_name, points=[ models.PointStruct( id=uuid.uuid4().hex, vector={vector_name: embeddings[0]}, - payload={"document": information}, + payload=payload, ) ], ) - async def search(self, query: str) -> list[str]: + async def search(self, query: str) -> list[Entry]: """ Find points in the Qdrant collection. If there are no entries found, an empty list is returned. :param query: The query to use for the search. @@ -97,4 +113,10 @@ class QdrantConnector: limit=10, ) - return [result.payload["document"] for result in search_results] + return [ + Entry( + content=result.payload["document"], + metadata=result.payload.get("metadata"), + ) + for result in search_results + ] diff --git a/src/mcp_server_qdrant/server.py b/src/mcp_server_qdrant/server.py index bb38f71..df94fe2 100644 --- a/src/mcp_server_qdrant/server.py +++ b/src/mcp_server_qdrant/server.py @@ -1,12 +1,13 @@ +import json import logging from contextlib import asynccontextmanager -from typing import AsyncIterator, List +from typing import AsyncIterator, List, Optional from mcp.server import Server from mcp.server.fastmcp import Context, FastMCP from mcp_server_qdrant.embeddings.factory import create_embedding_provider -from mcp_server_qdrant.qdrant import QdrantConnector +from mcp_server_qdrant.qdrant import Entry, Metadata, QdrantConnector from mcp_server_qdrant.settings import EmbeddingProviderSettings, QdrantSettings logger = logging.getLogger(__name__) @@ -57,28 +58,32 @@ mcp = FastMCP("mcp-server-qdrant", lifespan=server_lifespan) @mcp.tool( - name="qdrant-store-memory", + name="qdrant-store", description=( "Keep the memory for later use, when you are asked to remember something." ), ) -async def store(information: str, ctx: Context) -> str: +async def store( + ctx: Context, information: str, metadata: Optional[Metadata] = None +) -> str: """ Store a memory in Qdrant. - :param information: The information to store. :param ctx: The context for the request. + :param information: The information to store. + :param metadata: JSON metadata to store with the information, optional. :return: A message indicating that the information was stored. """ await ctx.debug(f"Storing information {information} in Qdrant") qdrant_connector: QdrantConnector = ctx.request_context.lifespan_context[ "qdrant_connector" ] - await qdrant_connector.store(information) + entry = Entry(content=information, metadata=metadata) + await qdrant_connector.store(entry) return f"Remembered: {information}" @mcp.tool( - name="qdrant-find-memories", + name="qdrant-find", description=( "Look up memories in Qdrant. Use this tool when you need to: \n" " - Find memories by their content \n" @@ -86,11 +91,11 @@ async def store(information: str, ctx: Context) -> str: " - Get some personal information about the user" ), ) -async def find(query: str, ctx: Context) -> List[str]: +async def find(ctx: Context, query: str) -> List[str]: """ Find memories in Qdrant. - :param query: The query to use for the search. :param ctx: The context for the request. + :param query: The query to use for the search. :return: A list of entries found. """ await ctx.debug(f"Finding points for query {query}") @@ -104,5 +109,9 @@ async def find(query: str, ctx: Context) -> List[str]: f"Memories for the query '{query}'", ] for entry in entries: - content.append(f"{entry}") + # Format the metadata as a JSON string and produce XML-like output + entry_metadata = json.dumps(entry.metadata) if entry.metadata else "" + content.append( + f"{entry.content}{entry_metadata}" + ) return content diff --git a/tests/test_qdrant_integration.py b/tests/test_qdrant_integration.py new file mode 100644 index 0000000..f990b8b --- /dev/null +++ b/tests/test_qdrant_integration.py @@ -0,0 +1,161 @@ +import uuid + +import pytest + +from mcp_server_qdrant.embeddings.fastembed import FastEmbedProvider +from mcp_server_qdrant.qdrant import Entry, QdrantConnector + + +@pytest.fixture +async def embedding_provider(): + """Fixture to provide a FastEmbed embedding provider.""" + return FastEmbedProvider(model_name="sentence-transformers/all-MiniLM-L6-v2") + + +@pytest.fixture +async def qdrant_connector(embedding_provider): + """Fixture to provide a QdrantConnector with in-memory Qdrant client.""" + # Use a random collection name to avoid conflicts between tests + collection_name = f"test_collection_{uuid.uuid4().hex}" + + # Create connector with in-memory Qdrant + connector = QdrantConnector( + qdrant_url=":memory:", + qdrant_api_key=None, + collection_name=collection_name, + embedding_provider=embedding_provider, + ) + + yield connector + + +@pytest.mark.asyncio +async def test_store_and_search(qdrant_connector): + """Test storing an entry and then searching for it.""" + # Store a test entry + test_entry = Entry( + content="The quick brown fox jumps over the lazy dog", + metadata={"source": "test", "importance": "high"}, + ) + await qdrant_connector.store(test_entry) + + # Search for the entry + results = await qdrant_connector.search("fox jumps") + + # Verify results + assert len(results) == 1 + assert results[0].content == test_entry.content + assert results[0].metadata == test_entry.metadata + + +@pytest.mark.asyncio +async def test_search_empty_collection(qdrant_connector): + """Test searching in an empty collection.""" + # Search in an empty collection + results = await qdrant_connector.search("test query") + + # Verify results + assert len(results) == 0 + + +@pytest.mark.asyncio +async def test_multiple_entries(qdrant_connector): + """Test storing and searching multiple entries.""" + # Store multiple entries + entries = [ + Entry( + content="Python is a programming language", + metadata={"topic": "programming"}, + ), + Entry(content="The Eiffel Tower is in Paris", metadata={"topic": "landmarks"}), + Entry(content="Machine learning is a subset of AI", metadata={"topic": "AI"}), + ] + + for entry in entries: + await qdrant_connector.store(entry) + + # Search for programming-related entries + programming_results = await qdrant_connector.search("Python programming") + assert len(programming_results) > 0 + assert any("Python" in result.content for result in programming_results) + + # Search for landmark-related entries + landmark_results = await qdrant_connector.search("Eiffel Tower Paris") + assert len(landmark_results) > 0 + assert any("Eiffel" in result.content for result in landmark_results) + + # Search for AI-related entries + ai_results = await qdrant_connector.search( + "artificial intelligence machine learning" + ) + assert len(ai_results) > 0 + assert any("machine learning" in result.content.lower() for result in ai_results) + + +@pytest.mark.asyncio +async def test_ensure_collection_exists(qdrant_connector): + """Test that the collection is created if it doesn't exist.""" + # The collection shouldn't exist yet + assert not await qdrant_connector._client.collection_exists( + qdrant_connector._collection_name + ) + + # Storing an entry should create the collection + test_entry = Entry(content="Test content") + await qdrant_connector.store(test_entry) + + # Now the collection should exist + assert await qdrant_connector._client.collection_exists( + qdrant_connector._collection_name + ) + + +@pytest.mark.asyncio +async def test_metadata_handling(qdrant_connector): + """Test that metadata is properly stored and retrieved.""" + # Store entries with different metadata + metadata1 = {"source": "book", "author": "Jane Doe", "year": 2023} + metadata2 = {"source": "article", "tags": ["science", "research"]} + + await qdrant_connector.store( + Entry(content="Content with structured metadata", metadata=metadata1) + ) + await qdrant_connector.store( + Entry(content="Content with list in metadata", metadata=metadata2) + ) + + # Search and verify metadata is preserved + results = await qdrant_connector.search("metadata") + + assert len(results) == 2 + + # Check that both metadata objects are present in the results + found_metadata1 = False + found_metadata2 = False + + for result in results: + if result.metadata.get("source") == "book": + assert result.metadata.get("author") == "Jane Doe" + assert result.metadata.get("year") == 2023 + found_metadata1 = True + elif result.metadata.get("source") == "article": + assert "science" in result.metadata.get("tags", []) + assert "research" in result.metadata.get("tags", []) + found_metadata2 = True + + assert found_metadata1 + assert found_metadata2 + + +@pytest.mark.asyncio +async def test_entry_without_metadata(qdrant_connector): + """Test storing and retrieving entries without metadata.""" + # Store an entry without metadata + await qdrant_connector.store(Entry(content="Entry without metadata")) + + # Search and verify + results = await qdrant_connector.search("without metadata") + + assert len(results) == 1 + assert results[0].content == "Entry without metadata" + assert results[0].metadata is None diff --git a/uv.lock b/uv.lock index b954e12..f0d8b2e 100644 --- a/uv.lock +++ b/uv.lock @@ -497,6 +497,7 @@ source = { editable = "." } dependencies = [ { name = "fastembed" }, { name = "mcp", extra = ["cli"] }, + { name = "pydantic" }, { name = "qdrant-client" }, ] @@ -513,6 +514,7 @@ dev = [ requires-dist = [ { name = "fastembed", specifier = ">=0.6.0" }, { name = "mcp", extras = ["cli"], specifier = ">=1.3.0" }, + { name = "pydantic", specifier = ">=2.10.6" }, { name = "qdrant-client", specifier = ">=1.12.0" }, ]