Source code for gluonnlp.utils.parallel
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Utility functions for parallel processing."""
import queue
import threading
__all__ = ['Parallelizable', 'Parallel']
[docs]class Parallelizable:
"""Base class for parallelizable unit of work, which can be invoked by `Parallel`.
The subclass must implement the `forward_backward` method, and be used
together with `Parallel`. For example::
class ParallelNet(Parallelizable):
def __init__(self):
self._net = Model()
self._loss = gluon.loss.SoftmaxCrossEntropyLoss()
def forward_backward(self, x):
data, label = x
with mx.autograd.record():
out = self._net(data)
loss = self._loss(out, label)
loss.backward()
return loss
net = ParallelNet()
ctx = [mx.gpu(0), mx.gpu(1)]
parallel = Parallel(len(ctx), net)
# Gluon block is initialized after forwarding the first batch
initialized = False
for batch in batches:
for x in gluon.utils.split_and_load(batch, ctx):
parallel.put(x)
losses = [parallel.get() for _ in ctx]
trainer.step()
"""
[docs] def forward_backward(self, x):
""" Forward and backward computation. """
raise NotImplementedError()
[docs]class Parallel:
"""Class for parallel processing with `Parallelizable`s. It invokes a
`Parallelizable` with multiple Python threads. For example::
class ParallelNet(Parallelizable):
def __init__(self):
self._net = Model()
self._loss = gluon.loss.SoftmaxCrossEntropyLoss()
def forward_backward(self, x):
data, label = x
mx.autograd.record():
out = self._net(data)
loss = self._loss(out, label)
loss.backward()
return loss
net = ParallelNet()
ctx = [mx.gpu(0), mx.gpu(1)]
parallel = Parallel(len(ctx), net)
for batch in batches:
for x in gluon.utils.split_and_load(batch, ctx):
parallel.put(x)
losses = [parallel.get() for _ in ctx]
trainer.step()
Parameters
----------
num_workers : int
Number of worker threads. If set to 0, the main thread is used as the worker for
debugging purpose.
parallelizable :
Parallelizable net whose `forward` and `backward` methods are invoked
by multiple worker threads.
serial_init : bool, default True
Execute the first `num_workers` inputs in main thread, so that the `Block`
used in `parallizable` is initialized serially. Initialize a `Block` with
multiple threads may cause unexpected behavior.
"""
class _StopSignal:
"""Internal class to signal stop. """
def __init__(self, msg):
self._msg = msg
def __init__(self, num_workers, parallizable, serial_init=True):
self._in_queue = queue.Queue(-1)
self._out_queue = queue.Queue(-1)
self._num_workers = num_workers
self._threads = []
self._parallizable = parallizable
self._num_serial = num_workers if serial_init else 0
def _worker(in_queue, out_queue, parallel):
while True:
x = in_queue.get()
if isinstance(x, Parallel._StopSignal):
return
out = parallel.forward_backward(x)
out_queue.put(out)
arg = (self._in_queue, self._out_queue, self._parallizable)
for _ in range(num_workers):
thread = threading.Thread(target=_worker, args=arg)
self._threads.append(thread)
thread.start()
[docs] def put(self, x):
"""Assign input `x` to an available worker and invoke
`parallizable.forward_backward` with x. """
if self._num_serial > 0 or len(self._threads) == 0:
self._num_serial -= 1
out = self._parallizable.forward_backward(x)
self._out_queue.put(out)
else:
self._in_queue.put(x)
[docs] def get(self):
"""Get an output of previous `parallizable.forward_backward` calls.
This method blocks if none of previous `parallizable.forward_backward`
calls have return any result. """
return self._out_queue.get()
def __del__(self):
for thread in self._threads:
if thread.is_alive():
self._in_queue.put(self._StopSignal('stop'))
for thread in self._threads:
thread.join(10)