Shortcuts

Source code for flash.core.optimizers.lr_scheduler

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Implemented by @ananyahjha93
# also found at: https://github.com/Lightning-AI/lightning-bolts/blob/master/pl_bolts/optimizers/lr_scheduler.py
import math
import warnings
from typing import List

from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE

# Skip doctests if requirements aren't available
if not _TOPIC_CORE_AVAILABLE:
    __doctest_skip__ = ["LinearWarmupCosineAnnealingLR"]


[docs]class LinearWarmupCosineAnnealingLR(_LRScheduler): """Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr and base_lr followed by a cosine annealing schedule between base_lr and eta_min. .. warning:: It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR` after each iteration as calling it after each epoch will keep the starting lr at warmup_start_lr for the first epoch which is 0 in most cases. .. warning:: passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING. It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of :func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling train and validation methods. Example: >>> from torch import nn >>> from torch.optim import Adam >>> layer = nn.Linear(10, 1) >>> optimizer = Adam(layer.parameters(), lr=0.02) >>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40) >>> # >>> # the default case >>> for epoch in range(40): ... # train(...) ... # validate(...) ... scheduler.step() >>> # >>> # passing epoch param case >>> for epoch in range(40): ... scheduler.step(epoch) ... # train(...) ... # validate(...) """ def __init__( self, optimizer: Optimizer, warmup_epochs: int, max_epochs: int, warmup_start_lr: float = 0.0, eta_min: float = 0.0, last_epoch: int = -1, ) -> None: """ Args: optimizer (Optimizer): Wrapped optimizer. warmup_epochs (int): Maximum number of iterations for linear warmup max_epochs (int): Maximum number of iterations warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. eta_min (float): Minimum learning rate. Default: 0. last_epoch (int): The index of last epoch. Default: -1. """ self.warmup_epochs = warmup_epochs self.max_epochs = max_epochs self.warmup_start_lr = warmup_start_lr self.eta_min = eta_min super().__init__(optimizer, last_epoch)
[docs] def get_lr(self) -> List[float]: """Compute learning rate using chainable form of the scheduler.""" if not self._get_lr_called_within_step: warnings.warn( "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning, ) if self.last_epoch == self.warmup_epochs: return self.base_lrs if self.last_epoch == 0: return [self.warmup_start_lr] * len(self.base_lrs) if self.last_epoch < self.warmup_epochs: return [ group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) ] if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0: return [ group["lr"] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2 for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) ] return [ (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) / ( 1 + math.cos( math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs) ) ) * (group["lr"] - self.eta_min) + self.eta_min for group in self.optimizer.param_groups ]
def _get_closed_form_lr(self) -> List[float]: """Called when epoch is passed as a param to the `step` function of the scheduler.""" if self.last_epoch < self.warmup_epochs: return [ self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / max(1, self.warmup_epochs - 1) for base_lr in self.base_lrs ] return [ self.eta_min + 0.5 * (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) for base_lr in self.base_lrs ]

© Copyright 2020-2021, PyTorch Lightning. Revision a9cedb5a.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: stable
Versions
latest
stable
0.8.2
0.8.1.post0
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.