Download this tutorial

Training GNMT on IWSLT 2015 Dataset

In this notebook, we are going to train Google NMT on IWSLT 2015 English-Vietnamese Dataset. The building process includes four key steps:

  1. Load and preprocess the dataset

  2. Create a sampler and DataLoader

  3. Build the actual model

  4. Write the training algorithm

This tutorial will guide you through each of the steps and explain briefly how each works. Please remember to click the download button at the top of the page to download the necessary files to follow this tutorial.

Setup

Firstly, we need to setup the environment and import the necessary modules. For this tutorial, a GPU is highly important.

[1]:
import warnings
warnings.filterwarnings('ignore')

import argparse
import time
import random
import os
import io
import logging
import numpy as np
import mxnet as mx
from mxnet import gluon
import gluonnlp as nlp
import nmt
nlp.utils.check_version('0.7.0')

Next, we need to specify the hyperparameters for the dataset, the model, and for training and testing time.

[2]:
np.random.seed(100)
random.seed(100)
mx.random.seed(10000)
ctx = mx.gpu(0)

# parameters for dataset
dataset = 'IWSLT2015'
src_lang, tgt_lang = 'en', 'vi'
src_max_len, tgt_max_len = 50, 50

# parameters for model
num_hidden = 512
num_layers = 2
num_bi_layers = 1
dropout = 0.2

# parameters for training
batch_size, test_batch_size = 128, 32
num_buckets = 5
epochs = 1
clip = 5
lr = 0.001
lr_update_factor = 0.5
log_interval = 10
save_dir = 'gnmt_en_vi_u512'

#parameters for testing
beam_size = 10
lp_alpha = 1.0
lp_k = 5

nmt.utils.logging_config(save_dir)
All Logs will be saved to gnmt_en_vi_u512/<ipython-input-2-4699ac3a1bfb>.log
[2]:
'gnmt_en_vi_u512'

Loading and processing the dataset

The following shows how to process the dataset and cache the processed dataset for future use. The processing steps include the following:

  1. Clipping the source and target sequences

  2. Splitting the string input to a list of tokens

  3. Mapping the string token onto its integer index in the vocabulary

  4. Appending the end-of-sentence (EOS) token to source sentence and adding BOS and EOS tokens to the target sentence

Firstly, we load and cache the dataset with the two helper functions cache_dataset and load_cached_dataset. The functions are straightforward and well commented so no further explanation will be given.

[3]:
def cache_dataset(dataset, prefix):
    """Cache the processed npy dataset  the dataset into an npz file

    Parameters
    ----------
    dataset : gluon.data.SimpleDataset
    file_path : str
    """
    if not os.path.exists(nmt._constants.CACHE_PATH):
        os.makedirs(nmt._constants.CACHE_PATH)
    src_data = np.concatenate([e[0] for e in dataset])
    tgt_data = np.concatenate([e[1] for e in dataset])
    src_cumlen = np.cumsum([0]+[len(e[0]) for e in dataset])
    tgt_cumlen = np.cumsum([0]+[len(e[1]) for e in dataset])
    np.savez(os.path.join(nmt._constants.CACHE_PATH, prefix + '.npz'),
             src_data=src_data, tgt_data=tgt_data,
             src_cumlen=src_cumlen, tgt_cumlen=tgt_cumlen)


def load_cached_dataset(prefix):
    cached_file_path = os.path.join(nmt._constants.CACHE_PATH, prefix + '.npz')
    if os.path.exists(cached_file_path):
        print('Load cached data from {}'.format(cached_file_path))
        npz_data = np.load(cached_file_path)
        src_data, tgt_data, src_cumlen, tgt_cumlen = [npz_data[n] for n in
                ['src_data', 'tgt_data', 'src_cumlen', 'tgt_cumlen']]
        src_data = np.array([src_data[low:high] for low, high in zip(src_cumlen[:-1], src_cumlen[1:])])
        tgt_data = np.array([tgt_data[low:high] for low, high in zip(tgt_cumlen[:-1], tgt_cumlen[1:])])
        return gluon.data.ArrayDataset(np.array(src_data), np.array(tgt_data))
    else:
        return None

Next, we write the class TrainValDataTransform to have easy access to transforming and clipping the source and target sentences. This class also adds the EOS and BOS tokens for cleaner data. Please refer to the comments in the code for more details.

[4]:
class TrainValDataTransform(object):
    """Transform the machine translation dataset.

    Clip source and the target sentences to the maximum length. For the source sentence, append the
    EOS. For the target sentence, append BOS and EOS.

    Parameters
    ----------
    src_vocab : Vocab
    tgt_vocab : Vocab
    src_max_len : int
    tgt_max_len : int
    """

    def __init__(self, src_vocab, tgt_vocab, src_max_len, tgt_max_len):
        # On initialization of the class, we set the class variables
        self._src_vocab = src_vocab
        self._tgt_vocab = tgt_vocab
        self._src_max_len = src_max_len
        self._tgt_max_len = tgt_max_len

    def __call__(self, src, tgt):
        # On actual calling of the class, we perform the clipping then the appending of the EOS and BOS tokens.
        if self._src_max_len > 0:
            src_sentence = self._src_vocab[src.split()[:self._src_max_len]]
        else:
            src_sentence = self._src_vocab[src.split()]
        if self._tgt_max_len > 0:
            tgt_sentence = self._tgt_vocab[tgt.split()[:self._tgt_max_len]]
        else:
            tgt_sentence = self._tgt_vocab[tgt.split()]
        src_sentence.append(self._src_vocab[self._src_vocab.eos_token])
        tgt_sentence.insert(0, self._tgt_vocab[self._tgt_vocab.bos_token])
        tgt_sentence.append(self._tgt_vocab[self._tgt_vocab.eos_token])
        src_npy = np.array(src_sentence, dtype=np.int32)
        tgt_npy = np.array(tgt_sentence, dtype=np.int32)
        return src_npy, tgt_npy

We leverage the class written above to create a helper function that processes the dataset in very few lines of code.

[5]:
def process_dataset(dataset, src_vocab, tgt_vocab, src_max_len=-1, tgt_max_len=-1):
    start = time.time()
    dataset_processed = dataset.transform(TrainValDataTransform(src_vocab, tgt_vocab,
                                                                src_max_len,
                                                                tgt_max_len), lazy=False)
    end = time.time()
    print('Processing time spent: {}'.format(end - start))
    return dataset_processed

Here we define a function load_translation_data that combines all the above steps to load the data, check if it’s been processed, and if not, process the data. The method returns all of the required data for training, validating, and testing our model. Please refer to the comments in the code for more information on what each piece does.

[6]:
def load_translation_data(dataset, src_lang='en', tgt_lang='vi'):
    """Load translation dataset

    Parameters
    ----------
    dataset : str
    src_lang : str, default 'en'
    tgt_lang : str, default 'vi'

    Returns
    -------
    data_train_processed : Dataset
        The preprocessed training sentence pairs
    data_val_processed : Dataset
        The preprocessed validation sentence pairs
    data_test_processed : Dataset
        The preprocessed test sentence pairs
    val_tgt_sentences : list
        The target sentences in the validation set
    test_tgt_sentences : list
        The target sentences in the test set
    src_vocab : Vocab
        Vocabulary of the source language
    tgt_vocab : Vocab
        Vocabulary of the target language
    """
    common_prefix = 'IWSLT2015_{}_{}_{}_{}'.format(src_lang, tgt_lang,
                                                   src_max_len, tgt_max_len)

    # Load the three datasets from files
    data_train = nlp.data.IWSLT2015('train', src_lang=src_lang, tgt_lang=tgt_lang)
    data_val = nlp.data.IWSLT2015('val', src_lang=src_lang, tgt_lang=tgt_lang)
    data_test = nlp.data.IWSLT2015('test', src_lang=src_lang, tgt_lang=tgt_lang)
    src_vocab, tgt_vocab = data_train.src_vocab, data_train.tgt_vocab
    data_train_processed = load_cached_dataset(common_prefix + '_train')

    # Check if each dataset has been processed or not, and if not, process and cache them.
    if not data_train_processed:
        data_train_processed = process_dataset(data_train, src_vocab, tgt_vocab,
                                               src_max_len, tgt_max_len)
        cache_dataset(data_train_processed, common_prefix + '_train')
    data_val_processed = load_cached_dataset(common_prefix + '_val')
    if not data_val_processed:
        data_val_processed = process_dataset(data_val, src_vocab, tgt_vocab)
        cache_dataset(data_val_processed, common_prefix + '_val')
    data_test_processed = load_cached_dataset(common_prefix + '_test')
    if not data_test_processed:
        data_test_processed = process_dataset(data_test, src_vocab, tgt_vocab)
        cache_dataset(data_test_processed, common_prefix + '_test')

    # Pull out the target sentences for both test and validation
    fetch_tgt_sentence = lambda src, tgt: tgt.split()
    val_tgt_sentences = list(data_val.transform(fetch_tgt_sentence))
    test_tgt_sentences = list(data_test.transform(fetch_tgt_sentence))

    # Return all of the necessary pieces we can extract from the data for training our model
    return data_train_processed, data_val_processed, data_test_processed, \
           val_tgt_sentences, test_tgt_sentences, src_vocab, tgt_vocab

We define a last helper function get_data_lengths to get the length of the datasets, again, for simplified cleaner code later.

[7]:
def get_data_lengths(dataset):
    return list(dataset.transform(lambda srg, tgt: (len(srg), len(tgt))))

And for the last step of processing, we leverage all of our helper functions to keep the code concise and to these 15-20 lines for use in our main. This does all of the aforementioned processing along with storing the necessary information in memory for training our model.

[8]:
data_train, data_val, data_test, val_tgt_sentences, test_tgt_sentences, src_vocab, tgt_vocab\
    = load_translation_data(dataset=dataset, src_lang=src_lang, tgt_lang=tgt_lang)
data_train_lengths = get_data_lengths(data_train)
data_val_lengths = get_data_lengths(data_val)
data_test_lengths = get_data_lengths(data_test)

with io.open(os.path.join(save_dir, 'val_gt.txt'), 'w', encoding='utf-8') as of:
    for ele in val_tgt_sentences:
        of.write(' '.join(ele) + '\n')

with io.open(os.path.join(save_dir, 'test_gt.txt'), 'w', encoding='utf-8') as of:
    for ele in test_tgt_sentences:
        of.write(' '.join(ele) + '\n')


data_train = data_train.transform(lambda src, tgt: (src, tgt, len(src), len(tgt)), lazy=False)
data_val = gluon.data.SimpleDataset([(ele[0], ele[1], len(ele[0]), len(ele[1]), i)
                                     for i, ele in enumerate(data_val)])
data_test = gluon.data.SimpleDataset([(ele[0], ele[1], len(ele[0]), len(ele[1]), i)
                                      for i, ele in enumerate(data_test)])
Downloading /root/.mxnet/datasets/iwslt2015/iwslt15.zip from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/iwslt2015/iwslt15.zip...
Processing time spent: 4.309457778930664
Processing time spent: 0.04656338691711426
Processing time spent: 0.04290056228637695

Sampler and DataLoader construction

Now, we have obtained and stored all of the relevant data information. The next step is to construct the sampler and DataLoader. The first step is to use the batchify function, which pads and stacks sequences to form mini-batches.

[9]:
train_batchify_fn = nlp.data.batchify.Tuple(nlp.data.batchify.Pad(pad_val=0),
                                            nlp.data.batchify.Pad(pad_val=0),
                                            nlp.data.batchify.Stack(dtype='float32'),
                                            nlp.data.batchify.Stack(dtype='float32'))
test_batchify_fn = nlp.data.batchify.Tuple(nlp.data.batchify.Pad(pad_val=0),
                                           nlp.data.batchify.Pad(pad_val=0),
                                           nlp.data.batchify.Stack(dtype='float32'),
                                           nlp.data.batchify.Stack(dtype='float32'),
                                           nlp.data.batchify.Stack())

We can then construct bucketing samplers, which generate batches by grouping sequences with similar lengths. Here, the bucketing scheme is empirically determined.

[10]:
bucket_scheme = nlp.data.ExpWidthBucket(bucket_len_step=1.2)
train_batch_sampler = nlp.data.FixedBucketSampler(lengths=data_train_lengths,
                                                  batch_size=batch_size,
                                                  num_buckets=num_buckets,
                                                  shuffle=True,
                                                  bucket_scheme=bucket_scheme)
logging.info('Train Batch Sampler:\n{}'.format(train_batch_sampler.stats()))
val_batch_sampler = nlp.data.FixedBucketSampler(lengths=data_val_lengths,
                                                batch_size=test_batch_size,
                                                num_buckets=num_buckets,
                                                shuffle=False)
logging.info('Valid Batch Sampler:\n{}'.format(val_batch_sampler.stats()))
test_batch_sampler = nlp.data.FixedBucketSampler(lengths=data_test_lengths,
                                                 batch_size=test_batch_size,
                                                 num_buckets=num_buckets,
                                                 shuffle=False)
logging.info('Test Batch Sampler:\n{}'.format(test_batch_sampler.stats()))
2020-09-10 18:56:00,482 - root - Train Batch Sampler:
FixedBucketSampler:
  sample_num=133166, batch_num=1043
  key=[(9, 10), (16, 17), (26, 27), (37, 38), (51, 52)]
  cnt=[11414, 34897, 37760, 23480, 25615]
  batch_size=[128, 128, 128, 128, 128]
2020-09-10 18:56:00,485 - root - Valid Batch Sampler:
FixedBucketSampler:
  sample_num=1553, batch_num=52
  key=[(22, 28), (40, 52), (58, 76), (76, 100), (94, 124)]
  cnt=[1037, 432, 67, 10, 7]
  batch_size=[32, 32, 32, 32, 32]
2020-09-10 18:56:00,488 - root - Test Batch Sampler:
FixedBucketSampler:
  sample_num=1268, batch_num=42
  key=[(23, 29), (43, 53), (63, 77), (83, 101), (103, 125)]
  cnt=[770, 381, 84, 26, 7]
  batch_size=[32, 32, 32, 32, 32]

Given the samplers, we can create a DataLoader, which is iterable. This simply is a data construct (an iterator) that can feed the model batches at a time. For more information refer to this page.

[11]:
train_data_loader = gluon.data.DataLoader(data_train,
                                          batch_sampler=train_batch_sampler,
                                          batchify_fn=train_batchify_fn,
                                          num_workers=4)
val_data_loader = gluon.data.DataLoader(data_val,
                                        batch_sampler=val_batch_sampler,
                                        batchify_fn=test_batchify_fn,
                                        num_workers=4)
test_data_loader = gluon.data.DataLoader(data_test,
                                         batch_sampler=test_batch_sampler,
                                         batchify_fn=test_batchify_fn,
                                         num_workers=4)

Building the GNMT model

After obtaining the DataLoader, we can finally build the model. The GNMT encoder and decoder can be easily constructed by calling get_gnmt_encoder_decoder function. Then, we feed the encoder and decoder to the NMTModel to construct the GNMT model.

model.hybridize allows computation to be done using the symbolic backend. To understand what it means to be “hybridized,” please refer to this page on MXNet hybridization and its advantages.

[12]:
encoder, decoder, one_step_ahead_decoder = nmt.gnmt.get_gnmt_encoder_decoder(
    hidden_size=num_hidden, dropout=dropout, num_layers=num_layers,
    num_bi_layers=num_bi_layers)
model = nlp.model.translation.NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder,
                                       decoder=decoder, one_step_ahead_decoder=one_step_ahead_decoder,
                                       embed_size=num_hidden, prefix='gnmt_')
model.initialize(init=mx.init.Uniform(0.1), ctx=ctx)
static_alloc = True
model.hybridize(static_alloc=static_alloc)
logging.info(model)

# Due to the paddings, we need to mask out the losses corresponding to padding tokens.
loss_function = nlp.loss.MaskedSoftmaxCELoss()
loss_function.hybridize(static_alloc=static_alloc)
2020-09-10 18:56:04,955 - root - NMTModel(
  (encoder): GNMTEncoder(
    (dropout_layer): Dropout(p = 0.2, axes=())
    (rnn_cells): HybridSequential(
      (0): BidirectionalCell(forward=LSTMCell(None -> 2048), backward=LSTMCell(None -> 2048))
      (1): LSTMCell(None -> 2048)
    )
  )
  (decoder): GNMTDecoder(
    (attention_cell): DotProductAttentionCell(
      (_dropout_layer): Dropout(p = 0.0, axes=())
      (_proj_query): Dense(None -> 512, linear)
    )
    (dropout_layer): Dropout(p = 0.2, axes=())
    (rnn_cells): HybridSequential(
      (0): LSTMCell(None -> 2048)
      (1): LSTMCell(None -> 2048)
    )
  )
  (one_step_ahead_decoder): GNMTOneStepDecoder(
    (attention_cell): DotProductAttentionCell(
      (_dropout_layer): Dropout(p = 0.0, axes=())
      (_proj_query): Dense(None -> 512, linear)
    )
    (dropout_layer): Dropout(p = 0.2, axes=())
    (rnn_cells): HybridSequential(
      (0): LSTMCell(None -> 2048)
      (1): LSTMCell(None -> 2048)
    )
  )
  (src_embed): HybridSequential(
    (0): Embedding(17191 -> 512, float32)
    (1): Dropout(p = 0.0, axes=())
  )
  (tgt_embed): HybridSequential(
    (0): Embedding(7709 -> 512, float32)
    (1): Dropout(p = 0.0, axes=())
  )
  (tgt_proj): Dense(None -> 7709, linear)
)

Here, we build the BeamSearchTranslator and define a predetermined BeamSearchScorer as the heuristical mechanism for the search. For more information on Beam Search and its applications to NLP, check here.

[13]:
translator = nmt.translation.BeamSearchTranslator(model=model, beam_size=beam_size,
                                                  scorer=nlp.model.BeamSearchScorer(alpha=lp_alpha,
                                                                                    K=lp_k),
                                                  max_length=tgt_max_len + 100)
logging.info('Use beam_size={}, alpha={}, K={}'.format(beam_size, lp_alpha, lp_k))
2020-09-10 18:56:04,964 - root - Use beam_size=10, alpha=1.0, K=5

We define the evaluation function as shown in the code block below. The evaluate function uses the beam search translator to generate outputs for the validation and testing datasets. Please refer to the comments in the code for more information on what each piece does. In addition, we add the write_sentences helper method to easily output the sentences.

[14]:
def evaluate(data_loader):
    """Evaluate given the data loader

    Parameters
    ----------
    data_loader : gluon.data.DataLoader

    Returns
    -------
    avg_loss : float
        Average loss
    real_translation_out : list of list of str
        The translation output
    """
    translation_out = []
    all_inst_ids = []
    avg_loss_denom = 0
    avg_loss = 0.0

    for _, (src_seq, tgt_seq, src_valid_length, tgt_valid_length, inst_ids) \
            in enumerate(data_loader):
        src_seq = src_seq.as_in_context(ctx)
        tgt_seq = tgt_seq.as_in_context(ctx)
        src_valid_length = src_valid_length.as_in_context(ctx)
        tgt_valid_length = tgt_valid_length.as_in_context(ctx)

        # Calculate Loss
        out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1)
        loss = loss_function(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean().asscalar()
        all_inst_ids.extend(inst_ids.asnumpy().astype(np.int32).tolist())
        avg_loss += loss * (tgt_seq.shape[1] - 1)
        avg_loss_denom += (tgt_seq.shape[1] - 1)

        # Translate the sequences and score them
        samples, _, sample_valid_length =\
            translator.translate(src_seq=src_seq, src_valid_length=src_valid_length)
        max_score_sample = samples[:, 0, :].asnumpy()
        sample_valid_length = sample_valid_length[:, 0].asnumpy()

        # Iterate through the tokens and stitch the tokens together for the sentence
        for i in range(max_score_sample.shape[0]):
            translation_out.append(
                [tgt_vocab.idx_to_token[ele] for ele in
                 max_score_sample[i][1:(sample_valid_length[i] - 1)]])

    # Calculate the average loss and initialize a None-filled translation list
    avg_loss = avg_loss / avg_loss_denom
    real_translation_out = [None for _ in range(len(all_inst_ids))]

    # Combine all the words/tokens into a sentence for the final translation
    for ind, sentence in zip(all_inst_ids, translation_out):
        real_translation_out[ind] = sentence

    # Return the loss and the translation
    return avg_loss, real_translation_out


def write_sentences(sentences, file_path):
    with io.open(file_path, 'w', encoding='utf-8') as of:
        for sent in sentences:
            of.write(' '.join(sent) + '\n')

Training

Before entering the training stage, we need to create a trainer for updating the parameters based on the loss. In the following example, we create a trainer that uses the ADAM optimizer.

[15]:
trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': lr})

We can then write the training loop. During the training, we evaluate on the validation and testing datasets every epoch, and record the parameters that give the highest Bilingual Evaluation Understudy Score (BLEU) score on the validation dataset. Before performing forward and backward computation, we first use the as_in_context function to copy the mini-batch to the GPU. The statement with mx.autograd.record() tells Gluon’s backend to compute the gradients for the part inside the block.

[16]:
best_valid_bleu = 0.0

# Run through each epoch
for epoch_id in range(epochs):
    log_avg_loss = 0
    log_avg_gnorm = 0
    log_wc = 0
    log_start_time = time.time()

    # Iterate through each batch
    for batch_id, (src_seq, tgt_seq, src_valid_length, tgt_valid_length)\
            in enumerate(train_data_loader):

        src_seq = src_seq.as_in_context(ctx)
        tgt_seq = tgt_seq.as_in_context(ctx)
        src_valid_length = src_valid_length.as_in_context(ctx)
        tgt_valid_length = tgt_valid_length.as_in_context(ctx)

        # Compute gradients and losses
        with mx.autograd.record():
            out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1)
            loss = loss_function(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean()
            loss = loss * (tgt_seq.shape[1] - 1) / (tgt_valid_length - 1).mean()
            loss.backward()

        grads = [p.grad(ctx) for p in model.collect_params().values()]
        gnorm = gluon.utils.clip_global_norm(grads, clip)
        trainer.step(1)
        src_wc = src_valid_length.sum().asscalar()
        tgt_wc = (tgt_valid_length - 1).sum().asscalar()
        step_loss = loss.asscalar()
        log_avg_loss += step_loss
        log_avg_gnorm += gnorm
        log_wc += src_wc + tgt_wc
        if (batch_id + 1) % log_interval == 0:
            wps = log_wc / (time.time() - log_start_time)
            logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, gnorm={:.4f}, '
                         'throughput={:.2f}K wps, wc={:.2f}K'
                         .format(epoch_id, batch_id + 1, len(train_data_loader),
                                 log_avg_loss / log_interval,
                                 np.exp(log_avg_loss / log_interval),
                                 log_avg_gnorm / log_interval,
                                 wps / 1000, log_wc / 1000))
            log_start_time = time.time()
            log_avg_loss = 0
            log_avg_gnorm = 0
            log_wc = 0

    # Evaluate the losses on validation and test datasets and find the corresponding BLEU score and log it
    valid_loss, valid_translation_out = evaluate(val_data_loader)
    valid_bleu_score, _, _, _, _ = nmt.bleu.compute_bleu([val_tgt_sentences], valid_translation_out)
    logging.info('[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'
                 .format(epoch_id, valid_loss, np.exp(valid_loss), valid_bleu_score * 100))
    test_loss, test_translation_out = evaluate(test_data_loader)
    test_bleu_score, _, _, _, _ = nmt.bleu.compute_bleu([test_tgt_sentences], test_translation_out)
    logging.info('[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'
                 .format(epoch_id, test_loss, np.exp(test_loss), test_bleu_score * 100))

    # Output the sentences we predicted on the validation and test datasets
    write_sentences(valid_translation_out,
                    os.path.join(save_dir, 'epoch{:d}_valid_out.txt').format(epoch_id))
    write_sentences(test_translation_out,
                    os.path.join(save_dir, 'epoch{:d}_test_out.txt').format(epoch_id))

    # Save the model if the BLEU score is better than the previous best
    if valid_bleu_score > best_valid_bleu:
        best_valid_bleu = valid_bleu_score
        save_path = os.path.join(save_dir, 'valid_best.params')
        logging.info('Save best parameters to {}'.format(save_path))
        model.save_parameters(save_path)

    # Update the learning rate based on the number of epochs that have passed
    if epoch_id + 1 >= (epochs * 2) // 3:
        new_lr = trainer.learning_rate * lr_update_factor
        logging.info('Learning rate change to {}'.format(new_lr))
        trainer.set_learning_rate(new_lr)
2020-09-10 18:56:09,290 - root - [Epoch 0 Batch 10/1043] loss=7.7636, ppl=2353.4034, gnorm=1.5034, throughput=12.71K wps, wc=54.27K
2020-09-10 18:56:10,930 - root - [Epoch 0 Batch 20/1043] loss=6.3430, ppl=568.5267, gnorm=1.5107, throughput=30.62K wps, wc=50.20K
2020-09-10 18:56:13,090 - root - [Epoch 0 Batch 30/1043] loss=6.3594, ppl=577.8955, gnorm=0.7529, throughput=31.39K wps, wc=67.78K
2020-09-10 18:56:15,119 - root - [Epoch 0 Batch 40/1043] loss=6.1655, ppl=476.0433, gnorm=0.5288, throughput=31.16K wps, wc=63.19K
2020-09-10 18:56:17,148 - root - [Epoch 0 Batch 50/1043] loss=6.1833, ppl=484.5670, gnorm=0.5388, throughput=30.55K wps, wc=61.93K
2020-09-10 18:56:19,080 - root - [Epoch 0 Batch 60/1043] loss=6.0925, ppl=442.5096, gnorm=0.7216, throughput=30.65K wps, wc=59.19K
2020-09-10 18:56:21,426 - root - [Epoch 0 Batch 70/1043] loss=6.1401, ppl=464.0799, gnorm=0.4698, throughput=31.14K wps, wc=72.99K
2020-09-10 18:56:23,545 - root - [Epoch 0 Batch 80/1043] loss=6.0602, ppl=428.4726, gnorm=0.3620, throughput=30.48K wps, wc=64.58K
2020-09-10 18:56:25,320 - root - [Epoch 0 Batch 90/1043] loss=5.9279, ppl=375.3819, gnorm=0.3948, throughput=29.91K wps, wc=53.02K
2020-09-10 18:56:27,235 - root - [Epoch 0 Batch 100/1043] loss=5.8780, ppl=357.0978, gnorm=0.4438, throughput=31.04K wps, wc=59.42K
2020-09-10 18:56:29,333 - root - [Epoch 0 Batch 110/1043] loss=5.8629, ppl=351.7384, gnorm=0.3592, throughput=31.26K wps, wc=65.50K
2020-09-10 18:56:31,281 - root - [Epoch 0 Batch 120/1043] loss=5.8458, ppl=345.7877, gnorm=0.3202, throughput=30.00K wps, wc=58.43K
2020-09-10 18:56:33,389 - root - [Epoch 0 Batch 130/1043] loss=5.8980, ppl=364.3204, gnorm=0.3507, throughput=28.20K wps, wc=59.39K
2020-09-10 18:56:35,432 - root - [Epoch 0 Batch 140/1043] loss=5.8481, ppl=346.5848, gnorm=0.2840, throughput=29.98K wps, wc=61.18K
2020-09-10 18:56:37,336 - root - [Epoch 0 Batch 150/1043] loss=5.7623, ppl=318.0699, gnorm=0.2929, throughput=29.63K wps, wc=56.34K
2020-09-10 18:56:39,320 - root - [Epoch 0 Batch 160/1043] loss=5.6958, ppl=297.6206, gnorm=0.3834, throughput=29.22K wps, wc=57.93K
2020-09-10 18:56:41,465 - root - [Epoch 0 Batch 170/1043] loss=5.7127, ppl=302.6945, gnorm=0.3042, throughput=30.08K wps, wc=64.42K
2020-09-10 18:56:42,979 - root - [Epoch 0 Batch 180/1043] loss=5.4089, ppl=223.3751, gnorm=0.3627, throughput=29.31K wps, wc=44.31K
2020-09-10 18:56:45,080 - root - [Epoch 0 Batch 190/1043] loss=5.5708, ppl=262.6376, gnorm=0.3558, throughput=29.77K wps, wc=62.45K
2020-09-10 18:56:46,926 - root - [Epoch 0 Batch 200/1043] loss=5.4712, ppl=237.7463, gnorm=0.3522, throughput=29.40K wps, wc=54.24K
2020-09-10 18:56:48,723 - root - [Epoch 0 Batch 210/1043] loss=5.2892, ppl=198.1918, gnorm=0.4377, throughput=29.34K wps, wc=52.68K
2020-09-10 18:56:50,428 - root - [Epoch 0 Batch 220/1043] loss=5.2762, ppl=195.6183, gnorm=0.3712, throughput=29.65K wps, wc=50.47K
2020-09-10 18:56:52,457 - root - [Epoch 0 Batch 230/1043] loss=5.2838, ppl=197.1127, gnorm=0.4389, throughput=30.42K wps, wc=61.65K
2020-09-10 18:56:54,365 - root - [Epoch 0 Batch 240/1043] loss=5.2746, ppl=195.3190, gnorm=0.3568, throughput=30.60K wps, wc=58.32K
2020-09-10 18:56:56,688 - root - [Epoch 0 Batch 250/1043] loss=5.4011, ppl=221.6552, gnorm=0.2991, throughput=30.52K wps, wc=70.84K
2020-09-10 18:56:58,671 - root - [Epoch 0 Batch 260/1043] loss=5.2482, ppl=190.2232, gnorm=0.3616, throughput=30.40K wps, wc=60.22K
2020-09-10 18:57:01,048 - root - [Epoch 0 Batch 270/1043] loss=5.3489, ppl=210.3777, gnorm=0.2841, throughput=30.71K wps, wc=72.96K
2020-09-10 18:57:03,104 - root - [Epoch 0 Batch 280/1043] loss=5.2222, ppl=185.3480, gnorm=0.2746, throughput=29.59K wps, wc=60.80K
2020-09-10 18:57:04,669 - root - [Epoch 0 Batch 290/1043] loss=4.9363, ppl=139.2498, gnorm=0.3440, throughput=29.27K wps, wc=45.79K
2020-09-10 18:57:06,617 - root - [Epoch 0 Batch 300/1043] loss=5.0548, ppl=156.7787, gnorm=0.3332, throughput=30.33K wps, wc=59.05K
2020-09-10 18:57:08,690 - root - [Epoch 0 Batch 310/1043] loss=5.0704, ppl=159.2324, gnorm=0.3007, throughput=29.73K wps, wc=61.58K
2020-09-10 18:57:10,505 - root - [Epoch 0 Batch 320/1043] loss=4.9469, ppl=140.7447, gnorm=0.3110, throughput=29.27K wps, wc=53.10K
2020-09-10 18:57:12,561 - root - [Epoch 0 Batch 330/1043] loss=4.9999, ppl=148.3954, gnorm=0.2902, throughput=29.87K wps, wc=61.37K
2020-09-10 18:57:14,541 - root - [Epoch 0 Batch 340/1043] loss=5.0215, ppl=151.6404, gnorm=0.2808, throughput=28.74K wps, wc=56.88K
2020-09-10 18:57:16,403 - root - [Epoch 0 Batch 350/1043] loss=4.8833, ppl=132.0599, gnorm=0.3249, throughput=29.49K wps, wc=54.86K
2020-09-10 18:57:18,531 - root - [Epoch 0 Batch 360/1043] loss=4.9959, ppl=147.8075, gnorm=0.2769, throughput=30.36K wps, wc=64.55K
2020-09-10 18:57:20,771 - root - [Epoch 0 Batch 370/1043] loss=4.8776, ppl=131.3112, gnorm=0.3714, throughput=29.92K wps, wc=66.97K
2020-09-10 18:57:22,601 - root - [Epoch 0 Batch 380/1043] loss=4.7877, ppl=120.0269, gnorm=0.3218, throughput=28.87K wps, wc=52.79K
2020-09-10 18:57:24,369 - root - [Epoch 0 Batch 390/1043] loss=4.7544, ppl=116.0908, gnorm=0.3316, throughput=28.83K wps, wc=50.94K
2020-09-10 18:57:26,015 - root - [Epoch 0 Batch 400/1043] loss=4.6091, ppl=100.3901, gnorm=0.3578, throughput=29.31K wps, wc=48.22K
2020-09-10 18:57:27,717 - root - [Epoch 0 Batch 410/1043] loss=4.7704, ppl=117.9686, gnorm=0.3023, throughput=28.40K wps, wc=48.27K
2020-09-10 18:57:29,663 - root - [Epoch 0 Batch 420/1043] loss=4.8054, ppl=122.1683, gnorm=0.2976, throughput=28.88K wps, wc=56.14K
2020-09-10 18:57:31,982 - root - [Epoch 0 Batch 430/1043] loss=4.8537, ppl=128.2162, gnorm=0.3029, throughput=29.91K wps, wc=69.33K
2020-09-10 18:57:34,209 - root - [Epoch 0 Batch 440/1043] loss=4.7528, ppl=115.9037, gnorm=0.3052, throughput=30.14K wps, wc=67.08K
2020-09-10 18:57:36,051 - root - [Epoch 0 Batch 450/1043] loss=4.6582, ppl=105.4495, gnorm=0.3361, throughput=29.15K wps, wc=53.68K
2020-09-10 18:57:37,751 - root - [Epoch 0 Batch 460/1043] loss=4.4439, ppl=85.1083, gnorm=0.3661, throughput=29.66K wps, wc=50.38K
2020-09-10 18:57:39,858 - root - [Epoch 0 Batch 470/1043] loss=4.7184, ppl=111.9923, gnorm=0.3061, throughput=28.84K wps, wc=60.70K
2020-09-10 18:57:41,682 - root - [Epoch 0 Batch 480/1043] loss=4.3705, ppl=79.0805, gnorm=0.3532, throughput=29.66K wps, wc=54.04K
2020-09-10 18:57:43,326 - root - [Epoch 0 Batch 490/1043] loss=4.4818, ppl=88.3892, gnorm=0.4521, throughput=28.20K wps, wc=46.32K
2020-09-10 18:57:45,133 - root - [Epoch 0 Batch 500/1043] loss=4.5235, ppl=92.1537, gnorm=0.3182, throughput=26.89K wps, wc=48.55K
2020-09-10 18:57:46,596 - root - [Epoch 0 Batch 510/1043] loss=4.3002, ppl=73.7157, gnorm=0.3437, throughput=28.48K wps, wc=41.62K
2020-09-10 18:57:48,008 - root - [Epoch 0 Batch 520/1043] loss=4.2065, ppl=67.1245, gnorm=0.3895, throughput=28.34K wps, wc=39.98K
2020-09-10 18:57:50,010 - root - [Epoch 0 Batch 530/1043] loss=4.6181, ppl=101.2972, gnorm=0.3040, throughput=29.28K wps, wc=58.58K
2020-09-10 18:57:51,792 - root - [Epoch 0 Batch 540/1043] loss=4.4807, ppl=88.2937, gnorm=0.3201, throughput=28.47K wps, wc=50.72K
2020-09-10 18:57:53,920 - root - [Epoch 0 Batch 550/1043] loss=4.5101, ppl=90.9332, gnorm=0.3302, throughput=29.60K wps, wc=62.95K
2020-09-10 18:57:55,528 - root - [Epoch 0 Batch 560/1043] loss=4.2907, ppl=73.0205, gnorm=0.3507, throughput=28.71K wps, wc=46.13K
2020-09-10 18:57:57,576 - root - [Epoch 0 Batch 570/1043] loss=4.3227, ppl=75.3938, gnorm=0.3194, throughput=29.86K wps, wc=61.12K
2020-09-10 18:57:59,487 - root - [Epoch 0 Batch 580/1043] loss=4.3229, ppl=75.4078, gnorm=0.3018, throughput=29.04K wps, wc=55.43K
2020-09-10 18:58:01,944 - root - [Epoch 0 Batch 590/1043] loss=4.5578, ppl=95.3729, gnorm=0.2665, throughput=30.11K wps, wc=73.93K
2020-09-10 18:58:03,894 - root - [Epoch 0 Batch 600/1043] loss=4.4568, ppl=86.2083, gnorm=0.2772, throughput=28.62K wps, wc=55.80K
2020-09-10 18:58:05,704 - root - [Epoch 0 Batch 610/1043] loss=4.2865, ppl=72.7086, gnorm=0.3502, throughput=28.91K wps, wc=52.28K
2020-09-10 18:58:08,140 - root - [Epoch 0 Batch 620/1043] loss=4.5301, ppl=92.7671, gnorm=0.2773, throughput=29.73K wps, wc=72.39K
2020-09-10 18:58:09,401 - root - [Epoch 0 Batch 630/1043] loss=4.0129, ppl=55.3092, gnorm=0.3310, throughput=27.33K wps, wc=34.44K
2020-09-10 18:58:11,374 - root - [Epoch 0 Batch 640/1043] loss=4.3137, ppl=74.7192, gnorm=0.4117, throughput=29.07K wps, wc=57.32K
2020-09-10 18:58:13,633 - root - [Epoch 0 Batch 650/1043] loss=4.4500, ppl=85.6263, gnorm=0.2950, throughput=29.42K wps, wc=66.42K
2020-09-10 18:58:15,218 - root - [Epoch 0 Batch 660/1043] loss=4.1775, ppl=65.2026, gnorm=0.3733, throughput=28.06K wps, wc=44.42K
2020-09-10 18:58:17,834 - root - [Epoch 0 Batch 670/1043] loss=4.5783, ppl=97.3454, gnorm=0.2769, throughput=29.72K wps, wc=77.68K
2020-09-10 18:58:19,894 - root - [Epoch 0 Batch 680/1043] loss=4.4025, ppl=81.6578, gnorm=0.2940, throughput=28.66K wps, wc=59.02K
2020-09-10 18:58:21,662 - root - [Epoch 0 Batch 690/1043] loss=4.1894, ppl=65.9830, gnorm=0.3172, throughput=28.61K wps, wc=50.54K
2020-09-10 18:58:23,512 - root - [Epoch 0 Batch 700/1043] loss=4.2774, ppl=72.0513, gnorm=0.3043, throughput=28.38K wps, wc=52.45K
2020-09-10 18:58:25,010 - root - [Epoch 0 Batch 710/1043] loss=4.1563, ppl=63.8353, gnorm=0.3532, throughput=27.60K wps, wc=41.32K
2020-09-10 18:58:26,776 - root - [Epoch 0 Batch 720/1043] loss=4.2351, ppl=69.0701, gnorm=0.3096, throughput=28.39K wps, wc=50.09K
2020-09-10 18:58:28,614 - root - [Epoch 0 Batch 730/1043] loss=4.1712, ppl=64.7908, gnorm=0.3094, throughput=28.33K wps, wc=52.02K
2020-09-10 18:58:30,719 - root - [Epoch 0 Batch 740/1043] loss=4.3100, ppl=74.4400, gnorm=0.2919, throughput=28.46K wps, wc=59.86K
2020-09-10 18:58:32,599 - root - [Epoch 0 Batch 750/1043] loss=4.1652, ppl=64.4067, gnorm=0.3143, throughput=28.39K wps, wc=53.35K
2020-09-10 18:58:34,950 - root - [Epoch 0 Batch 760/1043] loss=4.2704, ppl=71.5524, gnorm=0.2906, throughput=29.84K wps, wc=70.09K
2020-09-10 18:58:36,515 - root - [Epoch 0 Batch 770/1043] loss=4.0868, ppl=59.5473, gnorm=0.3133, throughput=27.64K wps, wc=43.21K
2020-09-10 18:58:39,078 - root - [Epoch 0 Batch 780/1043] loss=4.2848, ppl=72.5854, gnorm=0.2839, throughput=29.64K wps, wc=75.93K
2020-09-10 18:58:40,752 - root - [Epoch 0 Batch 790/1043] loss=4.1226, ppl=61.7192, gnorm=0.3077, throughput=27.99K wps, wc=46.81K
2020-09-10 18:58:42,821 - root - [Epoch 0 Batch 800/1043] loss=4.1815, ppl=65.4671, gnorm=0.3218, throughput=28.39K wps, wc=58.72K
2020-09-10 18:58:44,808 - root - [Epoch 0 Batch 810/1043] loss=4.0632, ppl=58.1578, gnorm=0.3073, throughput=28.89K wps, wc=57.38K
2020-09-10 18:58:46,817 - root - [Epoch 0 Batch 820/1043] loss=3.9488, ppl=51.8745, gnorm=0.3230, throughput=29.15K wps, wc=58.52K
2020-09-10 18:58:48,816 - root - [Epoch 0 Batch 830/1043] loss=4.1032, ppl=60.5355, gnorm=0.3281, throughput=28.66K wps, wc=57.24K
2020-09-10 18:58:50,671 - root - [Epoch 0 Batch 840/1043] loss=4.0580, ppl=57.8557, gnorm=0.3068, throughput=28.36K wps, wc=52.57K
2020-09-10 18:58:52,850 - root - [Epoch 0 Batch 850/1043] loss=4.1134, ppl=61.1517, gnorm=0.3005, throughput=29.45K wps, wc=64.14K
2020-09-10 18:58:54,775 - root - [Epoch 0 Batch 860/1043] loss=4.1108, ppl=60.9972, gnorm=0.2982, throughput=28.42K wps, wc=54.64K
2020-09-10 18:58:57,035 - root - [Epoch 0 Batch 870/1043] loss=4.1379, ppl=62.6711, gnorm=0.3149, throughput=29.13K wps, wc=65.81K
2020-09-10 18:58:58,831 - root - [Epoch 0 Batch 880/1043] loss=4.0776, ppl=59.0042, gnorm=0.3166, throughput=28.01K wps, wc=50.27K
2020-09-10 18:59:00,831 - root - [Epoch 0 Batch 890/1043] loss=4.1361, ppl=62.5599, gnorm=0.2971, throughput=28.07K wps, wc=56.11K
2020-09-10 18:59:02,941 - root - [Epoch 0 Batch 900/1043] loss=4.1666, ppl=64.4983, gnorm=0.2897, throughput=28.88K wps, wc=60.91K
2020-09-10 18:59:04,782 - root - [Epoch 0 Batch 910/1043] loss=3.9679, ppl=52.8745, gnorm=0.3130, throughput=28.08K wps, wc=51.65K
2020-09-10 18:59:06,904 - root - [Epoch 0 Batch 920/1043] loss=4.1545, ppl=63.7171, gnorm=0.2874, throughput=28.53K wps, wc=60.52K
2020-09-10 18:59:08,516 - root - [Epoch 0 Batch 930/1043] loss=3.9605, ppl=52.4814, gnorm=0.3001, throughput=27.00K wps, wc=43.51K
2020-09-10 18:59:10,278 - root - [Epoch 0 Batch 940/1043] loss=3.9206, ppl=50.4301, gnorm=0.3329, throughput=28.24K wps, wc=49.71K
2020-09-10 18:59:12,687 - root - [Epoch 0 Batch 950/1043] loss=4.1886, ppl=65.9326, gnorm=0.2828, throughput=29.45K wps, wc=70.92K
2020-09-10 18:59:15,208 - root - [Epoch 0 Batch 960/1043] loss=4.1424, ppl=62.9518, gnorm=0.2878, throughput=29.39K wps, wc=74.06K
2020-09-10 18:59:16,820 - root - [Epoch 0 Batch 970/1043] loss=3.8549, ppl=47.2223, gnorm=0.3436, throughput=27.33K wps, wc=44.03K
2020-09-10 18:59:19,295 - root - [Epoch 0 Batch 980/1043] loss=4.1688, ppl=64.6365, gnorm=0.2710, throughput=29.41K wps, wc=72.73K
2020-09-10 18:59:21,561 - root - [Epoch 0 Batch 990/1043] loss=4.0817, ppl=59.2452, gnorm=0.2949, throughput=28.87K wps, wc=65.39K
2020-09-10 18:59:23,607 - root - [Epoch 0 Batch 1000/1043] loss=3.9732, ppl=53.1566, gnorm=0.3070, throughput=27.23K wps, wc=55.65K
2020-09-10 18:59:25,893 - root - [Epoch 0 Batch 1010/1043] loss=4.0693, ppl=58.5149, gnorm=0.2771, throughput=28.74K wps, wc=65.61K
2020-09-10 18:59:27,539 - root - [Epoch 0 Batch 1020/1043] loss=3.8083, ppl=45.0741, gnorm=0.3299, throughput=27.87K wps, wc=45.72K
2020-09-10 18:59:29,536 - root - [Epoch 0 Batch 1030/1043] loss=3.9724, ppl=53.1106, gnorm=0.3136, throughput=28.07K wps, wc=55.97K
2020-09-10 18:59:31,403 - root - [Epoch 0 Batch 1040/1043] loss=3.9367, ppl=51.2513, gnorm=0.3453, throughput=28.57K wps, wc=53.27K
2020-09-10 18:59:55,191 - root - [Epoch 0] valid Loss=2.8493, valid ppl=17.2761, valid bleu=3.09
2020-09-10 19:00:15,663 - root - [Epoch 0] test Loss=2.9845, test ppl=19.7764, test bleu=2.59
2020-09-10 19:00:15,669 - root - Save best parameters to gnmt_en_vi_u512/valid_best.params
2020-09-10 19:00:15,859 - root - Learning rate change to 0.0005

Conclusion

In this notebook, we have shown how to train a GNMT model on the IWSLT 2015 English-Vietnamese dataset using the Gluon NLP toolkit. The complete training script can be found here. The code sequence to reproduce the results can be seen on the machine translation page.