Shortcuts

Image Embedder

The Task

Image embedding encodes an image into a vector of features which can be used for a downstream task. This could include: clustering, similarity search, or classification.

The ImageEmbedder internally relies on VISSL.


Example

Let’s see how to configure a training strategy for the ImageEmbedder task. A vanilla DataModule object be created using standard Datasets as shown below. Then the user can configure the ImageEmbedder task with training_strategy, backbone, head and pretraining_transform. There are options provided to send additional arguments to config selections. This task can now be sent to the fit() method of Trainer.

Note

A lot of VISSL loss functions use hard-coded torch.distributed methods. The user is suggested to use accelerator=ddp even with a single GPU. Only barlow_twins training strategy works on the CPU. All other loss functions are configured to work on GPUs.

import torch
from torchvision.datasets import CIFAR10

import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageEmbedder

# 1. Download the data and prepare the datamodule
datamodule = ImageClassificationData.from_datasets(
    train_dataset=CIFAR10(".", download=True),
    batch_size=16,
)

# 2. Build the task
embedder = ImageEmbedder(
    backbone="resnet",
    training_strategy="barlow_twins",
    head="simclr_head",
    pretraining_transform="barlow_twins_transform",
    training_strategy_kwargs={"latent_embedding_dim": 128},
    pretraining_transform_kwargs={"size_crops": [196]},
)

# 3. Create the trainer and pre-train the encoder
# use accelerator='ddp' when using GPU(s),
# i.e. flash.Trainer(max_epochs=3, gpus=1, accelerator='ddp')
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(embedder, datamodule=datamodule)

# 4. Save the model!
trainer.save_checkpoint("image_embedder_model.pt")

# 5. Download the downstream prediction dataset and generate embeddings
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

datamodule = ImageClassificationData.from_files(
    predict_files=[
        "data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg",
        "data/hymenoptera_data/predict/2039585088_c6f47c592e.jpg",
    ]
)
embeddings = trainer.predict(embedder, datamodule=datamodule)

# list of embeddings for images sent to the predict function
print(embeddings)
Read the Docs v: stable
Versions
latest
stable
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.