Shortcuts

LARS

class flash.core.optimizers.LARS(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False, trust_coefficient=0.001, eps=1e-08)[source]

Extends SGD in PyTorch with LARS scaling.

See the paper Large batch training of Convolutional Networks

Parameters
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

  • lr (float) – learning rate

  • momentum (float, optional) – momentum factor (default: 0)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • dampening (float, optional) – dampening for momentum (default: 0)

  • nesterov (bool, optional) – enables Nesterov momentum (default: False)

  • trust_coefficient (float, optional) – trust coefficient for computing LR (default: 0.001)

  • eps (float, optional) – eps for division denominator (default: 1e-8)

Example

>>> from torch import nn
>>> model = nn.Linear(10, 1)
>>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
>>> # loss_fn(model(input), target).backward()
>>> optimizer.step()

Note

The application of momentum in the SGD part is modified according to the PyTorch standards. LARS scaling fits into the equation in the following fashion.

\begin{aligned}
    g_{t+1} & = \text{lars\_lr} * (\beta * p_{t} + g_{t+1}), \\
    v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
    p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
\end{aligned}

where p, g, v, \mu and \beta denote the parameters, gradient, velocity, momentum, and weight decay respectively. The lars_lr is defined by Eq. 6 in the paper. The Nesterov version is analogously modified.

Warning

Parameters with weight decay set to 0 will automatically be excluded from layer-wise LR scaling. This is to ensure consistency with papers like SimCLR and BYOL.

step(closure=None)[source]

Performs a single optimization step.

Parameters

closure (callable, optional) – A closure that reevaluates the model and returns the loss.

Read the Docs v: stable
Versions
latest
stable
0.8.2
0.8.1.post0
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.