Training & Inference with Data Loaders#

This guide explains how to use the built-in ArrayDataset and ArrayLoader, how they integrate with high-level methods like fit() / predict(), and how to construct fully custom training or inference loops (e.g., integrating a PyTorch DataLoader).

Built-in Dataset & Loader#

ArrayDataset#

It wraps one or more named arrays passed as keyword-only arguments. All arrays must share the same leading-axis size (the sample axis). By default arrays are stored as supplied; pass to_jax=True to convert to JAX arrays at construction.

from aimz.utils.data import ArrayDataset

X, y = ...   # X and y are array-like
dataset = ArrayDataset(X=X, y=y)
len(dataset)         # total number of samples
sample = dataset[0]  # {'X': X[0], 'y': y[0]} (dict of field -> element)

ArrayLoader#

It consumes an ArrayDataset and produces an iterator of (batch_dict, n_pad) pairs:

  • batch_dict maps each field name to a (possibly padded) mini-batch array.

  • n_pad is the number of synthetic examples added so the (possibly last) batch size is divisible by the number of local devices. When no device is specified (device=None), no padding is performed and n_pad is always 0. If set, batch is padded (if needed) then moved via jax.device_put(). Padding uses jax.numpy.pad() with mode="edge" (it repeats the last row) so shapes align for sharded computations; callers can ignore those rows (track via n_pad). The batch_size must be a positive integer, and if using a device or sharding it is best to choose a multiple of jax.local_device_count() to avoid padding.

import jax
from jax import random
from aimz.utils.data import ArrayDataset, ArrayLoader

# Suppose local_device_count() == 8 and batch_size == 10 -> padded to 16
loader = ArrayLoader(
  ArrayDataset(X=X, y=y),
  rng_key=random.key(0),
  batch_size=10,
  shuffle=True,
  device=jax.devices()[0],  # or a Sharding spec
)

for batch, n_pad in loader:
    # batch is a dict: {'X': ..., 'y': ...}; n_pad == 6
    ...

Note

ArrayDataset and ArrayLoader are lightweight utilities for working with in-memory arrays. They are intentionally minimal and primarily used internally to enable batching, optional shuffling, and (when required) padding for device sharding. The user can use them directly, but they are not meant to be a comprehensive data pipeline abstraction. For out-of-core datasets, implement a generator that streams data in chunks from disk or cloud storage.

Storage and Device Transfer#

When raw arrays are passed to high-level methods like fit() or predict(), aimz stores them as NumPy arrays on host memory and transfers one batch at a time to the device during iteration. JAX arrays passed in are converted to NumPy at this stage; their original device placement is not preserved. This allows datasets larger than device memory to be processed without modification. You can keep arrays on device by constructing a loader explicitly with to_jax=True:

loader = ArrayLoader(
    ArrayDataset(X=X, y=y, to_jax=True),
    rng_key=random.key(0),
    batch_size=batch_size,
)
im.predict(loader)

Integration with High-Level Methods#

High-level methods (fit(), predict()) accept either raw arrays (X, y, etc.) or an ArrayLoader. Passing a loader gives finer control over batch size, ordering, shuffling, and storage backend (see above). Any model-level device or sharding configuration takes precedence over the loader’s device argument. If the user pass raw arrays instead, fit() may internally construct a temporary loader with heuristic batching.

from numpyro.infer import SVI

from aimz import ImpactModel

# Set up variational inference strategy
vi = SVI(model, ...)

# Initialize ImpactModel with a model, random key, and SVI object
im = ImpactModel(model, rng_key=random.key(0), svi=vi)

# Use a prepared ArrayLoader for explicit batching/shuffling
im.fit(loader, epochs=10)

# Predictions also accept a loader for consistent batching
preds = im.predict(loader)

Custom Training Loops with train_on_batch()#

For fine-grained control (e.g., custom scheduling, gradient accumulation, or early stopping), a custom training loop can be built with train_on_batch().

im = ImpactModel(...)

for epoch in range(num_epochs):
    for batch, n_pad in loader:  # `n_pad` may be > 0 when padded
        if n_pad > 0:
            # Optionally handle or ignore the extra padded rows
            ...

        # Perform one update step on this batch
        im.train_on_batch(**batch)
        ...

    # (Optional) validation, logging, early stop checks

Using Other DataLoader Implementations#

You are not restricted to the built-in loader. Any iterable that yields a mapping (field name → array) per batch works with a custom loop, provided the arrays are convertible via jax.numpy.asarray().

im = ImpactModel(...)

# PyTorch DataLoader example (CPU → JAX conversion per batch)
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=10, shuffle=True)

losses = []
for epoch in range(num_epochs):
    for X_batch, y_batch in loader:
        batch = {"X": jnp.asarray(X_batch), "y": jnp.asarray(y_batch)}
        _, loss = im.train_on_batch(**batch)
        losses.append(loss)

After a manual training loop you can populate the model state so downstream calls (prediction, posterior predictive sampling) work the same as after fit():

  1. Set vi_result to a structure containing the final parameters and loss history.

  2. Draw posterior samples with sample() (return_datatree=False to get a raw dictionary instead of a DataTree).

  3. Register the samples via set_posterior_sample().

from typing import NamedTuple

from jax import Array


class SVIRunResult(NamedTuple):
    params: dict[str, Array]
    losses: list[float]

# Store final VI parameters and the collected loss trace (assumes `losses` list built above)
im.vi_result = SVIRunResult(im.inference.get_params(im._vi_state), losses)

# Obtain posterior samples
posterior_sample = im.sample(return_datatree=False)

# Register the samples so predictive methods can use them
im.set_posterior_sample(posterior_sample)

You can reuse the same loop pattern for prediction or likelihood evaluation:

# Collect per-batch posterior predictive means for target 'y'
batch_means = []
for X_batch, _ in loader:
    preds = im.predict_on_batch(X_batch, return_datatree=False)
    # preds['y'] shape: (num_draws, batch_size, ...); average over draws
    batch_means.append(preds["y"].mean(axis=0))

# Stitch back together along the sample axis
posterior_predictive_mean = jnp.concatenate(batch_means, axis=0)
# ... further metrics / evaluation

See Also#