aimz.utils.data.ArrayDataset#

class aimz.utils.data.ArrayDataset(*, to_jax=False, **arrays)[source]#

Dataset for named arrays.

Arrays are stored as-is by default, preserving whichever backend (NumPy or JAX) the caller supplied. Pass to_jax=True to convert all arrays to JAX arrays at construction time.

Parameters:
  • to_jax (bool)

  • arrays (Array | npt.NDArray | None)

__init__(*, to_jax=False, **arrays)[source]#

Initialize an ArrayDataset instance.

Parameters:
  • to_jax (bool) – Whether to convert the input arrays to JAX arrays.

  • **arrays (Array | npt.NDArray | None) – Named JAX arrays, NumPy arrays, or None. At least one non-None array must be provided. All non-None arrays must have the same length.

Raises:

ValueError – If no non-None arrays are provided or if the arrays do not have the same length.

Return type:

None

Methods

__init__(*[, to_jax])

Initialize an ArrayDataset instance.