Source code for gluonnlp.model.seq2seq_encoder_decoder

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Encoder and decoder used in sequence-to-sequence learning."""
__all__ = ['Seq2SeqEncoder']

import mxnet as mx
from mxnet.gluon.block import Block

def _nested_sequence_last(data, valid_length):
    """

    Parameters
    ----------
    data : nested container of NDArrays/Symbols
        The input data. Each element will have shape (batch_size, ...)
    valid_length : NDArray or Symbol
        Valid length of the sequences. Shape (batch_size,)
    Returns
    -------
    data_last: nested container of NDArrays/Symbols
        The last valid element in the sequence.
    """
    assert isinstance(data, list)
    if isinstance(data[0], (mx.sym.Symbol, mx.nd.NDArray)):
        F = mx.sym if isinstance(data[0], mx.sym.Symbol) else mx.ndarray
        return F.SequenceLast(F.stack(*data, axis=0),
                              sequence_length=valid_length,
                              use_sequence_length=True)
    elif isinstance(data[0], list):
        ret = []
        for i in range(len(data[0])):
            ret.append(_nested_sequence_last([ele[i] for ele in data], valid_length))
        return ret
    else:
        raise NotImplementedError


[docs]class Seq2SeqEncoder(Block): r"""Base class of the encoders in sequence to sequence learning models. """ def __call__(self, inputs, valid_length=None, states=None): #pylint: disable=arguments-differ """Encode the input sequence. Parameters ---------- inputs : NDArray The input sequence, Shape (batch_size, length, C_in). valid_length : NDArray or None, default None The valid length of the input sequence, Shape (batch_size,). This is used when the input sequences are padded. If set to None, all elements in the sequence are used. states : list of NDArrays or None, default None List that contains the initial states of the encoder. Returns ------- outputs : list Outputs of the encoder. """ return super(Seq2SeqEncoder, self).__call__(inputs, valid_length, states)
[docs] def forward(self, inputs, valid_length=None, states=None): #pylint: disable=arguments-differ raise NotImplementedError
class Seq2SeqDecoder(Block): """Base class of the decoders for sequence to sequence learning models. Given the inputs and the context computed by the encoder, generate the new states. Used in the training phase where we set the inputs to be the target sequence. """ def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None): r"""Generates the initial decoder states based on the encoder outputs. Parameters ---------- encoder_outputs : list of NDArrays encoder_valid_length : NDArray or None Returns ------- decoder_states : list """ raise NotImplementedError def forward(self, step_input, states, valid_length=None): #pylint: disable=arguments-differ """Given the inputs and the context computed by the encoder, generate the new states. Used in the training phase where we set the inputs to be the target sequence. Parameters ---------- inputs : NDArray The input embeddings. Shape (batch_size, length, C_in) states : list The initial states of the decoder. valid_length : NDArray or None valid length of the inputs. Shape (batch_size,) Returns ------- output : NDArray The output of the decoder. Shape is (batch_size, length, C_out) states: list The new states of the decoder additional_outputs : list Additional outputs of the decoder, e.g, the attention weights """ raise NotImplementedError class Seq2SeqOneStepDecoder(Block): r"""Base class of the decoders in sequence to sequence learning models. In the forward function, it generates the one-step-ahead decoding output. """ def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None): r"""Generates the initial decoder states based on the encoder outputs. Parameters ---------- encoder_outputs : list of NDArrays encoder_valid_length : NDArray or None Returns ------- decoder_states : list """ raise NotImplementedError def forward(self, step_input, states): #pylint: disable=arguments-differ """One-step decoding of the input Parameters ---------- step_input : NDArray Shape (batch_size, C_in) states : list The previous states of the decoder Returns ------- step_output : NDArray Shape (batch_size, C_out) states : list step_additional_outputs : list Additional outputs of the step, e.g, the attention weights """ raise NotImplementedError