aimz.utils.data.ArrayDataset#
- class aimz.utils.data.ArrayDataset(*, to_jax=False, **arrays)[source]#
Dataset for named arrays.
Arrays are stored as-is by default, preserving whichever backend (NumPy or JAX) the caller supplied. Pass
to_jax=Trueto convert all arrays to JAX arrays at construction time.- Parameters:
to_jax (bool)
arrays (Array | npt.NDArray | None)
- __init__(*, to_jax=False, **arrays)[source]#
Initialize an ArrayDataset instance.
- Parameters:
to_jax (bool) – Whether to convert the input arrays to JAX arrays.
**arrays (Array | npt.NDArray | None) – Named JAX arrays, NumPy arrays, or
None. At least one non-Nonearray must be provided. All non-Nonearrays must have the same length.
- Raises:
ValueError – If no non-
Nonearrays are provided or if the arrays do not have the same length.- Return type:
None
Methods
__init__(*[, to_jax])Initialize an ArrayDataset instance.