Source code for gluonnlp.data.dataset

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=undefined-all-variable
"""NLP Toolkit Dataset API. It allows easy and customizable loading of corpora and dataset files.
Files can be loaded into formats that are immediately ready for training and evaluation."""
__all__ = ['TextLineDataset', 'CorpusDataset', 'ConcatDataset', 'TSVDataset', 'NumpyDataset']

import bisect
import collections
import io
import json
import os
import warnings

import numpy as np

from mxnet.gluon.data import ArrayDataset, Dataset, SimpleDataset

from .utils import (Splitter, concat_sequence, line_splitter,
                    whitespace_splitter)


[docs]class ConcatDataset(Dataset): """Dataset that concatenates a list of datasets. Parameters ---------- datasets : list List of datasets. """ def __init__(self, datasets): self.datasets = datasets self.cum_sizes = np.cumsum([0] + [len(d) for d in datasets]) def __getitem__(self, i): dataset_id = bisect.bisect_right(self.cum_sizes, i) sample_id = i - self.cum_sizes[dataset_id - 1] return self.datasets[dataset_id - 1][sample_id] def __len__(self): return self.cum_sizes[-1]
[docs]class TextLineDataset(SimpleDataset): """Dataset that comprises lines in a file. Each line will be stripped. Parameters ---------- filename : str Path to the input text file. encoding : str, default 'utf8' File encoding format. """ def __init__(self, filename, encoding='utf8'): lines = [] with io.open(filename, 'r', encoding=encoding) as in_file: for line in in_file: lines.append(line.strip()) super(TextLineDataset, self).__init__(lines)
def _corpus_dataset_process(s, bos, eos): tokens = [bos] if bos else [] tokens.extend(s) if eos: tokens.append(eos) return tokens
[docs]class TSVDataset(SimpleDataset): """Common tab separated text dataset that reads text fields based on provided sample splitter and field separator. The returned dataset includes samples, each of which can either be a list of text fields if field_separator is specified, or otherwise a single string segment produced by the sample_splitter. Example:: # assume `test.tsv` contains the following content: # Id\tFirstName\tLastName # a\tJiheng\tJiang # b\tLaoban\tZha # discard the first line and select the 0th and 2nd fields dataset = data.TSVDataset('test.tsv', num_discard_samples=1, field_indices=[0, 2]) assert dataset[0] == ['a', 'Jiang'] assert dataset[1] == ['b', 'Zha'] Parameters ---------- filename : str or list of str Path to the input text file or list of paths to the input text files. encoding : str, default 'utf8' File encoding format. sample_splitter : function, default str.splitlines A function that splits the dataset string into samples. field_separator : function or None, default Splitter('\t') A function that splits each sample string into list of text fields. If None, raw samples are returned according to `sample_splitter`. num_discard_samples : int, default 0 Number of samples discarded at the head of the first file. field_indices : list of int or None, default None If set, for each sample, only fields with provided indices are selected as the output. Otherwise all fields are returned. allow_missing : bool, default False If set to True, no exception will be thrown if the number of fields is smaller than the maximum field index provided. """ def __init__(self, filename, encoding='utf8', sample_splitter=line_splitter, field_separator=Splitter('\t'), num_discard_samples=0, field_indices=None, allow_missing=False): assert sample_splitter, 'sample_splitter must be specified.' if not isinstance(filename, (tuple, list)): filename = (filename, ) self._filenames = [os.path.expanduser(f) for f in filename] self._encoding = encoding self._sample_splitter = sample_splitter self._field_separator = field_separator self._num_discard_samples = num_discard_samples self._field_indices = field_indices self._allow_missing = allow_missing super(TSVDataset, self).__init__(self._read()) def _should_discard(self): discard = self._num_discard_samples > 0 self._num_discard_samples -= 1 return discard def _field_selector(self, fields): if not self._field_indices: return fields try: result = [fields[i] for i in self._field_indices] except IndexError as e: raise(IndexError('%s. Fields = %s'%(str(e), str(fields)))) return result def _read(self): all_samples = [] for filename in self._filenames: with io.open(filename, 'r', encoding=self._encoding) as fin: content = fin.read() samples = (s for s in self._sample_splitter(content) if not self._should_discard()) if self._field_separator: if not self._allow_missing: samples = [self._field_selector(self._field_separator(s)) for s in samples] else: selected_samples = [] num_missing = 0 for s in samples: try: fields = self._field_separator(s) selected_samples.append(self._field_selector(fields)) except IndexError: num_missing += 1 if num_missing > 0: warnings.warn('%d incomplete samples in %s'%(num_missing, filename)) samples = selected_samples all_samples += samples return all_samples
[docs]class CorpusDataset(SimpleDataset): """Common text dataset that reads a whole corpus based on provided sample splitter and word tokenizer. The returned dataset includes samples, each of which can either be a list of tokens if tokenizer is specified, or otherwise a single string segment produced by the sample_splitter. Parameters ---------- filename : str or list of str Path to the input text file or list of paths to the input text files. encoding : str, default 'utf8' File encoding format. flatten : bool, default False Whether to return all samples as flattened tokens. If True, each sample is a token. skip_empty : bool, default True Whether to skip the empty samples produced from sample_splitters. If False, `bos` and `eos` will be added in empty samples. sample_splitter : function, default str.splitlines A function that splits the dataset string into samples. tokenizer : function or None, default str.split A function that splits each sample string into list of tokens. If None, raw samples are returned according to `sample_splitter`. bos : str or None, default None The token to add at the beginning of each sequence. If None, or if tokenizer is not specified, then nothing is added. eos : str or None, default None The token to add at the end of each sequence. If None, or if tokenizer is not specified, then nothing is added. """ def __init__(self, filename, encoding='utf8', flatten=False, skip_empty=True, sample_splitter=line_splitter, tokenizer=whitespace_splitter, bos=None, eos=None): assert sample_splitter, 'sample_splitter must be specified.' if not isinstance(filename, (tuple, list)): filename = (filename, ) self._filenames = [os.path.expanduser(f) for f in filename] self._encoding = encoding self._flatten = flatten self._skip_empty = skip_empty self._sample_splitter = sample_splitter self._tokenizer = tokenizer self._bos = bos self._eos = eos super(CorpusDataset, self).__init__(self._read()) def _read(self): all_samples = [] for filename in self._filenames: with io.open(filename, 'r', encoding=self._encoding) as fin: content = fin.read() samples = (s.strip() for s in self._sample_splitter(content)) if self._tokenizer: samples = [ _corpus_dataset_process(self._tokenizer(s), self._bos, self._eos) for s in samples if s or not self._skip_empty ] if self._flatten: samples = concat_sequence(samples) elif self._skip_empty: samples = [s for s in samples if s] all_samples += samples return all_samples
[docs]class NumpyDataset(ArrayDataset): """A dataset wrapping over a Numpy binary (.npy, .npz) file. If the file is a .npy file, then a single numpy array is loaded. If the file is a .npz file with multiple arrays, then a list of numpy arrays are loaded, ordered by their key in the archive. Sparse matrix is not yet supported. Parameters ---------- filename : str Path to the .npy or .npz file. kwargs Keyword arguments are passed to np.load. Properties ---------- keys: list of str or None The list of keys loaded from the .npz file. """ def __init__(self, filename, **kwargs): arrs = np.load(filename, **kwargs) keys = None data = [] if filename.endswith('.npy'): data.append(arrs) elif filename.endswith('.npz'): keys = sorted(arrs.keys()) for key in keys: data.append(arrs[key]) else: raise ValueError('Unsupported extension: %s'%filename) self._keys = keys super(NumpyDataset, self).__init__(*data) @property def keys(self): return self._keys
[docs] def get_field(self, field): """Return the dataset corresponds to the provided key. Example:: a = np.ones((2,2)) b = np.zeros((2,2)) np.savez('data.npz', a=a, b=b) dataset = NumpyDataset('data.npz') data_a = dataset.get_field('a') data_b = dataset.get_field('b') Parameters ---------- field : str The name of the field to retrieve. """ idx = self._keys.index(field) return self._data[idx]
class _JsonlDataset(SimpleDataset): """A dataset wrapping over a jsonlines (.jsonl) file, each line is a json object. Parameters ---------- filename : str Path to the .jsonl file. encoding : str, default 'utf8' File encoding format. """ def __init__(self, filename, encoding='utf8'): if not isinstance(filename, (tuple, list)): filename = (filename, ) self._filenames = [os.path.expanduser(f) for f in filename] self._encoding = encoding super(_JsonlDataset, self).__init__(self._read()) def _read(self): all_samples = [] for filename in self._filenames: samples = [] with open(filename, 'r', encoding=self._encoding) as fin: for line in fin.readlines(): samples.append(json.loads(line, object_pairs_hook=collections.OrderedDict)) samples = self._read_samples(samples) all_samples += samples return all_samples def _read_samples(self, samples): raise NotImplementedError