Beta
The VISSL integration is currently in Beta. The API and functionality may change without warning in future releases. More details.
Warning
Multi-gpu training is not currently supported by the ImageEmbedder task.
Image Embedder¶
The Task¶
Image embedding encodes an image into a vector of features which can be used for a downstream task. This could include: clustering, similarity search, or classification.
The Flash ImageEmbedder can be trained with Self Supervised Learning (SSL) to improve the quality of the embeddings it produces for your data.
The ImageEmbedder internally relies on VISSL.
You can read more about our integration with VISSL here: VISSL.
Example¶
Let’s see how to configure a training strategy for the ImageEmbedder task.
First we create an ImageClassificationData object using a Dataset from torchvision.
Next, we configure the ImageEmbedder task with training_strategy, backbone, head and pretraining_transform.
Finally, we construct a Trainer and call fit().
Here’s the full example:
import flash
import torch
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageEmbedder
from torchvision.datasets import CIFAR10
# 1. Download the data and prepare the datamodule
datamodule = ImageClassificationData.from_datasets(
train_dataset=CIFAR10(".", download=True),
batch_size=8,
)
# 2. Build the task
embedder = ImageEmbedder(
backbone="resnet18",
training_strategy="barlow_twins",
head="barlow_twins_head",
pretraining_transform="barlow_twins_transform",
training_strategy_kwargs={"latent_embedding_dim": 128},
pretraining_transform_kwargs={"size_crops": [32]},
)
# 3. Create the trainer and pre-train the encoder
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.fit(embedder, datamodule=datamodule)
# 4. Save the model!
trainer.save_checkpoint("image_embedder_model.pt")
# 5. Download the downstream prediction dataset and generate embeddings
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")
datamodule = ImageClassificationData.from_files(
predict_files=[
"data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg",
"data/hymenoptera_data/predict/2039585088_c6f47c592e.jpg",
],
batch_size=2,
)
embeddings = trainer.predict(embedder, datamodule=datamodule)
# list of embeddings for images sent to the predict function
print(embeddings)
To learn how to view the available backbones / heads for this task, see Backbones and Heads.
You can view the available training strategies with the available_training_strategies() method.
The head and pretraining_transform arguments should match the choice of training_strategy following this table:
|
|
|
|---|---|---|
|
|
|
|
|
|
|
|
|