LARS#
- class stable_ssl.optim.LARS(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False, eta=0.001, eps=1e-08, clip_lr=False, exclude_bias_n_norm=False)[source]#
Bases:
Optimizer
Extends SGD in PyTorch with LARS scaling from the paper.
Implementation based on Large batch training of Convolutional Networks.
- Parameters:
params (iterable) – iterable of parameters to optimize or dicts defining parameter groups
lr (float) – learning rate
momentum (float, optional) – momentum factor (default: 0)
weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)
dampening (float, optional) – dampening for momentum (default: 0)
nesterov (bool, optional) – enables Nesterov momentum (default: False)
eta (float, optional) – trust coefficient for computing LR (default: 0.001)
eps (float, optional) – eps for division denominator (default: 1e-8)
Example
>>> model = torch.nn.Linear(10, 1) >>> input = torch.Tensor(10) >>> target = torch.Tensor([1.]) >>> loss_fn = lambda input, target: (input - target) ** 2 >>> # >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
Note
The application of momentum in the SGD part is modified according to the PyTorch standards. LARS scaling fits into the equation in the following fashion. .. math:
\begin{aligned} g_{t+1} & = \text{lars_lr} * (\beta * p_{t} + g_{t+1}), \\ v_{t+1} & = \\mu * v_{t} + g_{t+1}, \\ p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, \\end{aligned}
where \(p\), \(g\), \(v\), \(\\mu\) and \(\beta\) denote the parameters, gradient, velocity, momentum, and weight decay respectively. The \(lars_lr\) is defined by Eq. 6 in the paper. The Nesterov version is analogously modified.
Warning
Parameters with weight decay set to 0 will automatically be excluded from layer-wise LR scaling. This is to ensure consistency with papers like SimCLR and BYOL.