Shortcuts

Source code for flash.image.embedding.model

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Any, Dict, List, Optional

from flash.core.adapter import AdapterTask
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _VISSL_AVAILABLE, requires
from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE
from flash.image.embedding.backbones import IMAGE_EMBEDDER_BACKBONES
from flash.image.embedding.strategies import IMAGE_EMBEDDER_STRATEGIES
from flash.image.embedding.transforms import IMAGE_EMBEDDER_TRANSFORMS

if _VISSL_AVAILABLE:
    import classy_vision
    import classy_vision.generic.distributed_util

    # patch this to avoid classy vision/vissl based distributed training
    classy_vision.generic.distributed_util.get_world_size = lambda: 1

# Skip doctests if requirements aren't available
__doctest_skip__ = []
if not _VISSL_AVAILABLE:
    __doctest_skip__ += [
        "ImageEmbedder",
        "ImageEmbedder.*",
    ]


[docs]class ImageEmbedder(AdapterTask): """The ``ImageEmbedder`` is a :class:`~flash.Task` for obtaining feature vectors (embeddings) from images. For more details, see :ref:`image_embedder`. Args: training_strategy: Training strategy from VISSL, select between 'simclr', 'swav', 'dino', 'moco', or 'barlow_twins'. head: projection head used for task, select between 'simclr_head', 'swav_head', 'dino_head', 'moco_head', or 'barlow_twins_head'. pretraining_transform: transform applied to input image for pre-training SSL model. Select between 'simclr_transform', 'swav_transform', 'dino_transform', 'moco_transform', or 'barlow_twins_transform'. backbone: VISSL backbone, defaults to ``resnet``. pretrained: Use a pretrained backbone, defaults to ``False``. optimizer: Optimizer to use for training. lr_scheduler: The LR scheduler to use during training. learning_rate: Learning rate to use for training, defaults to ``1e-3``. backbone_kwargs: arguments to be passed to VISSL backbones, i.e. ``vision_transformer`` and ``resnet``. training_strategy_kwargs: arguments passed to VISSL loss function, projection head and training hooks. pretraining_transform_kwargs: arguments passed to VISSL transforms. """ training_strategies: FlashRegistry = IMAGE_EMBEDDER_STRATEGIES backbones: FlashRegistry = IMAGE_EMBEDDER_BACKBONES transforms: FlashRegistry = IMAGE_EMBEDDER_TRANSFORMS required_extras: List[str] = ["image", "vissl", "fairscale"] def __init__( self, training_strategy: str, head: str, pretraining_transform: str, backbone: str = "resnet", pretrained: bool = False, optimizer: OPTIMIZER_TYPE = "Adam", lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: Optional[float] = None, backbone_kwargs: Optional[Dict[str, Any]] = None, training_strategy_kwargs: Optional[Dict[str, Any]] = None, pretraining_transform_kwargs: Optional[Dict[str, Any]] = None, ): self.save_hyperparameters() if backbone_kwargs is None: backbone_kwargs = {} if training_strategy_kwargs is None: training_strategy_kwargs = {} if pretraining_transform_kwargs is None: pretraining_transform_kwargs = {} backbone, num_features = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs) metadata = self.training_strategies.get(training_strategy, with_metadata=True) loss_fn, head, hooks = metadata["fn"](head=head, num_features=num_features, **training_strategy_kwargs) adapter = metadata["metadata"]["adapter"].from_task( self, loss_fn=loss_fn, backbone=backbone, head=head, hooks=hooks, ) super().__init__( adapter=adapter, optimizer=optimizer, lr_scheduler=lr_scheduler, learning_rate=learning_rate, ) self.input_transform = self.transforms.get(pretraining_transform)(**pretraining_transform_kwargs) warnings.warn( "Warning: VISSL ImageEmbedder overrides any user provided transforms" " with pre-defined transforms for the training strategy." ) def on_epoch_start(self) -> None: self.adapter.on_epoch_start() def on_train_start(self) -> None: self.adapter.on_train_start() def on_train_epoch_end(self) -> None: self.adapter.on_train_epoch_end() def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None: self.adapter.on_train_batch_end(outputs, batch, batch_idx, dataloader_idx)
[docs] @classmethod @requires(["image", "vissl", "fairscale"]) def available_training_strategies(cls) -> List[str]: """Get the list of available training strategies (passed to the ``training_strategy`` argument) for this task. Examples ________ .. doctest:: >>> from flash.image import ImageEmbedder >>> ImageEmbedder.available_training_strategies() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE ['barlow_twins', ..., 'swav'] """ registry: Optional[FlashRegistry] = getattr(cls, "training_strategies", None) if registry is None: return [] return registry.available_keys()

© Copyright 2020-2021, PyTorch Lightning. Revision a5e68476.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: 0.7.2
Versions
latest
stable
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.