Source code for gluoncv.utils.sync_loader_helper

"""Dataloader helper functions. Synchronize slices for both data and label."""
__all__ = ['split_data', 'split_and_load']

from mxnet import ndarray

[docs]def split_data(data, num_slice, batch_axis=0, even_split=True, multiplier=1): """Splits an NDArray into `num_slice` slices along `batch_axis`. Usually used for data parallelism where each slices is sent to one device (i.e. GPU). Parameters ---------- data : NDArray A batch of data. num_slice : int Number of desired slices. batch_axis : int, default 0 The axis along which to slice. even_split : bool, default True Whether to force all slices to have the same number of elements. If `True`, an error will be raised when `num_slice` does not evenly divide `data.shape[batch_axis]`. multiplier : int, default 1 The batch size has to be the multiples of multiplier Returns ------- list of NDArray Return value is a list even if `num_slice` is 1. """ size = data.shape[batch_axis] if even_split and size % num_slice != 0: raise ValueError( "data with shape %s cannot be evenly split into %d slices along axis %d. " \ "Use a batch size that's multiple of %d or set even_split=False to allow " \ "uneven partitioning of data."%( str(data.shape), num_slice, batch_axis, num_slice)) step = (int(size / multiplier) // num_slice) * multiplier # If size < num_slice, make fewer slices if not even_split and size < num_slice: step = 1 num_slice = size if batch_axis == 0: slices = [data[i*step:(i+1)*step] if i < num_slice - 1 else data[i*step:size] for i in range(num_slice)] elif even_split: slices = ndarray.split(data, num_outputs=num_slice, axis=batch_axis) else: slices = [ndarray.slice_axis(data, batch_axis, i*step, (i+1)*step) if i < num_slice - 1 else ndarray.slice_axis(data, batch_axis, i*step, size) for i in range(num_slice)] return slices
[docs]def split_and_load(data, ctx_list, batch_axis=0, even_split=True, multiplier=1): """Splits an NDArray into `len(ctx_list)` slices along `batch_axis` and loads each slice to one context in `ctx_list`. Parameters ---------- data : NDArray A batch of data. ctx_list : list of Context A list of Contexts. batch_axis : int, default 0 The axis along which to slice. even_split : bool, default True Whether to force all slices to have the same number of elements. multiplier : int, default 1 The batch size has to be the multiples of channel multiplier. Need to investigate further. Returns ------- list of NDArray Each corresponds to a context in `ctx_list`. """ if not isinstance(data, ndarray.NDArray): data = ndarray.array(data, ctx=ctx_list[0]) if len(ctx_list) == 1: return [data.as_in_context(ctx_list[0])] slices = split_data(data, len(ctx_list), batch_axis, even_split, multiplier) return [i.as_in_context(ctx) for i, ctx in zip(slices, ctx_list)]