aimz.utils.data.ArrayLoader#

class aimz.utils.data.ArrayLoader(dataset, rng_key, *, batch_size=32, shuffle=False, device=None)[source]#

Data loader for batching and padding arrays.

Parameters:
  • dataset (ArrayDataset)

  • rng_key (Array)

  • batch_size (int)

  • shuffle (bool)

  • device (Sharding | None)

__init__(dataset, rng_key, *, batch_size=32, shuffle=False, device=None)[source]#

Initialize an ArrayLoader instance.

Parameters:
  • dataset (ArrayDataset) – The dataset to load.

  • rng_key (Array) – A pseudo-random number generator key.

  • batch_size (int) – The number of samples per batch.

  • shuffle (bool) – Whether to shuffle the dataset before batching.

  • device (Sharding | None) – The device or sharding specification to which the data should be moved. By default, no device transfer is applied. When used as an input to a model, this will be overridden by the device setting of the model.

Return type:

None

Methods

__init__(dataset, rng_key, *[, batch_size, ...])

Initialize an ArrayLoader instance.

pad_array(x, n_pad[, axis])

Pad an array to ensure compatibility with sharding.