Source code for gluonnlp.data.candidate_sampler
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Candidate samplers"""
__all__ = ['UnigramCandidateSampler']
import mxnet as mx
import numpy as np
[docs]class UnigramCandidateSampler(mx.gluon.HybridBlock):
"""Unigram Candidate Sampler
Draw random samples from a unigram distribution with specified weights
using the alias method.
Parameters
----------
weights : mx.nd.NDArray
Unnormalized class probabilities. Samples are drawn and returned on the
same context as weights.context.
dtype : str or np.dtype, default 'float32'
Data type of the candidates. Make sure that the dtype precision is
large enough to represent the size of your weights array precisely. For
example, float32 can not distinguish 2**24 from 2**24 + 1.
"""
def __init__(self, weights, dtype='float32'):
super(UnigramCandidateSampler, self).__init__()
self._dtype = dtype
self.N = weights.size
if (np.dtype(dtype) == np.float32 and weights.size > 2**24) or \
(np.dtype(dtype) == np.float16 and weights.size > 2**11):
s = 'dtype={dtype} can not represent all weights'
raise ValueError(s.format(dtype=dtype))
total_weights = weights.sum()
prob = (weights * self.N / total_weights).asnumpy().tolist()
alias = [0] * self.N
# sort the data into the outcomes with probabilities
# that are high and low than 1/N.
low = []
high = []
for i in range(self.N):
if prob[i] < 1.0:
low.append(i)
else:
high.append(i)
# pair low with high
while len(low) > 0 and len(high) > 0:
l = low.pop()
h = high.pop()
alias[l] = h
prob[h] = prob[h] - (1.0 - prob[l])
if prob[h] < 1.0:
low.append(h)
else:
high.append(h)
for i in low + high:
prob[i] = 1
alias[i] = i
# store
prob = mx.nd.array(prob, dtype='float64')
alias = mx.nd.array(alias, dtype='float64')
self.prob = self.params.get_constant('prob', prob)
self.alias = self.params.get_constant('alias', alias)
def __repr__(self):
s = '{block_name}({len_weights}, {dtype})'
return s.format(block_name=self.__class__.__name__, len_weights=self.N,
dtype=self._dtype)
# pylint: disable=arguments-differ, unused-argument
[docs] def hybrid_forward(self, F, candidates_like, prob, alias):
"""Draw samples from uniform distribution and return sampled candidates.
Parameters
----------
candidates_like: mxnet.nd.NDArray or mxnet.sym.Symbol
This input specifies the shape of the to be sampled candidates. #
Returns
-------
samples: mxnet.nd.NDArray or mxnet.sym.Symbol
The sampled candidates of shape candidates_like.shape. Candidates
are sampled based on the weights specified on creation of the
UnigramCandidateSampler.
"""
candidates_flat = candidates_like.reshape((-1, )).astype('float64')
idx = F.random.uniform_like(candidates_flat, low=0, high=self.N).floor()
prob = F.gather_nd(prob, idx.reshape((1, -1)))
alias = F.gather_nd(alias, idx.reshape((1, -1)))
where = F.random.uniform_like(candidates_flat) < prob
hit = idx * where
alt = alias * (1 - where)
candidates = (hit + alt).reshape_like(candidates_like)
return candidates.astype(self._dtype)