Training & Inference with Data Loaders#
This guide explains how to use the built-in ArrayDataset and ArrayLoader, how they integrate with high-level methods like fit() / predict(), and how to construct fully custom training or inference loops (e.g., integrating a PyTorch DataLoader).
Built-in Dataset & Loader#
ArrayDataset#
It wraps one or more named arrays passed as keyword-only arguments.
All arrays must share the same leading-axis size (the sample axis).
By default arrays are stored as supplied; pass to_jax=True to convert to JAX arrays at construction.
from aimz.utils.data import ArrayDataset
X, y = ... # X and y are array-like
dataset = ArrayDataset(X=X, y=y)
len(dataset) # total number of samples
sample = dataset[0] # {'X': X[0], 'y': y[0]} (dict of field -> element)
ArrayLoader#
It consumes an ArrayDataset and produces an iterator of (batch_dict, n_pad) pairs:
batch_dictmaps each field name to a (possibly padded) mini-batch array.n_padis the number of synthetic examples added so the (possibly last) batch size is divisible by the number of local devices. When no device is specified (device=None), no padding is performed andn_padis always0. If set, batch is padded (if needed) then moved viajax.device_put(). Padding usesjax.numpy.pad()withmode="edge"(it repeats the last row) so shapes align for sharded computations; callers can ignore those rows (track vian_pad). Thebatch_sizemust be a positive integer, and if using a device or sharding it is best to choose a multiple ofjax.local_device_count()to avoid padding.
import jax
from jax import random
from aimz.utils.data import ArrayDataset, ArrayLoader
# Suppose local_device_count() == 8 and batch_size == 10 -> padded to 16
loader = ArrayLoader(
ArrayDataset(X=X, y=y),
rng_key=random.key(0),
batch_size=10,
shuffle=True,
device=jax.devices()[0], # or a Sharding spec
)
for batch, n_pad in loader:
# batch is a dict: {'X': ..., 'y': ...}; n_pad == 6
...
Note
ArrayDataset and ArrayLoader are lightweight utilities for working with in-memory arrays.
They are intentionally minimal and primarily used internally to enable batching, optional shuffling, and (when required) padding for device sharding.
The user can use them directly, but they are not meant to be a comprehensive data pipeline abstraction.
For out-of-core datasets, implement a generator that streams data in chunks from disk or cloud storage.
Storage and Device Transfer#
When raw arrays are passed to high-level methods like fit() or predict(), aimz stores them as NumPy arrays on host memory and transfers one batch at a time to the device during iteration.
JAX arrays passed in are converted to NumPy at this stage; their original device placement is not preserved.
This allows datasets larger than device memory to be processed without modification.
You can keep arrays on device by constructing a loader explicitly with to_jax=True:
loader = ArrayLoader(
ArrayDataset(X=X, y=y, to_jax=True),
rng_key=random.key(0),
batch_size=batch_size,
)
im.predict(loader)
Integration with High-Level Methods#
High-level methods (fit(), predict()) accept either raw arrays (X, y, etc.) or an ArrayLoader.
Passing a loader gives finer control over batch size, ordering, shuffling, and storage backend (see above).
Any model-level device or sharding configuration takes precedence over the loader’s device argument.
If the user pass raw arrays instead, fit() may internally construct a temporary loader with heuristic batching.
from numpyro.infer import SVI
from aimz import ImpactModel
# Set up variational inference strategy
vi = SVI(model, ...)
# Initialize ImpactModel with a model, random key, and SVI object
im = ImpactModel(model, rng_key=random.key(0), svi=vi)
# Use a prepared ArrayLoader for explicit batching/shuffling
im.fit(loader, epochs=10)
# Predictions also accept a loader for consistent batching
preds = im.predict(loader)
Custom Training Loops with train_on_batch()#
For fine-grained control (e.g., custom scheduling, gradient accumulation, or early stopping), a custom training loop can be built with train_on_batch().
im = ImpactModel(...)
for epoch in range(num_epochs):
for batch, n_pad in loader: # `n_pad` may be > 0 when padded
if n_pad > 0:
# Optionally handle or ignore the extra padded rows
...
# Perform one update step on this batch
im.train_on_batch(**batch)
...
# (Optional) validation, logging, early stop checks
Using Other DataLoader Implementations#
You are not restricted to the built-in loader.
Any iterable that yields a mapping (field name → array) per batch works with a custom loop, provided the arrays are convertible via jax.numpy.asarray().
im = ImpactModel(...)
# PyTorch DataLoader example (CPU → JAX conversion per batch)
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=10, shuffle=True)
losses = []
for epoch in range(num_epochs):
for X_batch, y_batch in loader:
batch = {"X": jnp.asarray(X_batch), "y": jnp.asarray(y_batch)}
_, loss = im.train_on_batch(**batch)
losses.append(loss)
After a manual training loop you can populate the model state so downstream calls (prediction, posterior predictive sampling) work the same as after fit():
Set
vi_resultto a structure containing the final parameters and loss history.Draw posterior samples with
sample()(return_datatree=Falseto get a raw dictionary instead of aDataTree).Register the samples via
set_posterior_sample().
from typing import NamedTuple
from jax import Array
class SVIRunResult(NamedTuple):
params: dict[str, Array]
losses: list[float]
# Store final VI parameters and the collected loss trace (assumes `losses` list built above)
im.vi_result = SVIRunResult(im.inference.get_params(im._vi_state), losses)
# Obtain posterior samples
posterior_sample = im.sample(return_datatree=False)
# Register the samples so predictive methods can use them
im.set_posterior_sample(posterior_sample)
You can reuse the same loop pattern for prediction or likelihood evaluation:
# Collect per-batch posterior predictive means for target 'y'
batch_means = []
for X_batch, _ in loader:
preds = im.predict_on_batch(X_batch, return_datatree=False)
# preds['y'] shape: (num_draws, batch_size, ...); average over draws
batch_means.append(preds["y"].mean(axis=0))
# Stitch back together along the sample axis
posterior_predictive_mean = jnp.concatenate(batch_means, axis=0)
# ... further metrics / evaluation
See Also#
PyTorch DataLoader – Widely used reference implementation.
Grain – JAX-native scalable input pipeline.
Dataloader for JAX – Minimal NumPy/JAX DataLoader.