chainer.optimizer_hooks.GradientLARS¶
-
class
chainer.optimizer_hooks.
GradientLARS
(threshold=0.01, weight_decay=0.0, eps=1e-09)[source]¶ Optimizer/UpdateRule hook function for layer wise adaptive rate scaling.
See: Large Batch Training of Convolutional Networks.
See: Convergence Analysis of Gradient Descent Algorithms with Proportional Updates.
This hook function scales all gradient arrays to fit to the weight norm.
In <https://arxiv.org/abs/1708.03888>,
vt+1=m∗vt+γ∗λ∗(∇L(wt)+βwt),wt+1=wt−vt+1,where
γ : learning_rate
m : momentum
β : weight_decay
η : lars_coeeficient
λ: local_lr =η∗‖.
As lr in chainer.optimizers.SGD or chainer.optimizers.MomentumSGD corresponds to \gamma * \eta, we define clip\_rate as \frac{\|w_t\|}{\|\nabla L(w_t)\| + \beta * \|w_t\|} and reformulate the aforementioned formula as: v_{t+1} = m * v_t + lr * clip\_rate * (\nabla L(w_t) + \beta w_t) and implement in this way. So you do not set lars_coeeficient.
- Parameters
threashold (float) – If weight norm is more than threshold, this function scales all gradient arrays to fit weight norm. (See <https://arxiv.org/abs/1801.03137>)
weight_decay (float) – Coefficient for the weight decay.
eps (float) – Small value for the numerical stability. (See <https://arxiv.org/abs/1801.03137>)
- Variables
threashold (float) – If weight norm is more than threshold, this function scales all gradient arrays to fit weight norm. (See <https://arxiv.org/abs/1801.03137>)
weight_decay (float) – Coefficient for the weight decay.
eps (float) – Small value for the numerical stability. (See <https://arxiv.org/abs/1801.03137>)
timing (string) – Specifies when this hook should be called by the Optimizer/UpdateRule. Valid values are ‘pre’ (before any updates) and ‘post’ (after any updates).
call_for_each_param (bool) – Specifies if this hook is called for each parameter (
True
) or only once (False
) by an optimizer to which this hook is registered. This function does not expect users to switch the value from default one, which is True.
Methods
Attributes
-
call_for_each_param
= True¶
-
name
= 'GradientLARS'¶
-
timing
= 'pre'¶