Shortcuts

LARS

class flash.core.optimizers.LARS(params, lr=torch.optim.optimizer.required, momentum=0, dampening=0, weight_decay=0, nesterov=False, trust_coefficient=0.001, eps=1e-08)[source]

Extends SGD in PyTorch with LARS scaling from the paper Large batch training of Convolutional Networks.

Parameters
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

  • lr (float) – learning rate

  • momentum (float, optional) – momentum factor (default: 0)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • dampening (float, optional) – dampening for momentum (default: 0)

  • nesterov (bool, optional) – enables Nesterov momentum (default: False)

  • trust_coefficient (float, optional) – trust coefficient for computing LR (default: 0.001)

  • eps (float, optional) – eps for division denominator (default: 1e-8)

Example

>>> model = nn.Linear(10, 1)
>>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
>>> # loss_fn(model(input), target).backward()
>>> optimizer.step()

Note

The application of momentum in the SGD part is modified according to the PyTorch standards. LARS scaling fits into the equation in the following fashion.

\begin{aligned}
    g_{t+1} & = \text{lars\_lr} * (\beta * p_{t} + g_{t+1}), \\
    v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
    p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
\end{aligned}

where p, g, v, \mu and \beta denote the parameters, gradient, velocity, momentum, and weight decay respectively. The lars_lr is defined by Eq. 6 in the paper. The Nesterov version is analogously modified.

Warning

Parameters with weight decay set to 0 will automatically be excluded from layer-wise LR scaling. This is to ensure consistency with papers like SimCLR and BYOL.

step(closure=None)

Performs a single optimization step.

Parameters

closure (callable, optional) – A closure that reevaluates the model and returns the loss.

Read the Docs v: 0.8.1
Versions
latest
stable
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.