101 lines
3.8 KiB
Python
101 lines
3.8 KiB
Python
import uuid
|
|
from typing import Optional
|
|
|
|
from qdrant_client import AsyncQdrantClient, models
|
|
|
|
from .embeddings.base import EmbeddingProvider
|
|
|
|
|
|
class QdrantConnector:
|
|
"""
|
|
Encapsulates the connection to a Qdrant server and all the methods to interact with it.
|
|
:param qdrant_url: The URL of the Qdrant server.
|
|
:param qdrant_api_key: The API key to use for the Qdrant server.
|
|
:param collection_name: The name of the collection to use.
|
|
:param embedding_provider: The embedding provider to use.
|
|
:param qdrant_local_path: The path to the storage directory for the Qdrant client, if local mode is used.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
qdrant_url: Optional[str],
|
|
qdrant_api_key: Optional[str],
|
|
collection_name: str,
|
|
embedding_provider: EmbeddingProvider,
|
|
qdrant_local_path: Optional[str] = None,
|
|
):
|
|
self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None
|
|
self._qdrant_api_key = qdrant_api_key
|
|
self._collection_name = collection_name
|
|
self._embedding_provider = embedding_provider
|
|
self._client = AsyncQdrantClient(
|
|
location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path
|
|
)
|
|
|
|
async def _ensure_collection_exists(self):
|
|
"""Ensure that the collection exists, creating it if necessary."""
|
|
collection_exists = await self._client.collection_exists(self._collection_name)
|
|
if not collection_exists:
|
|
# Create the collection with the appropriate vector size
|
|
# We'll get the vector size by embedding a sample text
|
|
sample_vector = await self._embedding_provider.embed_query("sample text")
|
|
vector_size = len(sample_vector)
|
|
|
|
# Use the vector name as defined in the embedding provider
|
|
vector_name = self._embedding_provider.get_vector_name()
|
|
await self._client.create_collection(
|
|
collection_name=self._collection_name,
|
|
vectors_config={
|
|
vector_name: models.VectorParams(
|
|
size=vector_size,
|
|
distance=models.Distance.COSINE,
|
|
)
|
|
},
|
|
)
|
|
|
|
async def store_memory(self, information: str):
|
|
"""
|
|
Store a memory in the Qdrant collection.
|
|
:param information: The information to store.
|
|
"""
|
|
await self._ensure_collection_exists()
|
|
|
|
# Embed the document
|
|
embeddings = await self._embedding_provider.embed_documents([information])
|
|
|
|
# Add to Qdrant
|
|
vector_name = self._embedding_provider.get_vector_name()
|
|
await self._client.upsert(
|
|
collection_name=self._collection_name,
|
|
points=[
|
|
models.PointStruct(
|
|
id=uuid.uuid4().hex,
|
|
vector={vector_name: embeddings[0]},
|
|
payload={"document": information},
|
|
)
|
|
],
|
|
)
|
|
|
|
async def find_memories(self, query: str) -> list[str]:
|
|
"""
|
|
Find memories in the Qdrant collection. If there are no memories found, an empty list is returned.
|
|
:param query: The query to use for the search.
|
|
:return: A list of memories found.
|
|
"""
|
|
collection_exists = await self._client.collection_exists(self._collection_name)
|
|
if not collection_exists:
|
|
return []
|
|
|
|
# Embed the query
|
|
query_vector = await self._embedding_provider.embed_query(query)
|
|
vector_name = self._embedding_provider.get_vector_name()
|
|
|
|
# Search in Qdrant
|
|
search_results = await self._client.search(
|
|
collection_name=self._collection_name,
|
|
query_vector=models.NamedVector(name=vector_name, vector=query_vector),
|
|
limit=10,
|
|
)
|
|
|
|
return [result.payload["document"] for result in search_results]
|