Implement the server with FastMCP
This commit is contained in:
@@ -1,170 +1,122 @@
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncIterator, List
|
||||
|
||||
import click
|
||||
import mcp
|
||||
import mcp.types as types
|
||||
from mcp.server import NotificationOptions, Server
|
||||
from mcp.server.models import InitializationOptions
|
||||
from mcp.server import Server
|
||||
from mcp.server.fastmcp import Context, FastMCP
|
||||
|
||||
from .custom_server import QdrantMCPServer
|
||||
from .embeddings.factory import create_embedding_provider
|
||||
from .helper import get_package_version
|
||||
from .qdrant import QdrantConnector
|
||||
from mcp_server_qdrant.embeddings.factory import create_embedding_provider
|
||||
from mcp_server_qdrant.qdrant import QdrantConnector
|
||||
from mcp_server_qdrant.settings import (
|
||||
EmbeddingProviderSettings,
|
||||
QdrantSettings,
|
||||
parse_args,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Parse command line arguments and set them as environment variables.
|
||||
# This is done for backwards compatibility with the previous versions
|
||||
# of the MCP server.
|
||||
env_vars = parse_args()
|
||||
for key, value in env_vars.items():
|
||||
os.environ[key] = value
|
||||
|
||||
|
||||
def serve(
|
||||
qdrant_connector: QdrantConnector,
|
||||
) -> Server:
|
||||
@asynccontextmanager
|
||||
async def server_lifespan(server: Server) -> AsyncIterator[dict]: # noqa
|
||||
"""
|
||||
Instantiate the server and configure tools to store and find memories in Qdrant.
|
||||
:param qdrant_connector: An instance of QdrantConnector to use for storing and retrieving memories.
|
||||
Context manager to handle the lifespan of the server.
|
||||
This is used to configure the embedding provider and Qdrant connector.
|
||||
"""
|
||||
server = QdrantMCPServer("qdrant")
|
||||
|
||||
@server.register_tool(
|
||||
name="qdrant-store-memory",
|
||||
description=(
|
||||
"Keep the memory for later use, when you are asked to remember something."
|
||||
),
|
||||
)
|
||||
async def store_memory(information: str):
|
||||
"""
|
||||
Store a memory in Qdrant.
|
||||
:param information: The information to store.
|
||||
"""
|
||||
nonlocal qdrant_connector
|
||||
await qdrant_connector.store_memory(information)
|
||||
return [types.TextContent(type="text", text=f"Remembered: {information}")]
|
||||
|
||||
@server.register_tool(
|
||||
name="qdrant-find-memories",
|
||||
description=(
|
||||
"Look up memories in Qdrant. Use this tool when you need to: \n"
|
||||
" - Find memories by their content \n"
|
||||
" - Access memories for further analysis \n"
|
||||
" - Get some personal information about the user"
|
||||
),
|
||||
)
|
||||
async def find_memories(query: str):
|
||||
"""
|
||||
Find memories in Qdrant.
|
||||
:param query: The query to use for the search.
|
||||
:return: A list of memories found.
|
||||
"""
|
||||
nonlocal qdrant_connector
|
||||
memories = await qdrant_connector.find_memories(query)
|
||||
content = [
|
||||
types.TextContent(type="text", text=f"Memories for the query '{query}'"),
|
||||
]
|
||||
for memory in memories:
|
||||
content.append(
|
||||
types.TextContent(type="text", text=f"<memory>{memory}</memory>")
|
||||
)
|
||||
return content
|
||||
|
||||
return server
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
"--qdrant-url",
|
||||
envvar="QDRANT_URL",
|
||||
required=False,
|
||||
help="Qdrant URL",
|
||||
)
|
||||
@click.option(
|
||||
"--qdrant-api-key",
|
||||
envvar="QDRANT_API_KEY",
|
||||
required=False,
|
||||
help="Qdrant API key",
|
||||
)
|
||||
@click.option(
|
||||
"--collection-name",
|
||||
envvar="COLLECTION_NAME",
|
||||
required=True,
|
||||
help="Collection name",
|
||||
)
|
||||
@click.option(
|
||||
"--fastembed-model-name",
|
||||
envvar="FASTEMBED_MODEL_NAME",
|
||||
required=False,
|
||||
help="FastEmbed model name",
|
||||
default="sentence-transformers/all-MiniLM-L6-v2",
|
||||
)
|
||||
@click.option(
|
||||
"--embedding-provider",
|
||||
envvar="EMBEDDING_PROVIDER",
|
||||
required=False,
|
||||
help="Embedding provider to use",
|
||||
default="fastembed",
|
||||
type=click.Choice(["fastembed"], case_sensitive=False),
|
||||
)
|
||||
@click.option(
|
||||
"--embedding-model",
|
||||
envvar="EMBEDDING_MODEL",
|
||||
required=False,
|
||||
help="Embedding model name",
|
||||
default="sentence-transformers/all-MiniLM-L6-v2",
|
||||
)
|
||||
@click.option(
|
||||
"--qdrant-local-path",
|
||||
envvar="QDRANT_LOCAL_PATH",
|
||||
required=False,
|
||||
help="Qdrant local path",
|
||||
)
|
||||
def main(
|
||||
qdrant_url: Optional[str],
|
||||
qdrant_api_key: str,
|
||||
collection_name: Optional[str],
|
||||
fastembed_model_name: Optional[str],
|
||||
embedding_provider: str,
|
||||
embedding_model: str,
|
||||
qdrant_local_path: Optional[str],
|
||||
):
|
||||
# XOR of url and local path, since we accept only one of them
|
||||
if not (bool(qdrant_url) ^ bool(qdrant_local_path)):
|
||||
raise ValueError(
|
||||
"Exactly one of qdrant-url or qdrant-local-path must be provided"
|
||||
try:
|
||||
# Embedding provider is created with a factory function so we can add
|
||||
# some more providers in the future. Currently, only FastEmbed is supported.
|
||||
embedding_provider_settings = EmbeddingProviderSettings()
|
||||
embedding_provider = create_embedding_provider(embedding_provider_settings)
|
||||
logger.info(
|
||||
f"Using embedding provider {embedding_provider_settings.provider_type} with "
|
||||
f"model {embedding_provider_settings.model_name}"
|
||||
)
|
||||
|
||||
# Warn if fastembed_model_name is provided, as this is going to be deprecated
|
||||
if fastembed_model_name:
|
||||
click.echo(
|
||||
"Warning: --fastembed-model-name parameter is deprecated and will be removed in a future version. "
|
||||
"Please use --embedding-provider and --embedding-model instead",
|
||||
err=True,
|
||||
qdrant_configuration = QdrantSettings()
|
||||
qdrant_connector = QdrantConnector(
|
||||
qdrant_configuration.location,
|
||||
qdrant_configuration.api_key,
|
||||
qdrant_configuration.collection_name,
|
||||
embedding_provider,
|
||||
qdrant_configuration.local_path,
|
||||
)
|
||||
logger.info(
|
||||
f"Connecting to Qdrant at {qdrant_configuration.get_qdrant_location()}"
|
||||
)
|
||||
|
||||
async def _run():
|
||||
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
|
||||
# Create the embedding provider
|
||||
provider = create_embedding_provider(
|
||||
provider_type=embedding_provider, model_name=embedding_model
|
||||
)
|
||||
yield {
|
||||
"embedding_provider": embedding_provider,
|
||||
"qdrant_connector": qdrant_connector,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
raise e
|
||||
finally:
|
||||
pass
|
||||
|
||||
# Create the Qdrant connector
|
||||
qdrant_connector = QdrantConnector(
|
||||
qdrant_url,
|
||||
qdrant_api_key,
|
||||
collection_name,
|
||||
provider,
|
||||
qdrant_local_path,
|
||||
)
|
||||
|
||||
# Create and run the server
|
||||
server = serve(qdrant_connector)
|
||||
await server.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
InitializationOptions(
|
||||
server_name="qdrant",
|
||||
server_version=get_package_version(),
|
||||
capabilities=server.get_capabilities(
|
||||
notification_options=NotificationOptions(),
|
||||
experimental_capabilities={},
|
||||
),
|
||||
),
|
||||
)
|
||||
mcp = FastMCP("Qdrant", lifespan=server_lifespan)
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
@mcp.tool(
|
||||
name="qdrant-store-memory",
|
||||
description=(
|
||||
"Keep the memory for later use, when you are asked to remember something."
|
||||
),
|
||||
)
|
||||
async def store(information: str, ctx: Context) -> str:
|
||||
"""
|
||||
Store a memory in Qdrant.
|
||||
:param information: The information to store.
|
||||
:param ctx: The context for the request.
|
||||
:return: A message indicating that the information was stored.
|
||||
"""
|
||||
await ctx.debug(f"Storing information {information} in Qdrant")
|
||||
qdrant_connector: QdrantConnector = ctx.request_context.lifespan_context[
|
||||
"qdrant_connector"
|
||||
]
|
||||
await qdrant_connector.store(information)
|
||||
return f"Remembered: {information}"
|
||||
|
||||
|
||||
@mcp.tool(
|
||||
name="qdrant-find-memories",
|
||||
description=(
|
||||
"Look up memories in Qdrant. Use this tool when you need to: \n"
|
||||
" - Find memories by their content \n"
|
||||
" - Access memories for further analysis \n"
|
||||
" - Get some personal information about the user"
|
||||
),
|
||||
)
|
||||
async def find(query: str, ctx: Context) -> List[str]:
|
||||
"""
|
||||
Find memories in Qdrant.
|
||||
:param query: The query to use for the search.
|
||||
:param ctx: The context for the request.
|
||||
:return: A list of entries found.
|
||||
"""
|
||||
await ctx.debug(f"Finding points for query {query}")
|
||||
qdrant_connector: QdrantConnector = ctx.request_context.lifespan_context[
|
||||
"qdrant_connector"
|
||||
]
|
||||
entries = await qdrant_connector.search(query)
|
||||
if not entries:
|
||||
return [f"No memories found for the query '{query}'"]
|
||||
content = [
|
||||
f"Memories for the query '{query}'",
|
||||
]
|
||||
for entry in entries:
|
||||
content.append(f"<entry>{entry}</entry>")
|
||||
return content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mcp.run()
|
||||
|
||||
Reference in New Issue
Block a user