Abstract the embedding providers
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
from typing import Optional, List
|
||||
from qdrant_client import AsyncQdrantClient, models
|
||||
from .embeddings.base import EmbeddingProvider
|
||||
|
||||
|
||||
class QdrantConnector:
|
||||
@@ -9,7 +9,7 @@ class QdrantConnector:
|
||||
:param qdrant_url: The URL of the Qdrant server.
|
||||
:param qdrant_api_key: The API key to use for the Qdrant server.
|
||||
:param collection_name: The name of the collection to use.
|
||||
:param fastembed_model_name: The name of the FastEmbed model to use.
|
||||
:param embedding_provider: The embedding provider to use.
|
||||
:param qdrant_local_path: The path to the storage directory for the Qdrant client, if local mode is used.
|
||||
"""
|
||||
|
||||
@@ -18,29 +18,52 @@ class QdrantConnector:
|
||||
qdrant_url: Optional[str],
|
||||
qdrant_api_key: Optional[str],
|
||||
collection_name: str,
|
||||
fastembed_model_name: str,
|
||||
embedding_provider: EmbeddingProvider,
|
||||
qdrant_local_path: Optional[str] = None,
|
||||
):
|
||||
self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None
|
||||
self._qdrant_api_key = qdrant_api_key
|
||||
self._collection_name = collection_name
|
||||
self._fastembed_model_name = fastembed_model_name
|
||||
# For the time being, FastEmbed models are the only supported ones.
|
||||
# A list of all available models can be found here:
|
||||
# https://qdrant.github.io/fastembed/examples/Supported_Models/
|
||||
self._client = AsyncQdrantClient(
|
||||
location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path
|
||||
)
|
||||
self._client.set_model(fastembed_model_name)
|
||||
self._embedding_provider = embedding_provider
|
||||
self._client = AsyncQdrantClient(location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path)
|
||||
|
||||
async def _ensure_collection_exists(self):
|
||||
"""Ensure that the collection exists, creating it if necessary."""
|
||||
collection_exists = await self._client.collection_exists(self._collection_name)
|
||||
if not collection_exists:
|
||||
# Create the collection with the appropriate vector size
|
||||
# We'll get the vector size by embedding a sample text
|
||||
sample_vector = await self._embedding_provider.embed_query("sample text")
|
||||
vector_size = len(sample_vector)
|
||||
|
||||
await self._client.create_collection(
|
||||
collection_name=self._collection_name,
|
||||
vectors_config=models.VectorParams(
|
||||
size=vector_size,
|
||||
distance=models.Distance.COSINE,
|
||||
),
|
||||
)
|
||||
|
||||
async def store_memory(self, information: str):
|
||||
"""
|
||||
Store a memory in the Qdrant collection.
|
||||
:param information: The information to store.
|
||||
"""
|
||||
await self._client.add(
|
||||
self._collection_name,
|
||||
documents=[information],
|
||||
await self._ensure_collection_exists()
|
||||
|
||||
# Embed the document
|
||||
embeddings = await self._embedding_provider.embed_documents([information])
|
||||
|
||||
# Add to Qdrant
|
||||
await self._client.upsert(
|
||||
collection_name=self._collection_name,
|
||||
points=[
|
||||
models.PointStruct(
|
||||
id=hash(information), # Simple hash as ID
|
||||
vector=embeddings[0],
|
||||
payload={"document": information},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
async def find_memories(self, query: str) -> list[str]:
|
||||
@@ -53,9 +76,14 @@ class QdrantConnector:
|
||||
if not collection_exists:
|
||||
return []
|
||||
|
||||
search_results = await self._client.query(
|
||||
self._collection_name,
|
||||
query_text=query,
|
||||
# Embed the query
|
||||
query_vector = await self._embedding_provider.embed_query(query)
|
||||
|
||||
# Search in Qdrant
|
||||
search_results = await self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
query_vector=query_vector,
|
||||
limit=10,
|
||||
)
|
||||
return [result.document for result in search_results]
|
||||
|
||||
return [result.payload["document"] for result in search_results]
|
||||
|
||||
Reference in New Issue
Block a user