Source code for gluonnlp.data.dataset
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=undefined-all-variable
"""NLP Toolkit Dataset API. It allows easy and customizable loading of corpora and dataset files.
Files can be loaded into formats that are immediately ready for training and evaluation."""
__all__ = ['TextLineDataset', 'CorpusDataset', 'ConcatDataset', 'TSVDataset', 'NumpyDataset']
import bisect
import collections
import io
import json
import os
import warnings
import numpy as np
from mxnet.gluon.data import ArrayDataset, Dataset, SimpleDataset
from .utils import (Splitter, concat_sequence, line_splitter,
whitespace_splitter)
[docs]class ConcatDataset(Dataset):
"""Dataset that concatenates a list of datasets.
Parameters
----------
datasets : list
List of datasets.
"""
def __init__(self, datasets):
self.datasets = datasets
self.cum_sizes = np.cumsum([0] + [len(d) for d in datasets])
def __getitem__(self, i):
dataset_id = bisect.bisect_right(self.cum_sizes, i)
sample_id = i - self.cum_sizes[dataset_id - 1]
return self.datasets[dataset_id - 1][sample_id]
def __len__(self):
return self.cum_sizes[-1]
[docs]class TextLineDataset(SimpleDataset):
"""Dataset that comprises lines in a file. Each line will be stripped.
Parameters
----------
filename : str
Path to the input text file.
encoding : str, default 'utf8'
File encoding format.
"""
def __init__(self, filename, encoding='utf8'):
lines = []
with io.open(filename, 'r', encoding=encoding) as in_file:
for line in in_file:
lines.append(line.strip())
super(TextLineDataset, self).__init__(lines)
def _corpus_dataset_process(s, bos, eos):
tokens = [bos] if bos else []
tokens.extend(s)
if eos:
tokens.append(eos)
return tokens
[docs]class TSVDataset(SimpleDataset):
"""Common tab separated text dataset that reads text fields based on provided sample splitter
and field separator.
The returned dataset includes samples, each of which can either be a list of text fields
if field_separator is specified, or otherwise a single string segment produced by the
sample_splitter.
Example::
# assume `test.tsv` contains the following content:
# Id\tFirstName\tLastName
# a\tJiheng\tJiang
# b\tLaoban\tZha
# discard the first line and select the 0th and 2nd fields
dataset = data.TSVDataset('test.tsv', num_discard_samples=1, field_indices=[0, 2])
assert dataset[0] == ['a', 'Jiang']
assert dataset[1] == ['b', 'Zha']
Parameters
----------
filename : str or list of str
Path to the input text file or list of paths to the input text files.
encoding : str, default 'utf8'
File encoding format.
sample_splitter : function, default str.splitlines
A function that splits the dataset string into samples.
field_separator : function or None, default Splitter('\t')
A function that splits each sample string into list of text fields.
If None, raw samples are returned according to `sample_splitter`.
num_discard_samples : int, default 0
Number of samples discarded at the head of the first file.
field_indices : list of int or None, default None
If set, for each sample, only fields with provided indices are selected as the output.
Otherwise all fields are returned.
allow_missing : bool, default False
If set to True, no exception will be thrown if the number of fields is smaller than the
maximum field index provided.
"""
def __init__(self, filename, encoding='utf8',
sample_splitter=line_splitter, field_separator=Splitter('\t'),
num_discard_samples=0, field_indices=None, allow_missing=False):
assert sample_splitter, 'sample_splitter must be specified.'
if not isinstance(filename, (tuple, list)):
filename = (filename, )
self._filenames = [os.path.expanduser(f) for f in filename]
self._encoding = encoding
self._sample_splitter = sample_splitter
self._field_separator = field_separator
self._num_discard_samples = num_discard_samples
self._field_indices = field_indices
self._allow_missing = allow_missing
super(TSVDataset, self).__init__(self._read())
def _should_discard(self):
discard = self._num_discard_samples > 0
self._num_discard_samples -= 1
return discard
def _field_selector(self, fields):
if not self._field_indices:
return fields
try:
result = [fields[i] for i in self._field_indices]
except IndexError as e:
raise(IndexError('%s. Fields = %s'%(str(e), str(fields))))
return result
def _read(self):
all_samples = []
for filename in self._filenames:
with io.open(filename, 'r', encoding=self._encoding) as fin:
content = fin.read()
samples = (s for s in self._sample_splitter(content) if not self._should_discard())
if self._field_separator:
if not self._allow_missing:
samples = [self._field_selector(self._field_separator(s)) for s in samples]
else:
selected_samples = []
num_missing = 0
for s in samples:
try:
fields = self._field_separator(s)
selected_samples.append(self._field_selector(fields))
except IndexError:
num_missing += 1
if num_missing > 0:
warnings.warn('%d incomplete samples in %s'%(num_missing, filename))
samples = selected_samples
all_samples += samples
return all_samples
[docs]class CorpusDataset(SimpleDataset):
"""Common text dataset that reads a whole corpus based on provided sample splitter
and word tokenizer.
The returned dataset includes samples, each of which can either be a list of tokens if tokenizer
is specified, or otherwise a single string segment produced by the sample_splitter.
Parameters
----------
filename : str or list of str
Path to the input text file or list of paths to the input text files.
encoding : str, default 'utf8'
File encoding format.
flatten : bool, default False
Whether to return all samples as flattened tokens. If True, each sample is a token.
skip_empty : bool, default True
Whether to skip the empty samples produced from sample_splitters. If False, `bos` and `eos`
will be added in empty samples.
sample_splitter : function, default str.splitlines
A function that splits the dataset string into samples.
tokenizer : function or None, default str.split
A function that splits each sample string into list of tokens. If None, raw samples are
returned according to `sample_splitter`.
bos : str or None, default None
The token to add at the beginning of each sequence. If None, or if tokenizer is not
specified, then nothing is added.
eos : str or None, default None
The token to add at the end of each sequence. If None, or if tokenizer is not
specified, then nothing is added.
"""
def __init__(self, filename, encoding='utf8', flatten=False, skip_empty=True,
sample_splitter=line_splitter, tokenizer=whitespace_splitter,
bos=None, eos=None):
assert sample_splitter, 'sample_splitter must be specified.'
if not isinstance(filename, (tuple, list)):
filename = (filename, )
self._filenames = [os.path.expanduser(f) for f in filename]
self._encoding = encoding
self._flatten = flatten
self._skip_empty = skip_empty
self._sample_splitter = sample_splitter
self._tokenizer = tokenizer
self._bos = bos
self._eos = eos
super(CorpusDataset, self).__init__(self._read())
def _read(self):
all_samples = []
for filename in self._filenames:
with io.open(filename, 'r', encoding=self._encoding) as fin:
content = fin.read()
samples = (s.strip() for s in self._sample_splitter(content))
if self._tokenizer:
samples = [
_corpus_dataset_process(self._tokenizer(s), self._bos, self._eos)
for s in samples if s or not self._skip_empty
]
if self._flatten:
samples = concat_sequence(samples)
elif self._skip_empty:
samples = [s for s in samples if s]
all_samples += samples
return all_samples
[docs]class NumpyDataset(ArrayDataset):
"""A dataset wrapping over a Numpy binary (.npy, .npz) file.
If the file is a .npy file, then a single numpy array is loaded.
If the file is a .npz file with multiple arrays, then a list of
numpy arrays are loaded, ordered by their key in the archive.
Sparse matrix is not yet supported.
Parameters
----------
filename : str
Path to the .npy or .npz file.
kwargs
Keyword arguments are passed to np.load.
Properties
----------
keys: list of str or None
The list of keys loaded from the .npz file.
"""
def __init__(self, filename, **kwargs):
arrs = np.load(filename, **kwargs)
keys = None
data = []
if filename.endswith('.npy'):
data.append(arrs)
elif filename.endswith('.npz'):
keys = sorted(arrs.keys())
for key in keys:
data.append(arrs[key])
else:
raise ValueError('Unsupported extension: %s'%filename)
self._keys = keys
super(NumpyDataset, self).__init__(*data)
@property
def keys(self):
return self._keys
[docs] def get_field(self, field):
"""Return the dataset corresponds to the provided key.
Example::
a = np.ones((2,2))
b = np.zeros((2,2))
np.savez('data.npz', a=a, b=b)
dataset = NumpyDataset('data.npz')
data_a = dataset.get_field('a')
data_b = dataset.get_field('b')
Parameters
----------
field : str
The name of the field to retrieve.
"""
idx = self._keys.index(field)
return self._data[idx]
class _JsonlDataset(SimpleDataset):
"""A dataset wrapping over a jsonlines (.jsonl) file, each line is a json object.
Parameters
----------
filename : str
Path to the .jsonl file.
encoding : str, default 'utf8'
File encoding format.
"""
def __init__(self, filename, encoding='utf8'):
if not isinstance(filename, (tuple, list)):
filename = (filename, )
self._filenames = [os.path.expanduser(f) for f in filename]
self._encoding = encoding
super(_JsonlDataset, self).__init__(self._read())
def _read(self):
all_samples = []
for filename in self._filenames:
samples = []
with open(filename, 'r', encoding=self._encoding) as fin:
for line in fin.readlines():
samples.append(json.loads(line, object_pairs_hook=collections.OrderedDict))
samples = self._read_samples(samples)
all_samples += samples
return all_samples
def _read_samples(self, samples):
raise NotImplementedError