Shortcuts

Source code for flash.image.embedding.model

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Any, Dict, List, Optional

from pytorch_lightning.utilities import rank_zero_warn
from torch import Tensor

from flash.core.adapter import AdapterTask
from flash.core.data.io.input import DataKeys
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _VISSL_AVAILABLE, requires
from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE
from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES
from flash.image.embedding.strategies import IMAGE_EMBEDDER_STRATEGIES
from flash.image.embedding.transforms import IMAGE_EMBEDDER_TRANSFORMS

if _VISSL_AVAILABLE:
    import classy_vision
    import classy_vision.generic.distributed_util

    # patch this to avoid classy vision/vissl based distributed training
    classy_vision.generic.distributed_util.get_world_size = lambda: 1

# Skip doctests if requirements aren't available
__doctest_skip__ = []
if not _VISSL_AVAILABLE:
    __doctest_skip__ += [
        "ImageEmbedder",
        "ImageEmbedder.*",
    ]

_deprecated_backbones = {
    "resnet": "resnet50",
    "vision_transformer": "vit_small_patch16_224",
}


[docs]class ImageEmbedder(AdapterTask): """The ``ImageEmbedder`` is a :class:`~flash.Task` for obtaining feature vectors (embeddings) from images. For more details, see :ref:`image_embedder`. Args: training_strategy: Training strategy from VISSL, select between 'simclr', 'swav', or 'barlow_twins'. head: projection head used for task, select between 'simclr_head', 'swav_head', or 'barlow_twins_head'. pretraining_transform: transform applied to input image for pre-training SSL model. Select between 'simclr_transform', 'swav_transform', or 'barlow_twins_transform'. backbone: VISSL backbone, defaults to ``resnet``. pretrained: Use a pretrained backbone, defaults to ``False``. optimizer: Optimizer to use for training. lr_scheduler: The LR scheduler to use during training. learning_rate: Learning rate to use for training, defaults to ``1e-3``. backbone_kwargs: arguments to be passed to VISSL backbones, i.e. ``vision_transformer`` and ``resnet``. training_strategy_kwargs: arguments passed to VISSL loss function, projection head and training hooks. pretraining_transform_kwargs: arguments passed to VISSL transforms. """ training_strategies: FlashRegistry = IMAGE_EMBEDDER_STRATEGIES backbones: FlashRegistry = IMAGE_CLASSIFIER_BACKBONES transforms: FlashRegistry = IMAGE_EMBEDDER_TRANSFORMS required_extras: str = "image" def __init__( self, training_strategy: str = "default", head: Optional[str] = None, pretraining_transform: Optional[str] = None, backbone: str = "resnet18", pretrained: bool = False, optimizer: OPTIMIZER_TYPE = "Adam", lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: Optional[float] = None, backbone_kwargs: Optional[Dict[str, Any]] = None, training_strategy_kwargs: Optional[Dict[str, Any]] = None, pretraining_transform_kwargs: Optional[Dict[str, Any]] = None, ): self.save_hyperparameters() if backbone_kwargs is None: backbone_kwargs = {} if training_strategy_kwargs is None: training_strategy_kwargs = {} if pretraining_transform_kwargs is None: pretraining_transform_kwargs = {} if backbone in _deprecated_backbones: rank_zero_warn( f"The '{backbone}' backbone for the `ImageEmbedder` is deprecated in v0.8 and will be removed " f"in v0.9. Use '{_deprecated_backbones[backbone]}' instead.", category=FutureWarning, ) backbone = _deprecated_backbones[backbone] model, num_features = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs) metadata = self.training_strategies.get(training_strategy, with_metadata=True) loss_fn, head, hooks = metadata["fn"](head=head, num_features=num_features, **training_strategy_kwargs) adapter = metadata["metadata"]["adapter"].from_task( task=self, loss_fn=loss_fn, backbone=model, head=head, hooks=hooks, ) super().__init__( adapter=adapter, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, learning_rate=learning_rate, ) if pretraining_transform is not None: warnings.warn( "Overriding any transforms from the `DataModule` with the pretraining transform: " f"{pretraining_transform}." ) self.input_transform = self.transforms.get(pretraining_transform)(**pretraining_transform_kwargs) if "providers" in metadata["metadata"] and metadata["metadata"]["providers"].name == "Facebook Research/vissl": if pretraining_transform is None: raise ValueError("Correct pretraining_transform must be set to use VISSL") def forward(self, x: Tensor) -> Any: return self.model(x) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: return self(batch[DataKeys.INPUT]) def on_epoch_start(self) -> None: self.adapter.on_epoch_start() def on_train_start(self) -> None: self.adapter.on_train_start() def on_train_epoch_end(self) -> None: self.adapter.on_train_epoch_end() def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, *args) -> None: self.adapter.on_train_batch_end(outputs, batch, batch_idx, *args)
[docs] @classmethod @requires("image", "vissl", "fairscale") def available_training_strategies(cls) -> List[str]: """Get the list of available training strategies (passed to the ``training_strategy`` argument) for this task. Examples ________ .. doctest:: >>> from flash.image import ImageEmbedder >>> ImageEmbedder.available_training_strategies() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE ['barlow_twins', ..., 'swav'] """ registry: Optional[FlashRegistry] = getattr(cls, "training_strategies", None) if registry is None: return [] return registry.available_keys()

© Copyright 2020-2021, PyTorch Lightning. Revision da42a635.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: stable
Versions
latest
stable
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.