Shortcuts

LARS

class flash.core.optimizers.LARS(params, lr=torch.optim.optimizer.required, momentum=0, dampening=0, weight_decay=0, nesterov=False, trust_coefficient=0.001, eps=1e-08)[source]

Extends SGD in PyTorch with LARS scaling from the paper Large batch training of Convolutional Networks.

Parameters
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

  • lr (float) – learning rate

  • momentum (float, optional) – momentum factor (default: 0)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • dampening (float, optional) – dampening for momentum (default: 0)

  • nesterov (bool, optional) – enables Nesterov momentum (default: False)

  • trust_coefficient (float, optional) – trust coefficient for computing LR (default: 0.001)

  • eps (float, optional) – eps for division denominator (default: 1e-8)

Example

>>> model = nn.Linear(10, 1)
>>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
>>> # loss_fn(model(input), target).backward()
>>> optimizer.step()

Note

The application of momentum in the SGD part is modified according to the PyTorch standards. LARS scaling fits into the equation in the following fashion.

\begin{aligned}
    g_{t+1} & = \text{lars\_lr} * (\beta * p_{t} + g_{t+1}), \\
    v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
    p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
\end{aligned}

where p, g, v, \mu and \beta denote the parameters, gradient, velocity, momentum, and weight decay respectively. The lars_lr is defined by Eq. 6 in the paper. The Nesterov version is analogously modified.

Warning

Parameters with weight decay set to 0 will automatically be excluded from layer-wise LR scaling. This is to ensure consistency with papers like SimCLR and BYOL.

step(closure=None)

Performs a single optimization step.

Parameters

closure (callable, optional) – A closure that reevaluates the model and returns the loss.

Read the Docs v: 0.7.4
Versions
latest
stable
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.