Image Embedder¶
The Task¶
Image embedding encodes an image into a vector of features which can be used for a downstream task. This could include: clustering, similarity search, or classification.
The ImageEmbedder
internally relies on VISSL.
Example¶
Let’s see how to configure a training strategy for the ImageEmbedder
task.
A vanilla DataModule
object be created using standard Datasets as shown below.
Then the user can configure the ImageEmbedder
task with training_strategy
, backbone
, head
and pretraining_transform
.
There are options provided to send additional arguments to config selections.
This task can now be sent to the fit()
method of Trainer
.
Note
A lot of VISSL loss functions use hard-coded torch.distributed
methods. The user is suggested to use accelerator=ddp
even with a single GPU.
Only barlow_twins
training strategy works on the CPU. All other loss functions are configured to work on GPUs.
import torch
from torchvision.datasets import CIFAR10
import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageEmbedder
# 1. Download the data and prepare the datamodule
datamodule = ImageClassificationData.from_datasets(
train_dataset=CIFAR10(".", download=True),
batch_size=16,
)
# 2. Build the task
embedder = ImageEmbedder(
backbone="resnet",
training_strategy="barlow_twins",
head="simclr_head",
pretraining_transform="barlow_twins_transform",
training_strategy_kwargs={"latent_embedding_dim": 128},
pretraining_transform_kwargs={"size_crops": [196]},
)
# 3. Create the trainer and pre-train the encoder
# use accelerator='ddp' when using GPU(s),
# i.e. flash.Trainer(max_epochs=3, gpus=1, accelerator='ddp')
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(embedder, datamodule=datamodule)
# 4. Save the model!
trainer.save_checkpoint("image_embedder_model.pt")
# 5. Download the downstream prediction dataset and generate embeddings
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")
datamodule = ImageClassificationData.from_files(
predict_files=[
"data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg",
"data/hymenoptera_data/predict/2039585088_c6f47c592e.jpg",
]
)
embeddings = trainer.predict(embedder, datamodule=datamodule)
# list of embeddings for images sent to the predict function
print(embeddings)