Source code for gluonnlp.model.sequence_sampler
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Implements the beam search sampler."""
__all__ = ['BeamSearchScorer', 'BeamSearchSampler', 'HybridBeamSearchSampler', 'SequenceSampler']
from typing import TypeVar
import numpy as np
import mxnet as mx
from mxnet.gluon import HybridBlock
from .._constants import LARGE_NEGATIVE_FLOAT
__T = TypeVar('__T')
[docs]class BeamSearchScorer(HybridBlock):
r"""Score function used in beam search.
Implements the length-penalized score function used in the GNMT paper::
scores = (log_probs + scores) / length_penalty
length_penalty = (K + length)^\alpha / (K + 1)^\alpha
Parameters
----------
alpha : float, default 1.0
K : float, default 5.0
from_logits : bool, default True
Whether input is a log probability (usually from log_softmax) instead
of unnormalized numbers.
"""
def __init__(self, alpha=1.0, K=5.0, from_logits=True, **kwargs):
super(BeamSearchScorer, self).__init__(**kwargs)
self._alpha = alpha
self._K = K
self._from_logits = from_logits
def __call__(self, outputs, scores, step): # pylint: disable=arguments-differ
"""Compute new scores of each candidate
Parameters
----------
outputs : NDArray or Symbol
If from_logits is True, outputs is the log probabilities of the candidates.
Shape (d1, d2, ..., dn, V).
Otherwise, outputs is the unnormalized outputs from predictor of the same shape,
before softmax/log_softmax.
scores : NDArray or Symbol
The original scores of the beams. Shape (d1, d2, ..., dn)
step : NDArray or Symbol
Step to calculate the score function. It starts from 1. Shape (1,)
Returns
-------
candidate_scores : NDArray or Symbol
The scores of all the candidates. Shape (d1, d2, ..., dn, V), where V is the size
of the vocabulary.
"""
return super(BeamSearchScorer, self).__call__(outputs, scores, step)
[docs] def hybrid_forward(self, F, outputs, scores, step): # pylint: disable=arguments-differ
if not self._from_logits:
outputs = outputs.log_softmax()
prev_lp = (self._K + step - 1) ** self._alpha / (self._K + 1) ** self._alpha
prev_lp = prev_lp * (step != 1) + (step == 1)
scores = F.broadcast_mul(scores, prev_lp)
lp = (self._K + step) ** self._alpha / (self._K + 1) ** self._alpha
candidate_scores = F.broadcast_add(outputs, F.expand_dims(scores, axis=-1))
candidate_scores = F.broadcast_div(candidate_scores, lp)
return candidate_scores
def _extract_and_flatten_nested_structure(data, flattened=None):
"""Flatten the structure of a nested container to a list.
Parameters
----------
data : A single NDArray/Symbol or nested container with NDArrays/Symbol.
The nested container to be flattened.
flattened : list or None
The container thats holds flattened result.
Returns
-------
structure : An integer or a nested container with integers.
The extracted structure of the container of `data`.
flattened : (optional) list
The container thats holds flattened result.
It is returned only when the input argument `flattened` is not given.
"""
if flattened is None:
flattened = []
structure = _extract_and_flatten_nested_structure(data, flattened)
return structure, flattened
if isinstance(data, list):
return list(_extract_and_flatten_nested_structure(x, flattened) for x in data)
elif isinstance(data, tuple):
return tuple(_extract_and_flatten_nested_structure(x, flattened) for x in data)
elif isinstance(data, dict):
return {k: _extract_and_flatten_nested_structure(v) for k, v in data.items()}
elif isinstance(data, (mx.sym.Symbol, mx.nd.NDArray)):
flattened.append(data)
return len(flattened) - 1
else:
raise NotImplementedError
def _reconstruct_flattened_structure(structure, flattened):
"""Reconstruct the flattened list back to (possibly) nested structure.
Parameters
----------
structure : An integer or a nested container with integers.
The extracted structure of the container of `data`.
flattened : list or None
The container thats holds flattened result.
Returns
-------
data : A single NDArray/Symbol or nested container with NDArrays/Symbol.
The nested container that was flattened.
"""
if isinstance(structure, list):
return list(_reconstruct_flattened_structure(x, flattened) for x in structure)
elif isinstance(structure, tuple):
return tuple(_reconstruct_flattened_structure(x, flattened) for x in structure)
elif isinstance(structure, dict):
return {k: _reconstruct_flattened_structure(v, flattened) for k, v in structure.items()}
elif isinstance(structure, int):
return flattened[structure]
else:
raise NotImplementedError
def _expand_to_beam_size(data: __T, beam_size, batch_size, state_info=None) -> __T:
"""Tile all the states to have batch_size * beam_size on the batch axis.
Parameters
----------
data : A single NDArray/Symbol or nested container with NDArrays/Symbol
Each NDArray/Symbol should have shape (N, ...) when state_info is None,
or same as the layout in state_info when it's not None.
beam_size : int
Beam size
batch_size : int
Batch size
state_info : Nested structure of dictionary, default None.
Descriptors for states, usually from decoder's ``state_info()``.
When None, this method assumes that the batch axis is the first dimension.
Returns
-------
new_states : Object that contains NDArrays/Symbols
Each NDArray/Symbol should have shape batch_size * beam_size on the batch axis.
"""
assert not state_info or isinstance(state_info, (type(data), dict)), \
'data and state_info doesn\'t match, ' \
'got: {} vs {}.'.format(type(state_info), type(data))
if isinstance(data, list):
if not state_info:
state_info = [None] * len(data)
return [_expand_to_beam_size(d, beam_size, batch_size, s)
for d, s in zip(data, state_info)]
elif isinstance(data, tuple):
if not state_info:
state_info = [None] * len(data)
state_info = tuple(state_info)
return tuple(_expand_to_beam_size(d, beam_size, batch_size, s)
for d, s in zip(data, state_info))
elif isinstance(data, dict):
if not state_info:
state_info = {k: None for k in data.keys()}
return {k: _expand_to_beam_size(v, beam_size, batch_size, state_info[k])
for k, v in data.items()}
elif isinstance(data, mx.nd.NDArray):
if not state_info:
batch_axis = 0
else:
batch_axis = state_info['__layout__'].find('N')
if data.shape[batch_axis] != batch_size:
raise ValueError('The batch dimension of all the inner elements in states must be '
'{}, Found shape={}'.format(batch_size, data.shape))
new_shape = list(data.shape)
new_shape[batch_axis] = batch_size * beam_size
new_shape = tuple(new_shape)
return data.expand_dims(batch_axis+1)\
.broadcast_axes(axis=batch_axis+1, size=beam_size)\
.reshape(new_shape)
elif isinstance(data, mx.sym.Symbol):
if not state_info:
batch_axis = 0
else:
batch_axis = state_info['__layout__'].find('N')
new_shape = (0, ) * batch_axis + (-3, -2)
return data.expand_dims(batch_axis+1)\
.broadcast_axes(axis=batch_axis+1, size=beam_size)\
.reshape(new_shape)
elif data is None:
return None
else:
raise NotImplementedError
def _choose_states(F, states, state_info, indices):
"""
Parameters
----------
F : ndarray or symbol
states : Object contains NDArrays/Symbols
Each NDArray/Symbol should have shape (N, ...) when state_info is None,
or same as the layout in state_info when it's not None.
state_info : Nested structure of dictionary, default None.
Descriptors for states, usually from decoder's ``state_info()``.
When None, this method assumes that the batch axis is the first dimension.
indices : NDArray or Symbol
Indices of the states to take. Shape (N,).
Returns
-------
new_states : Object contains NDArrays/Symbols
Each NDArray/Symbol should have shape (N, ...).
"""
assert not state_info or isinstance(state_info, (type(states), dict)), \
'states and state_info don\'t match'
if isinstance(states, list):
if not state_info:
state_info = [None] * len(states)
return [_choose_states(F, d, s, indices) for d, s in zip(states, state_info)]
elif isinstance(states, tuple):
if not state_info:
state_info = [None] * len(states)
state_info = tuple(state_info)
return tuple(_choose_states(F, d, s, indices) for d, s in zip(states, state_info))
elif isinstance(states, dict):
if not state_info:
state_info = {k: None for k in states.keys()}
return {k: _choose_states(F, v, state_info[k], indices)
for k, v in states.items()}
elif isinstance(states, (mx.nd.NDArray, mx.sym.Symbol)):
if not state_info:
batch_axis = 0
else:
batch_axis = state_info['__layout__'].find('N')
states = F.take(states, indices, axis=batch_axis)
return states
else:
raise NotImplementedError
class _BeamSearchStepUpdate(HybridBlock):
def __init__(self, beam_size, eos_id, scorer, state_info, single_step=False, \
prefix=None, params=None):
super(_BeamSearchStepUpdate, self).__init__(prefix, params)
self._beam_size = beam_size
self._eos_id = eos_id
self._scorer = scorer
self._state_info = state_info
self._single_step = single_step
assert eos_id >= 0, 'eos_id cannot be negative! Received eos_id={}'.format(eos_id)
def hybrid_forward(self, F, samples, valid_length, outputs, scores, step, beam_alive_mask, # pylint: disable=arguments-differ
states, vocab_size, batch_shift):
"""
Parameters
----------
F
samples : NDArray or Symbol
The current samples generated by beam search.
When single_step is True, (batch_size, beam_size, max_length).
When single_step is False, (batch_size, beam_size, L).
valid_length : NDArray or Symbol
The current valid lengths of the samples
outputs : NDArray or Symbol
Outputs from predictor. If from_logits was set to True in scorer, then it's the
log probability of the current step. Else, it's the unnormalized outputs before
softmax or log_softmax. Shape (batch_size * beam_size, V).
scores : NDArray or Symbol
The previous scores. Shape (batch_size, beam_size)
step : NDArray or Symbol
The current step for doing beam search. Begins from 1. Shape (1,)
beam_alive_mask : NDArray or Symbol
Shape (batch_size, beam_size)
states : nested structure of NDArrays/Symbols
Each NDArray/Symbol should have shape (N, ...) when state_info is None,
or same as the layout in state_info when it's not None.
vocab_size : NDArray or Symbol
Shape (1,)
batch_shift : NDArray or Symbol
Contains [0, beam_size, 2 * beam_size, ..., (batch_size - 1) * beam_size].
Shape (batch_size,)
Returns
-------
new_samples : NDArray or Symbol or an empty list
The updated samples.
When single_step is True, it is an empty list.
When single_step is False, shape (batch_size, beam_size, L + 1)
new_valid_length : NDArray or Symbol
Valid lengths of the samples. Shape (batch_size, beam_size)
new_scores : NDArray or Symbol
Shape (batch_size, beam_size)
chosen_word_ids : NDArray or Symbol
The chosen word ids of the step. Shape (batch_size, beam_size). If it's negative,
no word will be appended to the beam.
beam_alive_mask : NDArray or Symbol
Shape (batch_size, beam_size)
new_states : nested structure of NDArrays/Symbols
Inner NDArrays have shape (batch_size * beam_size, ...)
"""
beam_size = self._beam_size
beam_alive_mask_bcast = F.expand_dims(beam_alive_mask, axis=2).astype(np.float32)
candidate_scores = self._scorer(outputs.reshape(shape=(-4, -1, beam_size, 0)),
scores, step)
# Concat the candidate scores and the scores of the finished beams
# The resulting candidate score will have shape (batch_size, beam_size * |V| + beam_size)
candidate_scores = F.broadcast_mul(beam_alive_mask_bcast, candidate_scores) + \
F.broadcast_mul(1 - beam_alive_mask_bcast,
F.ones_like(candidate_scores) * LARGE_NEGATIVE_FLOAT)
finished_scores = F.where(beam_alive_mask,
F.ones_like(scores) * LARGE_NEGATIVE_FLOAT, scores)
candidate_scores = F.concat(candidate_scores.reshape(shape=(0, -1)),
finished_scores, dim=1)
# Get the top K scores
new_scores, indices = F.topk(candidate_scores, axis=1, k=beam_size, ret_typ='both')
indices = indices.astype(np.int32)
use_prev = F.broadcast_greater_equal(indices, beam_size * vocab_size)
chosen_word_ids = F.broadcast_mod(indices, vocab_size)
beam_ids = F.where(use_prev,
F.broadcast_minus(indices, beam_size * vocab_size),
F.floor(F.broadcast_div(indices, vocab_size)))
batch_beam_indices = F.broadcast_add(beam_ids, F.expand_dims(batch_shift, axis=1))
chosen_word_ids = F.where(use_prev,
-F.ones_like(indices),
chosen_word_ids)
# Update the samples and vaild_length
selected_samples = F.take(samples.reshape(shape=(-3, 0)),
batch_beam_indices.reshape(shape=(-1,)))
new_samples = F.concat(selected_samples,
chosen_word_ids.reshape(shape=(-1, 1)), dim=1)\
.reshape(shape=(-4, -1, beam_size, 0))
if self._single_step:
new_samples = new_samples.slice_axis(axis=2, begin=1, end=None)
new_valid_length = F.take(valid_length.reshape(shape=(-1,)),
batch_beam_indices.reshape(shape=(-1,))).reshape((-1, beam_size))\
+ 1 - use_prev
# Update the states
new_states = _choose_states(F, states, self._state_info, batch_beam_indices.reshape((-1,)))
# Update the alive mask.
beam_alive_mask = F.take(beam_alive_mask.reshape(shape=(-1,)),
batch_beam_indices.reshape(shape=(-1,)))\
.reshape(shape=(-1, beam_size)) * (chosen_word_ids != self._eos_id)
return new_samples, new_valid_length, new_scores,\
chosen_word_ids, beam_alive_mask, new_states
class _SamplingStepUpdate(HybridBlock):
def __init__(self, beam_size, eos_id, temperature=1.0, top_k=None, prefix=None, params=None):
super(_SamplingStepUpdate, self).__init__(prefix, params)
self._beam_size = beam_size
self._eos_id = eos_id
self._temperature = temperature
self._top_k = top_k
assert eos_id >= 0, 'eos_id cannot be negative! Received eos_id={}'.format(eos_id)
# pylint: disable=arguments-differ
def hybrid_forward(self, F, samples, valid_length, outputs, scores, beam_alive_mask, states):
"""
Parameters
----------
F
samples : NDArray or Symbol
The current samples generated by beam search. Shape (batch_size, beam_size, L)
valid_length : NDArray or Symbol
The current valid lengths of the samples
outputs: NDArray or Symbol
Decoder output (unnormalized) scores of the current step.
Shape (batch_size * beam_size, V)
scores : NDArray or Symbol
The previous scores. Shape (batch_size, beam_size)
beam_alive_mask : NDArray or Symbol
Shape (batch_size, beam_size)
states : nested structure of NDArrays/Symbols
Inner NDArrays have shape (batch_size * beam_size, ...)
Returns
-------
new_samples : NDArray or Symbol
The updated samples. Shape (batch_size, beam_size, L + 1)
new_valid_length : NDArray or Symbol
Valid lengths of the samples. Shape (batch_size, beam_size)
new_scores : NDArray or Symbol
Shape (batch_size, beam_size)
chosen_word_ids : NDArray or Symbol
The chosen word ids of the step. Shape (batch_size, beam_size). If it's negative,
no word will be appended to the beam.
beam_alive_mask : NDArray or Symbol
Shape (batch_size, beam_size)
new_states : nested structure of NDArrays/Symbols
Inner NDArrays have shape (batch_size * beam_size, ...)
"""
beam_size = self._beam_size
# outputs: (batch_size, beam_size, vocab_size)
outputs = outputs.reshape(shape=(-4, -1, beam_size, 0))
if self._top_k:
ranks = outputs.argsort(is_ascend=False, dtype='int32')
outputs = F.where(ranks < self._top_k, outputs, F.ones_like(outputs)*-99999)
smoothed_probs = (outputs / self._temperature).softmax(axis=2)
log_probs = F.log_softmax(outputs, axis=2).reshape(-3, -1)
# (batch_size, beam_size)
chosen_word_ids = F.sample_multinomial(smoothed_probs, dtype=np.int32)
chosen_word_ids = F.where(beam_alive_mask,
chosen_word_ids,
-1*F.ones_like(beam_alive_mask))
chosen_word_log_probs = log_probs[mx.nd.arange(log_probs.shape[0]),
chosen_word_ids.reshape(-1)].reshape(-4, -1, beam_size)
# Don't update for finished beams
new_scores = scores + F.where(beam_alive_mask,
chosen_word_log_probs,
F.zeros_like(chosen_word_log_probs))
new_valid_length = valid_length + beam_alive_mask
# Update the samples and vaild_length
new_samples = F.concat(samples, chosen_word_ids.expand_dims(2), dim=2)
# Update the states
new_states = states
# Update the alive mask.
beam_alive_mask = beam_alive_mask * (chosen_word_ids != self._eos_id)
return new_samples, new_valid_length, new_scores,\
chosen_word_ids, beam_alive_mask, new_states
[docs]class BeamSearchSampler:
r"""Draw samples from the decoder by beam search.
Parameters
----------
beam_size : int
The beam size.
decoder : callable
Function of the one-step-ahead decoder, should have the form::
outputs, new_states = decoder(step_input, states)
The outputs, input should follow these rules:
- step_input has shape (batch_size,),
- outputs has shape (batch_size, V),
- states and new_states have the same structure and the leading
dimension of the inner NDArrays is the batch dimension.
eos_id : int
Id of the EOS token. No other elements will be appended to the sample if it reaches eos_id.
scorer : BeamSearchScorer, default BeamSearchScorer(alpha=1.0, K=5)
The score function used in beam search.
max_length : int, default 100
The maximum search length.
"""
def __init__(self, beam_size, decoder, eos_id, scorer=BeamSearchScorer(alpha=1.0, K=5),
max_length=100):
self._beam_size = beam_size
assert beam_size > 0,\
'beam_size must be larger than 0. Received beam_size={}'.format(beam_size)
self._decoder = decoder
self._eos_id = eos_id
assert eos_id >= 0, 'eos_id cannot be negative! Received eos_id={}'.format(eos_id)
self._max_length = max_length
self._scorer = scorer
if hasattr(decoder, 'state_info'):
state_info = decoder.state_info()
else:
state_info = None
self._updater = _BeamSearchStepUpdate(beam_size=beam_size, eos_id=eos_id, scorer=scorer,
state_info=state_info)
self._updater.hybridize()
def __call__(self, inputs, states):
"""Sample by beam search.
Parameters
----------
inputs : NDArray
The initial input of the decoder. Shape is (batch_size,).
states : Object that contains NDArrays
The initial states of the decoder.
Returns
-------
samples : NDArray
Samples draw by beam search. Shape (batch_size, beam_size, length). dtype is int32.
scores : NDArray
Scores of the samples. Shape (batch_size, beam_size). We make sure that scores[i, :] are
in descending order.
valid_length : NDArray
The valid length of the samples. Shape (batch_size, beam_size). dtype will be int32.
"""
batch_size = inputs.shape[0]
beam_size = self._beam_size
ctx = inputs.context
# Tile the states and inputs to have shape (batch_size * beam_size, ...)
if hasattr(self._decoder, 'state_info'):
state_info = self._decoder.state_info(batch_size)
else:
state_info = None
states = _expand_to_beam_size(states, beam_size=beam_size, batch_size=batch_size,
state_info=state_info)
step_input = _expand_to_beam_size(inputs, beam_size=beam_size,
batch_size=batch_size).astype(np.int32)
# All beams are initialized to alive
# Generated samples are initialized to be the inputs
# Except the first beam where the scores are set to be zero, all beams have -inf scores.
# Valid length is initialized to be 1
beam_alive_mask = mx.nd.ones(shape=(batch_size, beam_size), ctx=ctx, dtype=np.int32)
valid_length = mx.nd.ones(shape=(batch_size, beam_size), ctx=ctx, dtype=np.int32)
scores = mx.nd.zeros(shape=(batch_size, beam_size), ctx=ctx)
if beam_size > 1:
scores[:, 1:beam_size] = LARGE_NEGATIVE_FLOAT
samples = step_input.reshape((batch_size, beam_size, 1))
for i in range(self._max_length):
log_probs, new_states = self._decoder(step_input, states)
vocab_size_nd = mx.nd.array([log_probs.shape[1]], ctx=ctx, dtype=np.int32)
batch_shift_nd = mx.nd.arange(0, batch_size * beam_size, beam_size, ctx=ctx,
dtype=np.int32)
step_nd = mx.nd.array([i + 1], ctx=ctx)
samples, valid_length, scores, chosen_word_ids, beam_alive_mask, states = \
self._updater(samples, valid_length, log_probs, scores, step_nd, beam_alive_mask,
new_states, vocab_size_nd, batch_shift_nd)
step_input = mx.nd.relu(chosen_word_ids).reshape((-1,))
if mx.nd.sum(beam_alive_mask).asscalar() == 0:
return samples, scores, valid_length
final_word = mx.nd.where(beam_alive_mask,
mx.nd.full(shape=(batch_size, beam_size),
val=self._eos_id, ctx=ctx, dtype=np.int32),
mx.nd.full(shape=(batch_size, beam_size),
val=-1, ctx=ctx, dtype=np.int32))
samples = mx.nd.concat(samples, final_word.reshape((0, 0, 1)), dim=2)
valid_length += beam_alive_mask
return samples, scores, valid_length
[docs]class HybridBeamSearchSampler(HybridBlock):
r"""Draw samples from the decoder by beam search.
Parameters
----------
batch_size : int
The batch size.
beam_size : int
The beam size.
decoder : callable, must be hybridizable
Function of the one-step-ahead decoder, should have the form::
outputs, new_states = decoder(step_input, states)
The outputs, input should follow these rules:
- step_input has shape (batch_size,),
- outputs has shape (batch_size, V),
- states and new_states have the same structure and the leading
dimension of the inner NDArrays is the batch dimension.
eos_id : int
Id of the EOS token. No other elements will be appended to the sample if it reaches eos_id.
scorer : BeamSearchScorer, default BeamSearchScorer(alpha=1.0, K=5), must be hybridizable
The score function used in beam search.
max_length : int, default 100
The maximum search length.
vocab_size : int, default None, meaning `decoder._vocab_size`
The vocabulary size
"""
def __init__(self, batch_size, beam_size, decoder, eos_id,
scorer=BeamSearchScorer(alpha=1.0, K=5),
max_length=100, vocab_size=None,
prefix=None, params=None):
super(HybridBeamSearchSampler, self).__init__(prefix, params)
self._batch_size = batch_size
self._beam_size = beam_size
assert beam_size > 0,\
'beam_size must be larger than 0. Received beam_size={}'.format(beam_size)
self._decoder = decoder
self._eos_id = eos_id
assert eos_id >= 0, 'eos_id cannot be negative! Received eos_id={}'.format(eos_id)
self._max_length = max_length
self._scorer = scorer
self._state_info_func = getattr(decoder, 'state_info', lambda _=None: None)
self._updater = _BeamSearchStepUpdate(beam_size=beam_size, eos_id=eos_id, scorer=scorer,
single_step=True, state_info=self._state_info_func())
self._updater.hybridize()
self._vocab_size = vocab_size or getattr(decoder, '_vocab_size', None)
assert self._vocab_size is not None,\
'Please provide vocab_size or define decoder._vocab_size'
assert not hasattr(decoder, '_vocab_size') or decoder._vocab_size == self._vocab_size, \
'Provided vocab_size={} is not equal to decoder._vocab_size={}'\
.format(self._vocab_size, decoder._vocab_size)
[docs] def hybrid_forward(self, F, inputs, states): # pylint: disable=arguments-differ
"""Sample by beam search.
Parameters
----------
F
inputs : NDArray or Symbol
The initial input of the decoder. Shape is (batch_size,).
states : Object that contains NDArrays or Symbols
The initial states of the decoder.
Returns
-------
samples : NDArray or Symbol
Samples draw by beam search. Shape (batch_size, beam_size, length). dtype is int32.
scores : NDArray or Symbol
Scores of the samples. Shape (batch_size, beam_size). We make sure that scores[i, :] are
in descending order.
valid_length : NDArray or Symbol
The valid length of the samples. Shape (batch_size, beam_size). dtype will be int32.
"""
batch_size = self._batch_size
beam_size = self._beam_size
vocab_size = self._vocab_size
# Tile the states and inputs to have shape (batch_size * beam_size, ...)
state_info = self._state_info_func(batch_size)
step_input = _expand_to_beam_size(inputs, beam_size=beam_size,
batch_size=batch_size).astype(np.int32)
states = _expand_to_beam_size(states, beam_size=beam_size, batch_size=batch_size,
state_info=state_info)
state_structure, states = _extract_and_flatten_nested_structure(states)
if beam_size == 1:
init_scores = F.zeros(shape=(batch_size, 1))
else:
init_scores = F.concat(
F.zeros(shape=(batch_size, 1)),
F.full(shape=(batch_size, beam_size - 1), val=LARGE_NEGATIVE_FLOAT),
dim=1)
vocab_size = F.full(shape=(1,), val=vocab_size, dtype=np.int32)
batch_shift = F.arange(0, batch_size * beam_size, beam_size, dtype=np.int32)
def _loop_cond(_i, _samples, _indices, _step_input, _valid_length, _scores, \
beam_alive_mask, *_states):
return F.sum(beam_alive_mask) > 0
def _loop_func(i, samples, indices, step_input, valid_length, scores, \
beam_alive_mask, *states):
outputs, new_states = self._decoder(
step_input, _reconstruct_flattened_structure(state_structure, states))
step = i + 1
new_samples, new_valid_length, new_scores, \
chosen_word_ids, new_beam_alive_mask, new_new_states = \
self._updater(samples, valid_length, outputs, scores, step.astype(np.float32),
beam_alive_mask,
_extract_and_flatten_nested_structure(new_states)[-1],
vocab_size, batch_shift)
new_step_input = F.relu(chosen_word_ids).reshape((-1,))
# We are doing `new_indices = indices[1 : ] + indices[ : 1]`
new_indices = F.concat(
indices.slice_axis(axis=0, begin=1, end=None),
indices.slice_axis(axis=0, begin=0, end=1),
dim=0)
return [], (step, new_samples, new_indices, new_step_input, new_valid_length, \
new_scores, new_beam_alive_mask) + tuple(new_new_states)
_, pad_samples, indices, _, new_valid_length, new_scores, new_beam_alive_mask = \
F.contrib.while_loop(
cond=_loop_cond, func=_loop_func, max_iterations=self._max_length,
loop_vars=(
F.zeros(shape=(1,), dtype=np.int32), # i
F.zeros(shape=(batch_size, beam_size, self._max_length),
dtype=np.int32), # samples
F.arange(start=0, stop=self._max_length, dtype=np.int32), # indices
step_input, # step_input
F.ones(shape=(batch_size, beam_size), dtype=np.int32), # valid_length
init_scores, # scores
F.ones(shape=(batch_size, beam_size), dtype=np.int32), # beam_alive_mask
) + tuple(states)
)[1][:7] # I hate Python 2
samples = pad_samples.take(indices, axis=2)
def _then_func():
new_samples = F.concat(
step_input.reshape((batch_size, beam_size, 1)),
samples,
F.full(shape=(batch_size, beam_size, 1), val=-1, dtype=np.int32),
dim=2,
name='concat3')
new_new_valid_length = new_valid_length
return new_samples, new_new_valid_length
def _else_func():
final_word = F.where(new_beam_alive_mask,
F.full(shape=(batch_size, beam_size), val=self._eos_id,
dtype=np.int32),
F.full(shape=(batch_size, beam_size), val=-1, dtype=np.int32))
new_samples = F.concat(
step_input.reshape((batch_size, beam_size, 1)),
samples,
final_word.reshape((0, 0, 1)),
dim=2)
new_new_valid_length = new_valid_length + new_beam_alive_mask
return new_samples, new_new_valid_length
new_samples, new_new_valid_length = \
F.contrib.cond(F.sum(new_beam_alive_mask) == 0, _then_func, _else_func)
return new_samples, new_scores, new_new_valid_length
[docs]class SequenceSampler:
r"""Draw samples from the decoder according to the step-wise distribution.
Parameters
----------
beam_size : int
The beam size.
decoder : callable
Function of the one-step-ahead decoder, should have the form::
outputs, new_states = decoder(step_input, states)
The outputs, input should follow these rules:
- step_input has shape (batch_size,)
- outputs is the unnormalized prediction before softmax with shape (batch_size, V)
- states and new_states have the same structure and the leading
dimension of the inner NDArrays is the batch dimension.
eos_id : int
Id of the EOS token. No other elements will be appended to the sample if it reaches eos_id.
max_length : int, default 100
The maximum search length.
temperature : float, default 1.0
Softmax temperature.
top_k : int or None, default None
Sample only from the top-k candidates. If None, all candidates are considered.
"""
def __init__(self, beam_size, decoder, eos_id, max_length=100, temperature=1.0, top_k=None):
self._beam_size = beam_size
self._decoder = decoder
self._eos_id = eos_id
assert eos_id >= 0, 'eos_id cannot be negative! Received eos_id={}'.format(eos_id)
self._max_length = max_length
self._top_k = top_k
self._updater = _SamplingStepUpdate(beam_size=beam_size,
eos_id=eos_id,
temperature=temperature,
top_k=top_k)
def __call__(self, inputs, states):
"""Sample by beam search.
Parameters
----------
inputs : NDArray
The initial input of the decoder. Shape is (batch_size,).
states : Object that contains NDArrays
The initial states of the decoder.
Returns
-------
samples : NDArray
Samples draw by beam search. Shape (batch_size, beam_size, length). dtype is int32.
scores : NDArray
Scores of the samples. Shape (batch_size, beam_size). We make sure that scores[i, :] are
in descending order.
valid_length : NDArray
The valid length of the samples. Shape (batch_size, beam_size). dtype will be int32.
"""
batch_size = inputs.shape[0]
beam_size = self._beam_size
ctx = inputs.context
# Tile the states and inputs to have shape (batch_size * beam_size, ...)
if hasattr(self._decoder, 'state_info'):
state_info = self._decoder.state_info(batch_size)
else:
state_info = None
states = _expand_to_beam_size(states, beam_size=beam_size, batch_size=batch_size,
state_info=state_info)
step_input = _expand_to_beam_size(inputs, beam_size=beam_size, batch_size=batch_size)
# All beams are initialized to alive
# Generated samples are initialized to be the inputs
# Except the first beam where the scores are set to be zero, all beams have -inf scores.
# Valid length is initialized to be 1
beam_alive_mask = mx.nd.ones(shape=(batch_size, beam_size), ctx=ctx, dtype=np.int32)
valid_length = mx.nd.ones(shape=(batch_size, beam_size), ctx=ctx, dtype=np.int32)
scores = mx.nd.zeros(shape=(batch_size, beam_size), ctx=ctx)
samples = step_input.reshape((batch_size, beam_size, 1)).astype(np.int32)
for _ in range(self._max_length):
outputs, new_states = self._decoder(step_input, states)
samples, valid_length, scores, chosen_word_ids, beam_alive_mask, states = \
self._updater(samples, valid_length, outputs, scores, beam_alive_mask, new_states)
step_input = mx.nd.relu(chosen_word_ids).reshape((-1,))
if mx.nd.sum(beam_alive_mask).asscalar() == 0:
return samples, scores, valid_length
final_word = mx.nd.where(beam_alive_mask,
mx.nd.full(shape=(batch_size, beam_size),
val=self._eos_id, ctx=ctx, dtype=np.int32),
mx.nd.full(shape=(batch_size, beam_size),
val=-1, ctx=ctx, dtype=np.int32))
samples = mx.nd.concat(samples, final_word.reshape((0, 0, 1)), dim=2)
valid_length += beam_alive_mask
return samples, scores, valid_length