Source code for flash.image.embedding.model
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Any, Dict, List, Optional
from flash.core.adapter import AdapterTask
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _VISSL_AVAILABLE, requires
from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE
from flash.image.embedding.backbones import IMAGE_EMBEDDER_BACKBONES
from flash.image.embedding.strategies import IMAGE_EMBEDDER_STRATEGIES
from flash.image.embedding.transforms import IMAGE_EMBEDDER_TRANSFORMS
if _VISSL_AVAILABLE:
import classy_vision
import classy_vision.generic.distributed_util
# patch this to avoid classy vision/vissl based distributed training
classy_vision.generic.distributed_util.get_world_size = lambda: 1
# Skip doctests if requirements aren't available
__doctest_skip__ = []
if not _VISSL_AVAILABLE:
__doctest_skip__ += [
"ImageEmbedder",
"ImageEmbedder.*",
]
[docs]class ImageEmbedder(AdapterTask):
"""The ``ImageEmbedder`` is a :class:`~flash.Task` for obtaining feature vectors (embeddings) from images. For
more details, see :ref:`image_embedder`.
Args:
training_strategy: Training strategy from VISSL,
select between 'simclr', 'swav', 'dino', 'moco', or 'barlow_twins'.
head: projection head used for task, select between
'simclr_head', 'swav_head', 'dino_head', 'moco_head', or 'barlow_twins_head'.
pretraining_transform: transform applied to input image for pre-training SSL model.
Select between 'simclr_transform', 'swav_transform', 'dino_transform',
'moco_transform', or 'barlow_twins_transform'.
backbone: VISSL backbone, defaults to ``resnet``.
pretrained: Use a pretrained backbone, defaults to ``False``.
optimizer: Optimizer to use for training.
lr_scheduler: The LR scheduler to use during training.
learning_rate: Learning rate to use for training, defaults to ``1e-3``.
backbone_kwargs: arguments to be passed to VISSL backbones, i.e. ``vision_transformer`` and ``resnet``.
training_strategy_kwargs: arguments passed to VISSL loss function, projection head and training hooks.
pretraining_transform_kwargs: arguments passed to VISSL transforms.
"""
training_strategies: FlashRegistry = IMAGE_EMBEDDER_STRATEGIES
backbones: FlashRegistry = IMAGE_EMBEDDER_BACKBONES
transforms: FlashRegistry = IMAGE_EMBEDDER_TRANSFORMS
required_extras: List[str] = ["image", "vissl", "fairscale"]
def __init__(
self,
training_strategy: str,
head: str,
pretraining_transform: str,
backbone: str = "resnet",
pretrained: bool = False,
optimizer: OPTIMIZER_TYPE = "Adam",
lr_scheduler: LR_SCHEDULER_TYPE = None,
learning_rate: Optional[float] = None,
backbone_kwargs: Optional[Dict[str, Any]] = None,
training_strategy_kwargs: Optional[Dict[str, Any]] = None,
pretraining_transform_kwargs: Optional[Dict[str, Any]] = None,
):
self.save_hyperparameters()
if backbone_kwargs is None:
backbone_kwargs = {}
if training_strategy_kwargs is None:
training_strategy_kwargs = {}
if pretraining_transform_kwargs is None:
pretraining_transform_kwargs = {}
backbone, num_features = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs)
metadata = self.training_strategies.get(training_strategy, with_metadata=True)
loss_fn, head, hooks = metadata["fn"](head=head, num_features=num_features, **training_strategy_kwargs)
adapter = metadata["metadata"]["adapter"].from_task(
self,
loss_fn=loss_fn,
backbone=backbone,
head=head,
hooks=hooks,
)
super().__init__(
adapter=adapter,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
learning_rate=learning_rate,
)
self.input_transform = self.transforms.get(pretraining_transform)(**pretraining_transform_kwargs)
warnings.warn(
"Warning: VISSL ImageEmbedder overrides any user provided transforms"
" with pre-defined transforms for the training strategy."
)
def on_epoch_start(self) -> None:
self.adapter.on_epoch_start()
def on_train_start(self) -> None:
self.adapter.on_train_start()
def on_train_epoch_end(self) -> None:
self.adapter.on_train_epoch_end()
def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
self.adapter.on_train_batch_end(outputs, batch, batch_idx, dataloader_idx)
[docs] @classmethod
@requires(["image", "vissl", "fairscale"])
def available_training_strategies(cls) -> List[str]:
"""Get the list of available training strategies (passed to the ``training_strategy`` argument) for this
task.
Examples
________
.. doctest::
>>> from flash.image import ImageEmbedder
>>> ImageEmbedder.available_training_strategies() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
['barlow_twins', ..., 'swav']
"""
registry: Optional[FlashRegistry] = getattr(cls, "training_strategies", None)
if registry is None:
return []
return registry.available_keys()