Source code for flash.core.optimizers.lars
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Implemented by @ananyahjha93
# also found at: https://github.com/gridai-labs/aavae/tree/main/src/optimizers
# References:
# - https://arxiv.org/pdf/1708.03888.pdf
# - https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py
import torch
from torch import nn
from torch.optim.optimizer import Optimizer, required
[docs]class LARS(Optimizer):
r"""Extends SGD in PyTorch with LARS scaling from the paper
`Large batch training of Convolutional Networks <https://arxiv.org/pdf/1708.03888.pdf>`_.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): learning rate
momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
trust_coefficient (float, optional): trust coefficient for computing LR (default: 0.001)
eps (float, optional): eps for division denominator (default: 1e-8)
Example:
>>> model = nn.Linear(10, 1)
>>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
>>> # loss_fn(model(input), target).backward()
>>> optimizer.step()
.. note::
The application of momentum in the SGD part is modified according to
the PyTorch standards. LARS scaling fits into the equation in the
following fashion.
.. math::
\begin{aligned}
g_{t+1} & = \text{lars\_lr} * (\beta * p_{t} + g_{t+1}), \\
v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
\end{aligned}
where :math:`p`, :math:`g`, :math:`v`, :math:`\mu` and :math:`\beta` denote the
parameters, gradient, velocity, momentum, and weight decay respectively.
The :math:`lars_lr` is defined by Eq. 6 in the paper.
The Nesterov version is analogously modified.
.. warning::
Parameters with weight decay set to 0 will automatically be excluded from
layer-wise LR scaling. This is to ensure consistency with papers like SimCLR
and BYOL.
"""
def __init__(
self,
params,
lr=required,
momentum: float = 0,
dampening: float = 0,
weight_decay: float = 0,
nesterov: bool = False,
trust_coefficient: float = 0.001,
eps: float = 1e-8,
):
if lr is not required and lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if momentum < 0.0:
raise ValueError(f"Invalid momentum value: {momentum}")
if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, nesterov=nesterov)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
self.eps = eps
self.trust_coefficient = trust_coefficient
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("nesterov", False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
# exclude scaling for params with 0 weight decay
for group in self.param_groups:
weight_decay = group["weight_decay"]
momentum = group["momentum"]
dampening = group["dampening"]
nesterov = group["nesterov"]
for p in group["params"]:
if p.grad is None:
continue
d_p = p.grad
p_norm = torch.norm(p.data)
g_norm = torch.norm(p.grad.data)
# lars scaling + weight decay part
if weight_decay != 0:
if p_norm != 0 and g_norm != 0:
lars_lr = p_norm / (g_norm + p_norm * weight_decay + self.eps)
lars_lr *= self.trust_coefficient
d_p = d_p.add(p, alpha=weight_decay)
d_p *= lars_lr
# sgd part
if momentum != 0:
param_state = self.state[p]
if "momentum_buffer" not in param_state:
buf = param_state["momentum_buffer"] = torch.clone(d_p).detach()
else:
buf = param_state["momentum_buffer"]
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
if nesterov:
d_p = d_p.add(buf, alpha=momentum)
else:
d_p = buf
p.add_(d_p, alpha=-group["lr"])
return loss