Source code for gluonnlp.model.seq2seq_encoder_decoder
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Encoder and decoder used in sequence-to-sequence learning."""
__all__ = ['Seq2SeqEncoder']
import mxnet as mx
from mxnet.gluon.block import Block
def _nested_sequence_last(data, valid_length):
"""
Parameters
----------
data : nested container of NDArrays/Symbols
The input data. Each element will have shape (batch_size, ...)
valid_length : NDArray or Symbol
Valid length of the sequences. Shape (batch_size,)
Returns
-------
data_last: nested container of NDArrays/Symbols
The last valid element in the sequence.
"""
assert isinstance(data, list)
if isinstance(data[0], (mx.sym.Symbol, mx.nd.NDArray)):
F = mx.sym if isinstance(data[0], mx.sym.Symbol) else mx.ndarray
return F.SequenceLast(F.stack(*data, axis=0),
sequence_length=valid_length,
use_sequence_length=True)
elif isinstance(data[0], list):
ret = []
for i in range(len(data[0])):
ret.append(_nested_sequence_last([ele[i] for ele in data], valid_length))
return ret
else:
raise NotImplementedError
[docs]class Seq2SeqEncoder(Block):
r"""Base class of the encoders in sequence to sequence learning models.
"""
def __call__(self, inputs, valid_length=None, states=None): #pylint: disable=arguments-differ
"""Encode the input sequence.
Parameters
----------
inputs : NDArray
The input sequence, Shape (batch_size, length, C_in).
valid_length : NDArray or None, default None
The valid length of the input sequence, Shape (batch_size,). This is used when the
input sequences are padded. If set to None, all elements in the sequence are used.
states : list of NDArrays or None, default None
List that contains the initial states of the encoder.
Returns
-------
outputs : list
Outputs of the encoder.
"""
return super(Seq2SeqEncoder, self).__call__(inputs, valid_length, states)
[docs] def forward(self, inputs, valid_length=None, states=None): #pylint: disable=arguments-differ
raise NotImplementedError
class Seq2SeqDecoder(Block):
"""Base class of the decoders for sequence to sequence learning models.
Given the inputs and the context computed by the encoder, generate the new
states. Used in the training phase where we set the inputs to be the target
sequence.
"""
def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None):
r"""Generates the initial decoder states based on the encoder outputs.
Parameters
----------
encoder_outputs : list of NDArrays
encoder_valid_length : NDArray or None
Returns
-------
decoder_states : list
"""
raise NotImplementedError
def forward(self, step_input, states, valid_length=None): #pylint: disable=arguments-differ
"""Given the inputs and the context computed by the encoder, generate the new states.
Used in the training phase where we set the inputs to be the target
sequence.
Parameters
----------
inputs : NDArray
The input embeddings. Shape (batch_size, length, C_in)
states : list
The initial states of the decoder.
valid_length : NDArray or None
valid length of the inputs. Shape (batch_size,)
Returns
-------
output : NDArray
The output of the decoder. Shape is (batch_size, length, C_out)
states: list
The new states of the decoder
additional_outputs : list
Additional outputs of the decoder, e.g, the attention weights
"""
raise NotImplementedError
class Seq2SeqOneStepDecoder(Block):
r"""Base class of the decoders in sequence to sequence learning models.
In the forward function, it generates the one-step-ahead decoding output.
"""
def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None):
r"""Generates the initial decoder states based on the encoder outputs.
Parameters
----------
encoder_outputs : list of NDArrays
encoder_valid_length : NDArray or None
Returns
-------
decoder_states : list
"""
raise NotImplementedError
def forward(self, step_input, states): #pylint: disable=arguments-differ
"""One-step decoding of the input
Parameters
----------
step_input : NDArray
Shape (batch_size, C_in)
states : list
The previous states of the decoder
Returns
-------
step_output : NDArray
Shape (batch_size, C_out)
states : list
step_additional_outputs : list
Additional outputs of the step, e.g, the attention weights
"""
raise NotImplementedError