Source code for gluonnlp.data.bert.glue

"""Utility functions for BERT glue data preprocessing"""

__all__ = ['truncate_seqs_equal', 'concat_sequences']

import collections
import itertools
import numpy.ma as ma


[docs]def truncate_seqs_equal(sequences, max_len): """truncate a list of seqs equally so that the total length equals max length. Parameters ---------- sequences : list of list of object Sequences of tokens, each of which is an iterable of tokens. max_len : int Max length to be truncated to. Returns ------- list : list of truncated sequence keeping the origin order Examples -------- >>> seqs = [[1, 2, 3], [4, 5, 6]] >>> truncate_seqs_equal(seqs, 6) [[1, 2, 3], [4, 5, 6]] >>> seqs = [[1, 2, 3], [4, 5, 6]] >>> truncate_seqs_equal(seqs, 4) [[1, 2], [4, 5]] >>> seqs = [[1, 2, 3], [4, 5, 6]] >>> truncate_seqs_equal(seqs, 3) [[1, 2], [4]] """ assert isinstance(sequences, list) lens = list(map(len, sequences)) if sum(lens) <= max_len: return sequences lens = ma.masked_array(lens, mask=[0] * len(lens)) while True: argmin = lens.argmin() minval = lens[argmin] quotient, remainder = divmod(max_len, len(lens) - sum(lens.mask)) if minval <= quotient: # Ignore values that don't need truncation lens.mask[argmin] = 1 max_len -= minval else: # Truncate all lens.data[~lens.mask] = [ quotient + 1 if i < remainder else quotient for i in range(lens.count()) ] break sequences = [seq[:length] for (seq, length) in zip(sequences, lens.data.tolist())] return sequences
[docs]def concat_sequences(seqs, separators, seq_mask=0, separator_mask=1): """Concatenate sequences in a list into a single sequence, using specified separators. Example 1: seqs: [['is', 'this' ,'jacksonville', '?'], ['no' ,'it' ,'is' ,'not', '.']] separator: [[SEP], [SEP], [CLS]] seq_mask: 0 separator_mask: 1 Returns: tokens: is this jacksonville ? [SEP] no it is not . [SEP] [CLS] segment_ids: 0 0 0 0 0 1 1 1 1 1 1 2 p_mask: 0 0 0 0 1 0 0 0 0 0 1 1 Example 2: separator_mask can also be a list. seqs: [['is', 'this' ,'jacksonville', '?'], ['no' ,'it' ,'is' ,'not', '.']] separator: [[SEP], [SEP], [CLS]] seq_mask: 0 separator_mask: [[1], [1], [0]] Returns: tokens: 'is this jacksonville ? [SEP] no it is not . [SEP] [CLS]' segment_ids: 0 0 0 0 0 1 1 1 1 1 1 2 p_mask: 1 1 1 1 1 0 0 0 0 0 1 0 Example 3: seq_mask can also be a list. seqs: [['is', 'this' ,'jacksonville', '?'], ['no' ,'it' ,'is' ,'not', '.']] separator: [[SEP], [SEP], [CLS]] seq_mask: [[1, 1, 1, 1], [0, 0, 0, 0, 0]] separator_mask: [[1], [1], [0]] Returns: tokens: 'is this jacksonville ? [SEP] no it is not . [SEP] [CLS]' segment_ids: 0 0 0 0 0 1 1 1 1 1 1 2 p_mask: 1 1 1 1 1 0 0 0 0 0 1 0 Parameters ---------- seqs : list of list of object sequences to be concatenated separator : list of list of object The special tokens to separate sequences. seq_mask : int or list of list of int A single mask value for all sequence items or a list of values for each item in sequences separator_mask : int or list of list of int A single mask value for all separators or a list of values for each separator Returns ------- np.array: input token ids in 'int32', shape (batch_size, seq_length) np.array: segment ids in 'int32', shape (batch_size, seq_length) np.array: mask for special tokens """ assert isinstance(seqs, collections.abc.Iterable) and len(seqs) > 0 assert isinstance(seq_mask, (list, int)) assert isinstance(separator_mask, (list, int)) concat = sum((seq + sep for sep, seq in itertools.zip_longest(separators, seqs, fillvalue=[])), []) segment_ids = sum( ([i] * (len(seq) + len(sep)) for i, (sep, seq) in enumerate(itertools.zip_longest(separators, seqs, fillvalue=[]))), []) if isinstance(seq_mask, int): seq_mask = [[seq_mask] * len(seq) for seq in seqs] if isinstance(separator_mask, int): separator_mask = [[separator_mask] * len(sep) for sep in separators] p_mask = sum((s_mask + mask for sep, seq, s_mask, mask in itertools.zip_longest( separators, seqs, seq_mask, separator_mask, fillvalue=[])), []) return concat, segment_ids, p_mask