Source code for gluonnlp.model.train.cache

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Cache model."""
__all__ = ['CacheCell']

import mxnet as mx
from mxnet.gluon import HybridBlock

[docs]class CacheCell(HybridBlock): r"""Cache language model. We implement the neural cache language model proposed in the following work:: @article{grave2016improving, title={Improving neural language models with a continuous cache}, author={Grave, Edouard and Joulin, Armand and Usunier, Nicolas}, journal={ICLR}, year={2017} } Parameters ---------- lm_model : gluonnlp.model.StandardRNN or gluonnlp.model.AWDRNN The type of RNN to use. Options are 'gluonnlp.model.StandardRNN', 'gluonnlp.model.AWDRNN'. vocab_size : int Size of the input vocabulary. window : int Size of cache window theta : float The scala controls the flatness of the cache distribution that predict the next word as shown below: .. math:: p_{cache} \propto \sum_{i=1}^{t-1} \mathbb{1}_{w=x_{i+1}} exp(\theta {h_t}^T h_i) where :math:`p_{cache}` is the cache distribution, :math:`\mathbb{1}` is the identity function, and :math:`h_i` is the output of timestep i. lambdas : float Linear scalar between only cache and vocab distribution, the formulation is as below: .. math:: p = (1 - \lambda) p_{vocab} + \lambda p_{cache} where :math:`p_{vocab}` is the vocabulary distribution and :math:`p_{cache}` is the cache distribution. """ def __init__(self, lm_model, vocab_size, window, theta, lambdas, **kwargs): super(CacheCell, self).__init__(**kwargs) self._vocab_size = vocab_size self._window = window self._theta = theta self._lambdas = lambdas with self.name_scope(): self.lm_model = lm_model
[docs] def save_parameters(self, filename, deduplicate=False): """Save parameters to file. filename : str Path to file. deduplicate : bool, default False If True, save shared parameters only once. Otherwise, if a Block contains multiple sub-blocks that share parameters, each of the shared parameters will be separately saved for every sub-block. """ self.lm_model.save_parameters(filename, deduplicate=deduplicate)
[docs] def load_parameters(self, filename, ctx=mx.cpu()): # pylint: disable=arguments-differ """Load parameters from file. filename : str Path to parameter file. ctx : Context or list of Context, default cpu() Context(s) initialize loaded parameters on. """ self.lm_model.load_parameters(filename, ctx=ctx)
[docs] def begin_state(self, *args, **kwargs): """Initialize the hidden states. """ return self.lm_model.begin_state(*args, **kwargs)
def __call__(self, inputs, target, next_word_history, cache_history, begin_state=None): # pylint: disable=arguments-differ """Defines the forward computation for cache cell. Arguments can be either :py:class:`NDArray` or :py:class:`Symbol`. Parameters ---------- inputs: NDArray or Symbol The input data target: NDArray or Symbol The label next_word_history: NDArray or Symbol The next word in memory cache_history: NDArray or Symbol The hidden state in cache history begin_state: list of NDArray or Symbol, optional The begin states. Returns -------- out: NDArray or Symbol The linear interpolation of the cache language model with the regular word-level language model next_word_history: NDArray or Symbol The next words to be kept in the memory for look up (size is equal to the window size) cache_history: NDArray or Symbol The hidden states to be kept in the memory for look up (size is equal to the window size) """ return super(CacheCell, self).__call__(inputs, target, next_word_history, cache_history, begin_state)
[docs] def hybrid_forward(self, F, inputs, target, next_word_history, cache_history, begin_state=None): # pylint: disable=arguments-differ """Defines the forward computation for cache cell. Arguments can be either :py:class:`NDArray` or :py:class:`Symbol`. Parameters ---------- inputs: NDArray or Symbol The input data target: NDArray or Symbol The label next_word_history: NDArray or Symbol The next word in memory cache_history: NDArray or Symbol The hidden state in cache history begin_state: list of NDArray or Symbol, optional The begin states. Returns -------- out: NDArray or Symbol The linear interpolation of the cache language model with the regular word-level language model next_word_history: NDArray or Symbol The next words to be kept in the memory for look up (size is equal to the window size) cache_history: NDArray or Symbol The hidden states to be kept in the memory for look up (size is equal to the window size) """ output, hidden, encoder_hs, _ = super(self.lm_model.__class__, self.lm_model).\ hybrid_forward(F, inputs, begin_state) encoder_h = encoder_hs[-1].reshape(-3, -2) output = output.reshape(-1, self._vocab_size) start_idx = len(next_word_history) \ if next_word_history is not None else 0 next_word_history = F.concat(*[F.one_hot(t[0], self._vocab_size, on_value=1, off_value=0) for t in target], dim=0) if next_word_history is None \ else F.concat(next_word_history, F.concat(*[F.one_hot(t[0], self._vocab_size, on_value=1, off_value=0) for t in target], dim=0), dim=0) cache_history = encoder_h if cache_history is None \ else F.concat(cache_history, encoder_h, dim=0) out = None softmax_output = F.softmax(output) for idx, vocab_L in enumerate(softmax_output): joint_p = vocab_L if start_idx + idx > self._window: valid_next_word = next_word_history[start_idx + idx - self._window:start_idx + idx] valid_cache_history = cache_history[start_idx + idx - self._window:start_idx + idx] logits = F.dot(valid_cache_history, encoder_h[idx]) cache_attn = F.softmax(self._theta * logits).reshape(-1, 1) cache_dist = (cache_attn.broadcast_to(valid_next_word.shape) * valid_next_word).sum(axis=0) joint_p = self._lambdas * cache_dist + (1 - self._lambdas) * vocab_L out = joint_p[target[idx]] if out is None \ else F.concat(out, joint_p[target[idx]], dim=0) next_word_history = next_word_history[-self._window:] cache_history = cache_history[-self._window:] return out, next_word_history, cache_history, hidden