Shortcuts

Source code for flash.core.optimizers.lars

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Implemented by @ananyahjha93
# also found at: https://github.com/gridai-labs/aavae/tree/main/src/optimizers
# References:
#     - https://arxiv.org/pdf/1708.03888.pdf
#     - https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py
import torch
from torch import nn
from torch.optim.optimizer import Optimizer, required

from flash.core.utilities.imports import _CORE_TESTING

# Skip doctests if requirements aren't available
if not _CORE_TESTING:
    __doctest_skip__ = ["LARS"]


[docs]class LARS(Optimizer): r"""Extends SGD in PyTorch with LARS scaling from the paper `Large batch training of Convolutional Networks <https://arxiv.org/pdf/1708.03888.pdf>`_. Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float): learning rate momentum (float, optional): momentum factor (default: 0) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) dampening (float, optional): dampening for momentum (default: 0) nesterov (bool, optional): enables Nesterov momentum (default: False) trust_coefficient (float, optional): trust coefficient for computing LR (default: 0.001) eps (float, optional): eps for division denominator (default: 1e-8) Example: >>> model = nn.Linear(10, 1) >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9) >>> optimizer.zero_grad() >>> # loss_fn(model(input), target).backward() >>> optimizer.step() .. note:: The application of momentum in the SGD part is modified according to the PyTorch standards. LARS scaling fits into the equation in the following fashion. .. math:: \begin{aligned} g_{t+1} & = \text{lars\_lr} * (\beta * p_{t} + g_{t+1}), \\ v_{t+1} & = \mu * v_{t} + g_{t+1}, \\ p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, \end{aligned} where :math:`p`, :math:`g`, :math:`v`, :math:`\mu` and :math:`\beta` denote the parameters, gradient, velocity, momentum, and weight decay respectively. The :math:`lars_lr` is defined by Eq. 6 in the paper. The Nesterov version is analogously modified. .. warning:: Parameters with weight decay set to 0 will automatically be excluded from layer-wise LR scaling. This is to ensure consistency with papers like SimCLR and BYOL. """ def __init__( self, params, lr=required, momentum: float = 0, dampening: float = 0, weight_decay: float = 0, nesterov: bool = False, trust_coefficient: float = 0.001, eps: float = 1e-8, ): if lr is not required and lr < 0.0: raise ValueError(f"Invalid learning rate: {lr}") if momentum < 0.0: raise ValueError(f"Invalid momentum value: {momentum}") if weight_decay < 0.0: raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, nesterov=nesterov) if nesterov and (momentum <= 0 or dampening != 0): raise ValueError("Nesterov momentum requires a momentum and zero dampening") self.eps = eps self.trust_coefficient = trust_coefficient super().__init__(params, defaults) def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: group.setdefault("nesterov", False) @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() # exclude scaling for params with 0 weight decay for group in self.param_groups: weight_decay = group["weight_decay"] momentum = group["momentum"] dampening = group["dampening"] nesterov = group["nesterov"] for p in group["params"]: if p.grad is None: continue d_p = p.grad p_norm = torch.norm(p.data) g_norm = torch.norm(p.grad.data) # lars scaling + weight decay part if weight_decay != 0: if p_norm != 0 and g_norm != 0: lars_lr = p_norm / (g_norm + p_norm * weight_decay + self.eps) lars_lr *= self.trust_coefficient d_p = d_p.add(p, alpha=weight_decay) d_p *= lars_lr # sgd part if momentum != 0: param_state = self.state[p] if "momentum_buffer" not in param_state: buf = param_state["momentum_buffer"] = torch.clone(d_p).detach() else: buf = param_state["momentum_buffer"] buf.mul_(momentum).add_(d_p, alpha=1 - dampening) if nesterov: d_p = d_p.add(buf, alpha=momentum) else: d_p = buf p.add_(d_p, alpha=-group["lr"]) return loss

© Copyright 2020-2021, PyTorch Lightning. Revision da42a635.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: stable
Versions
latest
stable
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.