Shortcuts

Beta

The VISSL integration is currently in Beta. The API and functionality may change without warning in future releases. More details.

Warning

Multi-gpu training is not currently supported by the ImageEmbedder task.

Image Embedder

The Task

Image embedding encodes an image into a vector of features which can be used for a downstream task. This could include: clustering, similarity search, or classification.

The Flash ImageEmbedder can be trained with Self Supervised Learning (SSL) to improve the quality of the embeddings it produces for your data. The ImageEmbedder internally relies on VISSL. You can read more about our integration with VISSL here: VISSL.


Example

Let’s see how to configure a training strategy for the ImageEmbedder task. First we create an ImageClassificationData object using a Dataset from torchvision. Next, we configure the ImageEmbedder task with training_strategy, backbone, head and pretraining_transform. Finally, we construct a Trainer and call fit(). Here’s the full example:

import torch
from torchvision.datasets import CIFAR10

import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageEmbedder

# 1. Download the data and prepare the datamodule
datamodule = ImageClassificationData.from_datasets(
    train_dataset=CIFAR10(".", download=True),
    batch_size=8,
)

# 2. Build the task
embedder = ImageEmbedder(
    backbone="resnet18",
    training_strategy="barlow_twins",
    head="barlow_twins_head",
    pretraining_transform="barlow_twins_transform",
    training_strategy_kwargs={"latent_embedding_dim": 128},
    pretraining_transform_kwargs={"size_crops": [32]},
)

# 3. Create the trainer and pre-train the encoder
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.fit(embedder, datamodule=datamodule)

# 4. Save the model!
trainer.save_checkpoint("image_embedder_model.pt")

# 5. Download the downstream prediction dataset and generate embeddings
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

datamodule = ImageClassificationData.from_files(
    predict_files=[
        "data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg",
        "data/hymenoptera_data/predict/2039585088_c6f47c592e.jpg",
    ],
    batch_size=2,
)
embeddings = trainer.predict(embedder, datamodule=datamodule)

# list of embeddings for images sent to the predict function
print(embeddings)

To learn how to view the available backbones / heads for this task, see Backbones and Heads. You can view the available training strategies with the available_training_strategies() method.

The head and pretraining_transform arguments should match the choice of training_strategy following this table:

training_strategy

head

pretraining_transform

simclr

simclr_head

simclr_transform

barlow_twins

barlow_twins_head

barlow_twins_transform

swav

swav_head

swav_transform

Read the Docs v: stable
Versions
latest
stable
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.