Merge pull request #19 from qdrant/chore/simplify-server-script
Refactor server and embedding provider to improve modularity
This commit is contained in:
@@ -1,18 +1,17 @@
|
|||||||
from mcp_server_qdrant.embeddings import EmbeddingProvider
|
from mcp_server_qdrant.embeddings import EmbeddingProvider
|
||||||
|
|
||||||
|
|
||||||
def create_embedding_provider(provider_type: str, **kwargs) -> EmbeddingProvider:
|
def create_embedding_provider(provider_type: str, model_name: str) -> EmbeddingProvider:
|
||||||
"""
|
"""
|
||||||
Create an embedding provider based on the specified type.
|
Create an embedding provider based on the specified type.
|
||||||
|
|
||||||
:param provider_type: The type of embedding provider to create.
|
:param provider_type: The type of embedding provider to create.
|
||||||
:param kwargs: Additional arguments to pass to the provider constructor.
|
:param model_name: The name of the model to use for embeddings, specific to the provider type.
|
||||||
:return: An instance of the specified embedding provider.
|
:return: An instance of the specified embedding provider.
|
||||||
"""
|
"""
|
||||||
if provider_type.lower() == "fastembed":
|
if provider_type.lower() == "fastembed":
|
||||||
from .fastembed import FastEmbedProvider
|
from .fastembed import FastEmbedProvider
|
||||||
|
|
||||||
model_name = kwargs.get("model_name", "sentence-transformers/all-MiniLM-L6-v2")
|
|
||||||
return FastEmbedProvider(model_name)
|
return FastEmbedProvider(model_name)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported embedding provider: {provider_type}")
|
raise ValueError(f"Unsupported embedding provider: {provider_type}")
|
||||||
|
|||||||
@@ -22,37 +22,14 @@ def get_package_version() -> str:
|
|||||||
|
|
||||||
|
|
||||||
def serve(
|
def serve(
|
||||||
qdrant_url: Optional[str],
|
qdrant_connector: QdrantConnector,
|
||||||
qdrant_api_key: Optional[str],
|
|
||||||
collection_name: str,
|
|
||||||
embedding_provider_type: str,
|
|
||||||
embedding_model_name: str,
|
|
||||||
qdrant_local_path: Optional[str] = None,
|
|
||||||
) -> Server:
|
) -> Server:
|
||||||
"""
|
"""
|
||||||
Instantiate the server and configure tools to store and find memories in Qdrant.
|
Instantiate the server and configure tools to store and find memories in Qdrant.
|
||||||
:param qdrant_url: The URL of the Qdrant server.
|
:param qdrant_connector: An instance of QdrantConnector to use for storing and retrieving memories.
|
||||||
:param qdrant_api_key: The API key to use for the Qdrant server.
|
|
||||||
:param collection_name: The name of the collection to use.
|
|
||||||
:param embedding_provider_type: The type of embedding provider to use.
|
|
||||||
:param embedding_model_name: The name of the embedding model to use.
|
|
||||||
:param qdrant_local_path: The path to the storage directory for the Qdrant client, if local mode is used.
|
|
||||||
"""
|
"""
|
||||||
server = Server("qdrant")
|
server = Server("qdrant")
|
||||||
|
|
||||||
# Create the embedding provider
|
|
||||||
embedding_provider = create_embedding_provider(
|
|
||||||
embedding_provider_type, model_name=embedding_model_name
|
|
||||||
)
|
|
||||||
|
|
||||||
qdrant = QdrantConnector(
|
|
||||||
qdrant_url,
|
|
||||||
qdrant_api_key,
|
|
||||||
collection_name,
|
|
||||||
embedding_provider,
|
|
||||||
qdrant_local_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
@server.list_tools()
|
@server.list_tools()
|
||||||
async def handle_list_tools() -> list[types.Tool]:
|
async def handle_list_tools() -> list[types.Tool]:
|
||||||
"""
|
"""
|
||||||
@@ -90,8 +67,8 @@ def serve(
|
|||||||
"properties": {
|
"properties": {
|
||||||
"query": {
|
"query": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The query to search for in the memories",
|
"description": "The query to search for",
|
||||||
},
|
}
|
||||||
},
|
},
|
||||||
"required": ["query"],
|
"required": ["query"],
|
||||||
},
|
},
|
||||||
@@ -109,14 +86,14 @@ def serve(
|
|||||||
if not arguments or "information" not in arguments:
|
if not arguments or "information" not in arguments:
|
||||||
raise ValueError("Missing required argument 'information'")
|
raise ValueError("Missing required argument 'information'")
|
||||||
information = arguments["information"]
|
information = arguments["information"]
|
||||||
await qdrant.store_memory(information)
|
await qdrant_connector.store_memory(information)
|
||||||
return [types.TextContent(type="text", text=f"Remembered: {information}")]
|
return [types.TextContent(type="text", text=f"Remembered: {information}")]
|
||||||
|
|
||||||
if name == "qdrant-find-memories":
|
if name == "qdrant-find-memories":
|
||||||
if not arguments or "query" not in arguments:
|
if not arguments or "query" not in arguments:
|
||||||
raise ValueError("Missing required argument 'query'")
|
raise ValueError("Missing required argument 'query'")
|
||||||
query = arguments["query"]
|
query = arguments["query"]
|
||||||
memories = await qdrant.find_memories(query)
|
memories = await qdrant_connector.find_memories(query)
|
||||||
content = [
|
content = [
|
||||||
types.TextContent(
|
types.TextContent(
|
||||||
type="text", text=f"Memories for the query '{query}'"
|
type="text", text=f"Memories for the query '{query}'"
|
||||||
@@ -128,6 +105,8 @@ def serve(
|
|||||||
)
|
)
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
raise ValueError(f"Unknown tool: {name}")
|
||||||
|
|
||||||
return server
|
return server
|
||||||
|
|
||||||
|
|
||||||
@@ -203,14 +182,22 @@ def main(
|
|||||||
|
|
||||||
async def _run():
|
async def _run():
|
||||||
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
|
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
|
||||||
server = serve(
|
# Create the embedding provider
|
||||||
|
provider = create_embedding_provider(
|
||||||
|
provider_type=embedding_provider, model_name=embedding_model
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the Qdrant connector
|
||||||
|
qdrant_connector = QdrantConnector(
|
||||||
qdrant_url,
|
qdrant_url,
|
||||||
qdrant_api_key,
|
qdrant_api_key,
|
||||||
collection_name,
|
collection_name,
|
||||||
embedding_provider,
|
provider,
|
||||||
embedding_model,
|
|
||||||
qdrant_local_path,
|
qdrant_local_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create and run the server
|
||||||
|
server = serve(qdrant_connector)
|
||||||
await server.run(
|
await server.run(
|
||||||
read_stream,
|
read_stream,
|
||||||
write_stream,
|
write_stream,
|
||||||
|
|||||||
Reference in New Issue
Block a user