Source code for gluonnlp.data.datasetloader
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=ungrouped-imports
"""DatasetLoader. An extension of Gluon data loader that allows
reading and processing multiple files on-the-fly.
"""
__all__ = ['DatasetLoader']
import io
import pickle
import warnings
import multiprocessing
from functools import partial
from mxnet import context
from mxnet.gluon.data.dataloader import ForkingPickler, _as_in_context
from mxnet.gluon.data.dataloader import default_mp_batchify_fn, default_batchify_fn
from .stream import _PathDataset
# manager for creating shared object
_manager = None
_dataset = None
def _initialize_dataset_worker(manager):
global _manager
_manager = manager
def _dataset_worker_fn(urls, dataset_fn, batch_sampler_fn):
"""Function to generate datasets and batch sampler for each worker."""
global _manager, _dataset
dataset = dataset_fn(urls)
batch_sampler = batch_sampler_fn(dataset)
if _manager:
dataset = _manager.list(zip(*dataset._data))
_dataset = dataset
return dataset, batch_sampler
def _batch_worker_fn(samples, batchify_fn, dataset=None, counter=None):
"""Function for processing data in worker process."""
# pylint: disable=unused-argument
# it is required that each worker process has to fork a new MXIndexedRecordIO handle
# preserving dataset as global variable can save tons of overhead and is safe in new process
if len(dataset[0]) > 1:
if isinstance(samples[0], (list, tuple)):
batch = [batchify_fn([dataset[i] for i in shard]) for shard in samples]
else:
batch = batchify_fn([dataset[i] for i in samples])
else:
if isinstance(samples[0], (list, tuple)):
batch = [batchify_fn([dataset[i][0] for i in shard]) for shard in samples]
else:
batch = batchify_fn([dataset[i][0] for i in samples])
buf = io.BytesIO()
ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(batch)
return buf.getvalue(), counter
class _MultiBatchWorkerIter:
"""Internal multi-worker iterator for DataLoader."""
def __init__(self, worker_pool, batchify_fn, dataset_iter=None,
pin_memory=False, worker_fn=_batch_worker_fn, prefetch=0,
manager=None):
self._worker_pool = worker_pool
self._batchify_fn = batchify_fn
self._data_buffer = {}
self._rcvd_idx = 0
self._sent_idx = 0
self._dataset_iter = iter(dataset_iter)
self._worker_fn = worker_fn
self._pin_memory = pin_memory
self._prefetch = prefetch
self._dataset = None
self._batch_iter = None
self._manager = manager
# datasets reference list
self._dataset_refs = []
# counter reference dict
self._counter_ref = {}
# pre-fetch
for _ in range(self._prefetch):
self._push_next()
def _count_dataset_ref(self, new_dataset):
dataset_refs = []
for dataset in self._dataset_refs:
if self._counter_ref[id(dataset)].value > 0:
dataset_refs.append(dataset)
else:
del self._counter_ref[id(dataset)]
if self._dataset:
if self._counter_ref[id(self._dataset)].value > 0:
if id(new_dataset) != id(self._dataset):
dataset_refs.append(self._dataset)
else:
del self._counter_ref[id(self._dataset)]
self._dataset_refs = dataset_refs
def _next_dataset(self):
try:
dataset, batch_sampler = next(self._dataset_iter)
except StopIteration:
return None
return dataset, batch_sampler
def _push_next(self):
"""Assign next batch workload to workers."""
if self._batch_iter is not None:
r = next(self._batch_iter, None)
else:
r = None
if r is None:
result = self._next_dataset()
if result is None:
return
else:
dataset, batch_sampler = result
# Without checking the reference counts of previous datasets in the master process,
# the key error can be triggered occasionally. This may be a bug in Python.
self._count_dataset_ref(dataset)
self._dataset = dataset
# initialize reference counter
if id(dataset) not in self._counter_ref:
self._counter_ref[id(dataset)] = self._manager.Value('i', 0)
self._batch_iter = iter(batch_sampler)
self._push_next()
else:
counter = self._counter_ref[id(self._dataset)]
counter.value += 1
async_ret = self._worker_pool.apply_async(
self._worker_fn, (r, self._batchify_fn, self._dataset, counter))
self._data_buffer[self._sent_idx] = async_ret
self._sent_idx += 1
def __next__(self):
self._push_next()
if self._rcvd_idx == self._sent_idx:
assert not self._data_buffer, 'Data buffer should be empty at this moment'
raise StopIteration
assert self._rcvd_idx < self._sent_idx, 'rcvd_idx must be smaller than sent_idx'
assert self._rcvd_idx in self._data_buffer, 'fatal error with _push_next, rcvd_idx missing'
ret = self._data_buffer.pop(self._rcvd_idx)
batch, counter = ret.get()
batch = pickle.loads(batch)
counter.value -= 1
if self._pin_memory:
batch = _as_in_context(batch, context.cpu_pinned())
self._rcvd_idx += 1
return batch
def next(self):
return self.__next__()
def __iter__(self):
return self
class _MultiDatasetWorkerIter:
"""Internal multi-worker iterator for DataLoader."""
def __init__(self, worker_pool, file_sampler,
dataset_fn, batch_sampler_fn,
worker_fn=_dataset_worker_fn,
prefetch=0, dataset=None, circle_length=1,
cached=False, num_max_cached=0):
if cached:
assert num_max_cached > 0,\
'When cached is turned on, num_max_cached must be positive.'
self._worker_pool = worker_pool
self._dataset_fn = dataset_fn
self._batch_sampler_fn = batch_sampler_fn
self._worker_fn = worker_fn
self._prefetch = prefetch
self._circle_length = circle_length
self._cached = cached
self._num_max_cached = num_max_cached
# send and receive index for datasets
self._rcvd_idx = 0
self._sent_idx = 0
self._data_buffer = {}
self._dataset = [dataset[i] for i in iter(file_sampler)]
self._num_datasets = len(self._dataset)
# construct cached list
self._cached_dataset = []
# pre-fetch
for _ in range(self._prefetch):
self._push_next_dataset()
def _push_next_dataset(self):
"""Assign next dataset workload to workers."""
current_dataset_idx = self._sent_idx * self._circle_length
if current_dataset_idx < self._num_datasets:
circle_length = min(self._circle_length,
self._num_datasets - current_dataset_idx)
urls = [self._dataset[current_dataset_idx + i] for i in range(circle_length)]
else:
return
# push to worker asynchronously
async_ret = self._worker_pool.apply_async(
self._worker_fn, (urls, self._dataset_fn, self._batch_sampler_fn))
# data buffer stores the async result
self._data_buffer[self._sent_idx] = async_ret
self._sent_idx += 1
def _next_dataset(self):
"""Retrieve the next dataset. Returns None if no dataset is available."""
if self._rcvd_idx == self._sent_idx:
assert not self._data_buffer, 'Data buffer should be empty at this moment'
return None
assert self._rcvd_idx < self._sent_idx, \
'rcvd_idx must be smaller than sent_idx'
assert self._rcvd_idx in self._data_buffer, \
'fatal error with _next_dataset, rcvd_idx missing'
if len(self._cached_dataset) == 0 or self._data_buffer[self._rcvd_idx].ready():
ret = self._data_buffer.pop(self._rcvd_idx)
dataset, batch_sampler = ret.get()
self._rcvd_idx += 1
if self._cached and len(self._cached_dataset) < self._num_max_cached:
self._cached_dataset.append((dataset, batch_sampler))
else:
dataset, batch_sampler = self._cached_dataset.pop(0)
return dataset, batch_sampler
def __next__(self):
"""Next dataset"""
self._push_next_dataset()
result = self._next_dataset()
if result is None:
raise StopIteration
return result
def next(self):
"""Next dataset"""
return self.__next__()
def __iter__(self):
"""Returns the iterator object"""
return self
[docs]class DatasetLoader:
"""Loads data from a list of datasets and returns mini-batches of data.
One dataset is loaded at a time.
Parameters
----------
file_patterns: str
Path to the input text files.
file_sampler : str or gluon.data.Sampler, defaults to 'random'
The sampler used to sample a file. The following string values are supported:
- 'sequential': SequentialSampler
- 'random': RandomSampler
dataset_fn : DatasetFn, callable
Callable object to generate a gluon.data.Dataset given a url.
batch_sampler_fn : SamplerFn, callable
Callable object to generate a gluon.data.sampler.Sampler given a dataset.
dataset_params : dict, default is None
Dictionary of parameters passed to dataset_fn.
batch_sampler_params : dict, default is None
Dictionary of parameters passed to batch_sampler_fn.
batchify_fn : callable
Callback function to allow users to specify how to merge samples
into a batch. Defaults to `default_batchify_fn`::
def default_batchify_fn(data):
if isinstance(data[0], nd.NDArray):
return nd.stack(*data)
elif isinstance(data[0], tuple):
data = zip(*data)
return [default_batchify_fn(i) for i in data]
else:
data = np.asarray(data)
return nd.array(data, dtype=data.dtype)
num_dataset_workers : int
Number of worker process for dataset creation.
num_batch_workers : int
Number of worker process for batch creation.
pin_memory : boolean, default False
If ``True``, the dataloader will copy NDArrays into pinned memory
before returning them. Copying from CPU pinned memory to GPU is faster
than from normal CPU memory. At the same time, it increases GPU memory.
circle_length : int, default is 1
The number of files to be read at the same time. When `circle_length` is larger than 1,
we merge `circle_length` number of files.
dataset_prefetch : int, default is `num_dataset_workers`
The number of prefetching datasets only works if `num_workers` > 0.
If `prefetch` > 0, it allow worker process to prefetch certain datasets before
acquiring data from iterators.
Note that using large prefetching batch will provide smoother bootstrapping performance,
but will consume more memory. Using smaller number may forfeit the purpose of using
multiple worker processes, try reduce `num_dataset_workers` in this case.
By default it defaults to `num_dataset_workers`.
batch_prefetch : int, default is `num_batch_workers * 2`
The number of prefetching batches only works if `num_workers` > 0.
If `prefetch` > 0, it allow worker process to prefetch certain batches before
acquiring data from iterators.
Note that using large prefetching batch will provide smoother bootstrapping performance,
but will consume more shared_memory. Using smaller number may forfeit the purpose of using
multiple worker processes, try reduce `num_batch_workers` in this case.
By default it defaults to `num_batch_workers * 2`.
dataset_cached : bool, default is False
Whether or not to cache last processed dataset. Each processed dataset can
only be cached for once. When there is no new available processed dataset to be fetched,
we pop a cached processed dataset.
num_max_dataset_cached : int, default is 0
Maximum number of cached datasets. It is valid only if `dataset_cached` is True
"""
def __init__(self, file_patterns, file_sampler,
dataset_fn=None, batch_sampler_fn=None,
dataset_params=None, batch_sampler_params=None, batchify_fn=None,
num_dataset_workers=0, num_batch_workers=0,
pin_memory=False, circle_length=1,
dataset_prefetch=None, batch_prefetch=None,
dataset_cached=False, num_max_dataset_cached=0):
assert num_dataset_workers >= 0, \
'num_dataset_workers must be non-negative'
assert num_batch_workers >= 0, \
'num_batch_workers must be non-negative'
if num_batch_workers > 0:
assert num_dataset_workers > 0, \
'num_dataset_workers must be positive when num_batch_workers > 0'
else:
if num_dataset_workers > 0:
warnings.warn('The multi-processing functionalities for both dataset and'
' batch sampling are disabled when num_batch_workers=0 though '
'num_dataset_workers={} > 0'.format(num_dataset_workers))
assert circle_length >= 1, \
'circle_length must be larger than or equal to 1'
if dataset_cached:
assert num_max_dataset_cached > 0, \
'When dataset_cached is True, num_max_dataset_cached must be positive'
self._dataset = _PathDataset(file_patterns)
self._file_sampler = file_sampler
assert dataset_fn is not None, 'dataset_fn is not given.'
assert batch_sampler_fn is not None, 'batch_sampler_fn is not given.'
if dataset_params is not None:
self._dataset_fn = partial(dataset_fn, **dataset_params)
else:
self._dataset_fn = dataset_fn
if batch_sampler_params is not None:
self._batch_sampler_fn = partial(batch_sampler_fn, **batch_sampler_params)
else:
self._batch_sampler_fn = batch_sampler_fn
self._num_dataset_workers = num_dataset_workers
self._num_batch_workers = num_batch_workers
self._dataset_prefetch = max(0, int(dataset_prefetch) \
if dataset_prefetch is not None else self._num_dataset_workers)
self._batch_prefetch = max(0, int(batch_prefetch) \
if batch_prefetch is not None else 2 * self._num_batch_workers)
self._pin_memory = pin_memory
self._circle_length = circle_length
self._dataset_cached = dataset_cached
self._num_max_dataset_cached = num_max_dataset_cached
self._manager = None
self._dataset_worker_pool = None
if self._num_dataset_workers > 0:
self._manager = multiprocessing.Manager()
self._dataset_worker_pool = multiprocessing.Pool(self._num_dataset_workers,
initializer=_initialize_dataset_worker,
initargs=[self._manager])
self._batch_worker_pool = None
if self._num_batch_workers > 0:
self._batch_worker_pool = multiprocessing.Pool(self._num_batch_workers)
if batchify_fn is None:
if self._num_batch_workers > 0:
self._batchify_fn = default_mp_batchify_fn
else:
self._batchify_fn = default_batchify_fn
else:
self._batchify_fn = batchify_fn
def __iter__(self):
if self._num_dataset_workers == 0:
def _same_process_iter():
urls = []
dataset = [self._dataset[i] for i in iter(self._file_sampler)]
for i, url in enumerate(dataset):
urls.append(url)
if i < len(dataset) - 1:
if len(urls) < self._circle_length:
continue
if self._circle_length == 1:
urls = urls[0]
dataset, batch_sampler = _dataset_worker_fn(urls, self._dataset_fn,
self._batch_sampler_fn)
for batch in batch_sampler:
ret = self._batchify_fn([dataset[idx] for idx in batch])
if self._pin_memory:
ret = _as_in_context(ret, context.cpu_pinned())
yield ret
urls = []
return _same_process_iter()
# multi-worker
dataset_iter = _MultiDatasetWorkerIter(self._dataset_worker_pool,
worker_fn=_dataset_worker_fn,
dataset=self._dataset,
file_sampler=self._file_sampler,
dataset_fn=self._dataset_fn,
batch_sampler_fn=self._batch_sampler_fn,
prefetch=self._dataset_prefetch,
circle_length=self._circle_length,
cached=self._dataset_cached,
num_max_cached=self._num_max_dataset_cached)
return _MultiBatchWorkerIter(self._batch_worker_pool, self._batchify_fn, dataset_iter,
pin_memory=self._pin_memory, worker_fn=_batch_worker_fn,
prefetch=self._batch_prefetch, manager=self._manager)
def __del__(self):
if self._dataset_worker_pool:
# manually terminate due to a bug that pool is not automatically terminated
# https://bugs.python.org/issue34172
assert isinstance(self._dataset_worker_pool, multiprocessing.pool.Pool)
self._dataset_worker_pool.terminate()
if self._batch_worker_pool:
assert isinstance(self._batch_worker_pool, multiprocessing.pool.Pool)
self._batch_worker_pool.terminate()