Refactor server and embedding provider to improve modularity
- Modify `serve()` function to accept a pre-configured QdrantConnector - Update `create_embedding_provider()` to simplify model name handling - Improve error handling and parameter passing in server tools - Restructure main function to create connector and server more explicitly
This commit is contained in:
@@ -22,37 +22,14 @@ def get_package_version() -> str:
|
||||
|
||||
|
||||
def serve(
|
||||
qdrant_url: Optional[str],
|
||||
qdrant_api_key: Optional[str],
|
||||
collection_name: str,
|
||||
embedding_provider_type: str,
|
||||
embedding_model_name: str,
|
||||
qdrant_local_path: Optional[str] = None,
|
||||
qdrant_connector: QdrantConnector,
|
||||
) -> Server:
|
||||
"""
|
||||
Instantiate the server and configure tools to store and find memories in Qdrant.
|
||||
:param qdrant_url: The URL of the Qdrant server.
|
||||
:param qdrant_api_key: The API key to use for the Qdrant server.
|
||||
:param collection_name: The name of the collection to use.
|
||||
:param embedding_provider_type: The type of embedding provider to use.
|
||||
:param embedding_model_name: The name of the embedding model to use.
|
||||
:param qdrant_local_path: The path to the storage directory for the Qdrant client, if local mode is used.
|
||||
:param qdrant_connector: An instance of QdrantConnector to use for storing and retrieving memories.
|
||||
"""
|
||||
server = Server("qdrant")
|
||||
|
||||
# Create the embedding provider
|
||||
embedding_provider = create_embedding_provider(
|
||||
embedding_provider_type, model_name=embedding_model_name
|
||||
)
|
||||
|
||||
qdrant = QdrantConnector(
|
||||
qdrant_url,
|
||||
qdrant_api_key,
|
||||
collection_name,
|
||||
embedding_provider,
|
||||
qdrant_local_path,
|
||||
)
|
||||
|
||||
@server.list_tools()
|
||||
async def handle_list_tools() -> list[types.Tool]:
|
||||
"""
|
||||
@@ -90,8 +67,8 @@ def serve(
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The query to search for in the memories",
|
||||
},
|
||||
"description": "The query to search for",
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
@@ -109,14 +86,14 @@ def serve(
|
||||
if not arguments or "information" not in arguments:
|
||||
raise ValueError("Missing required argument 'information'")
|
||||
information = arguments["information"]
|
||||
await qdrant.store_memory(information)
|
||||
await qdrant_connector.store_memory(information)
|
||||
return [types.TextContent(type="text", text=f"Remembered: {information}")]
|
||||
|
||||
if name == "qdrant-find-memories":
|
||||
if not arguments or "query" not in arguments:
|
||||
raise ValueError("Missing required argument 'query'")
|
||||
query = arguments["query"]
|
||||
memories = await qdrant.find_memories(query)
|
||||
memories = await qdrant_connector.find_memories(query)
|
||||
content = [
|
||||
types.TextContent(
|
||||
type="text", text=f"Memories for the query '{query}'"
|
||||
@@ -128,6 +105,8 @@ def serve(
|
||||
)
|
||||
return content
|
||||
|
||||
raise ValueError(f"Unknown tool: {name}")
|
||||
|
||||
return server
|
||||
|
||||
|
||||
@@ -203,14 +182,22 @@ def main(
|
||||
|
||||
async def _run():
|
||||
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
|
||||
server = serve(
|
||||
# Create the embedding provider
|
||||
provider = create_embedding_provider(
|
||||
provider_type=embedding_provider, model_name=embedding_model
|
||||
)
|
||||
|
||||
# Create the Qdrant connector
|
||||
qdrant_connector = QdrantConnector(
|
||||
qdrant_url,
|
||||
qdrant_api_key,
|
||||
collection_name,
|
||||
embedding_provider,
|
||||
embedding_model,
|
||||
provider,
|
||||
qdrant_local_path,
|
||||
)
|
||||
|
||||
# Create and run the server
|
||||
server = serve(qdrant_connector)
|
||||
await server.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
|
||||
Reference in New Issue
Block a user