Source code for gluonnlp.optimizer.bert_adam

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

"""Weight updating functions."""
import os
import warnings
import numpy
from mxnet.optimizer import Optimizer, register
from mxnet.ndarray import zeros, NDArray, full
from mxnet.ndarray.contrib import mp_adamw_update, adamw_update, \
    multi_mp_adamw_update, multi_adamw_update

__all__ = ['BERTAdam']

[docs]@register class BERTAdam(Optimizer): """The Adam optimizer with weight decay regularization for BERT. Updates are applied by:: rescaled_grad = clip(grad * rescale_grad, clip_gradient) m = beta1 * m + (1 - beta1) * rescaled_grad v = beta2 * v + (1 - beta2) * (rescaled_grad**2) w = w - learning_rate * (m / (sqrt(v) + epsilon) + wd * w) Note that this is different from `mxnet.optimizer.Adam`, where L2 loss is added and accumulated in m and v. In BERTAdam, the weight decay term decoupled from gradient based update. This is also slightly different from the AdamW optimizer described in *Fixing Weight Decay Regularization in Adam*, where the schedule multiplier and learning rate is decoupled, and the bias-correction terms are removed. The BERTAdam optimizer uses the same learning rate to apply gradients w.r.t. the loss and weight decay. This optimizer accepts the following parameters in addition to those accepted by :class:`mxnet.optimizer.Optimizer`. Parameters ---------- beta1 : float, optional, default is 0.9 Exponential decay rate for the first moment estimates. beta2 : float, optional, default is 0.999 Exponential decay rate for the second moment estimates. epsilon : float, optional, default is 1e-6 Small value to avoid division by 0. """ def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6, **kwargs): super(BERTAdam, self).__init__(learning_rate=learning_rate, **kwargs) self.beta1 = beta1 self.beta2 = beta2 self.epsilon = epsilon self.aggregate_num = max(1, min(50, int(os.getenv('MXNET_OPTIMIZER_AGGREGATION_SIZE', '4'))))
[docs] def create_state_multi_precision(self, index, weight): """multi-precision state creation function.""" weight_master_copy = None if self.multi_precision and weight.dtype == numpy.float16: weight_master_copy = weight.astype(numpy.float32) return (self.create_state(index, weight_master_copy), weight_master_copy) if weight.dtype == numpy.float16 and not self.multi_precision: warnings.warn('Accumulating with float16 in optimizer can lead to ' 'poor accuracy or slow convergence. ' 'Consider using multi_precision=True option of the ' 'BERTAdam optimizer') return self.create_state(index, weight)
[docs] def create_state(self, _, weight): """state creation function.""" return (zeros(weight.shape, weight.context, dtype=weight.dtype), #mean zeros(weight.shape, weight.context, dtype=weight.dtype)) #variance
[docs] def update(self, index, weight, grad, state): """update function""" self._update_impl(index, weight, grad, state, multi_precision=False)
[docs] def update_multi_precision(self, index, weight, grad, state): """multi-precision update function""" use_multi_precision = self.multi_precision and weight[0].dtype == numpy.float16 self._update_impl(index, weight, grad, state, multi_precision=use_multi_precision)
def _update_impl(self, indices, weight, grad, state, multi_precision=False): """update function""" aggregate = self.aggregate_num > 1 if not isinstance(indices, (tuple, list)): indices = [indices] weight = [weight] grad = [grad] state = [state] for w_i, g_i in zip(weight, grad): assert(isinstance(w_i, NDArray)) assert(isinstance(g_i, NDArray)) aggregate = (aggregate and w_i.stype == 'default' and g_i.stype == 'default') self._update_count(indices) lrs = self._get_lrs(indices) wds = self._get_wds(indices) # pylint: disable=access-member-before-definition if not isinstance(self.rescale_grad, NDArray): self.rescale_grad = full(shape=(1,), val=self.rescale_grad, ctx=weight[0].context) else: self.rescale_grad = self.rescale_grad.as_in_context(weight[0].context) kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon, 'rescale_grad': self.rescale_grad} if self.clip_gradient: kwargs['clip_gradient'] = self.clip_gradient if aggregate: current_index = 0 while current_index < len(indices): sidx = current_index eidx = min(current_index + self.aggregate_num, len(indices)) if not multi_precision: mean, var = list(zip(*state[sidx:eidx])) multi_adamw_update(weight[sidx:eidx], grad[sidx:eidx], mean, var, out=weight[sidx:eidx], size=len(weight[sidx:eidx]), lrs=list(numpy.ones(len(weight[sidx:eidx]))), wds=wds[sidx:eidx], etas=lrs[sidx:eidx], **kwargs) else: mean_var = list(zip(*state[sidx:eidx]))[0] tmean_var = list(zip(*mean_var)) mean = tmean_var[0] var = tmean_var[1] multi_mp_adamw_update(weight[sidx:eidx], grad[sidx:eidx], mean, var, list(zip(*state[sidx:eidx]))[1], out=weight[sidx:eidx], size=len(weight[sidx:eidx]), lrs=list(numpy.ones(len(weight[sidx:eidx]))), wds=wds[sidx:eidx], etas=lrs[sidx:eidx], **kwargs) current_index += self.aggregate_num else: for w_i, g_i, s_i, lr, wd in zip(weight, grad, state, lrs, wds): if not multi_precision: mean, var = s_i adamw_update(w_i, g_i, mean, var, out=w_i, lr=1, wd=wd, eta=lr, **kwargs) else: mean, var = s_i[0] mp_adamw_update(w_i, g_i, mean, var, s_i[1], out=w_i, lr=1, wd=wd, eta=lr, **kwargs)