Abstract the embedding providers

This commit is contained in:
Kacper Łukawski
2025-03-05 22:40:58 +01:00
parent f8ecab23be
commit 56f1c797fc
11 changed files with 257 additions and 46 deletions

View File

@@ -1,6 +1,6 @@
from typing import Optional
from qdrant_client import AsyncQdrantClient
from typing import Optional, List
from qdrant_client import AsyncQdrantClient, models
from .embeddings.base import EmbeddingProvider
class QdrantConnector:
@@ -9,7 +9,7 @@ class QdrantConnector:
:param qdrant_url: The URL of the Qdrant server.
:param qdrant_api_key: The API key to use for the Qdrant server.
:param collection_name: The name of the collection to use.
:param fastembed_model_name: The name of the FastEmbed model to use.
:param embedding_provider: The embedding provider to use.
:param qdrant_local_path: The path to the storage directory for the Qdrant client, if local mode is used.
"""
@@ -18,29 +18,52 @@ class QdrantConnector:
qdrant_url: Optional[str],
qdrant_api_key: Optional[str],
collection_name: str,
fastembed_model_name: str,
embedding_provider: EmbeddingProvider,
qdrant_local_path: Optional[str] = None,
):
self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None
self._qdrant_api_key = qdrant_api_key
self._collection_name = collection_name
self._fastembed_model_name = fastembed_model_name
# For the time being, FastEmbed models are the only supported ones.
# A list of all available models can be found here:
# https://qdrant.github.io/fastembed/examples/Supported_Models/
self._client = AsyncQdrantClient(
location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path
)
self._client.set_model(fastembed_model_name)
self._embedding_provider = embedding_provider
self._client = AsyncQdrantClient(location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path)
async def _ensure_collection_exists(self):
"""Ensure that the collection exists, creating it if necessary."""
collection_exists = await self._client.collection_exists(self._collection_name)
if not collection_exists:
# Create the collection with the appropriate vector size
# We'll get the vector size by embedding a sample text
sample_vector = await self._embedding_provider.embed_query("sample text")
vector_size = len(sample_vector)
await self._client.create_collection(
collection_name=self._collection_name,
vectors_config=models.VectorParams(
size=vector_size,
distance=models.Distance.COSINE,
),
)
async def store_memory(self, information: str):
"""
Store a memory in the Qdrant collection.
:param information: The information to store.
"""
await self._client.add(
self._collection_name,
documents=[information],
await self._ensure_collection_exists()
# Embed the document
embeddings = await self._embedding_provider.embed_documents([information])
# Add to Qdrant
await self._client.upsert(
collection_name=self._collection_name,
points=[
models.PointStruct(
id=hash(information), # Simple hash as ID
vector=embeddings[0],
payload={"document": information},
)
],
)
async def find_memories(self, query: str) -> list[str]:
@@ -53,9 +76,14 @@ class QdrantConnector:
if not collection_exists:
return []
search_results = await self._client.query(
self._collection_name,
query_text=query,
# Embed the query
query_vector = await self._embedding_provider.embed_query(query)
# Search in Qdrant
search_results = await self._client.search(
collection_name=self._collection_name,
query_vector=query_vector,
limit=10,
)
return [result.document for result in search_results]
return [result.payload["document"] for result in search_results]