Gluonnlp provides some special optimizers for training in natural language processing.

BERTAdam Optimizer

The Adam optimizer with weight decay regularization for BERT.


The Adam optimizer with weight decay regularization for BERT.

API Reference

NLP optimizer.

class gluonnlp.optimizer.BERTAdam(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-06, **kwargs)[source]

The Adam optimizer with weight decay regularization for BERT.

Updates are applied by:

rescaled_grad = clip(grad * rescale_grad, clip_gradient)
m = beta1 * m + (1 - beta1) * rescaled_grad
v = beta2 * v + (1 - beta2) * (rescaled_grad**2)
w = w - learning_rate * (m / (sqrt(v) + epsilon) + wd * w)

Note that this is different from mxnet.optimizer.Adam, where L2 loss is added and accumulated in m and v. In BERTAdam, the weight decay term decoupled from gradient based update.

This is also slightly different from the AdamW optimizer described in Fixing Weight Decay Regularization in Adam, where the schedule multiplier and learning rate is decoupled, and the bias-correction terms are removed. The BERTAdam optimizer uses the same learning rate to apply gradients w.r.t. the loss and weight decay.

This optimizer accepts the following parameters in addition to those accepted by mxnet.optimizer.Optimizer.

  • beta1 (float, optional, default is 0.9) – Exponential decay rate for the first moment estimates.

  • beta2 (float, optional, default is 0.999) – Exponential decay rate for the second moment estimates.

  • epsilon (float, optional, default is 1e-6) – Small value to avoid division by 0.

create_state(_, weight)[source]

state creation function.

create_state_multi_precision(index, weight)[source]

multi-precision state creation function.

update(index, weight, grad, state)[source]

update function

update_multi_precision(index, weight, grad, state)[source]

multi-precision update function