Disk-Backed vs. On-Batch Methods#

Open In Colab

This page explains and compares the two complementary execution styles provided by ImpactModel:

  • Disk-backed (default) methods iterate over the input in chunks, materialize results incrementally, and persist structured artifacts (Zarr-backed xarray.DataTree plus metadata) to a temporary or user-specified output directory.

  • On-batch (*_on_batch suffix) methods execute a single, fully in-memory pass and can optionally return a plain dict instead of a xarray.DataTree. The naming mirrors the Keras convention to signal an immediate, single-batch, memory-resident operation.

Why Disk-Backed by Default#

The non-*_on_batch methods default to a disk-backed (chunked) execution model for several reasons:

  • Posterior predictive and prior predictive tensors can scale as (#samples x #dims x #posterior_samples x ...). Even moderate increases in any axis (time, spatial units, parameter samples) can exceed host or accelerator RAM.

  • Using batch_size with chunked iteration limits peak memory and prevents out-of-memory errors.

  • Persisted Zarr arrays with metadata (coords, dims, attrs) create an artifact you can reopen without rerunning inference.

  • The xarray.DataTree + Zarr format integrates with scientific Python tools such as Dask and ArviZ.

  • Summaries (means, HDIs, residual PPC stats) can be computed lazily over chunked storage without first materializing dense arrays.

  • One API works for both small experiments and large-scale use cases.

Comparison#

Disk-backed variants target larger datasets, enable chunked processing, multi-device parallelism, and stable artifact generation. These methods build internal data loaders, iterate in chunks, and decouple sampling from file I/O, enabling concurrent execution. Outputs consolidate into a single xarray.DataTree backed by Zarr files for post-hoc analysis. On-batch variants, in contrast, favor minimal overhead, immediate return, and greater flexibility when posterior sample shapes are not shard-friendly.

Feature Summary#

Feature

Disk-backed (default)

On-batch (*_on_batch)

Typical dataset size

Medium → large

Small → moderate

Supported use cases

Standard models

Broader model support

Peak memory usage

Chunk-bounded

Full batch resident

Writes to disk

Yes

No

Return type

xarray.DataTree

xarray.DataTree or dict (via return_datatree=False)

Custom batch sizing

Yes (batch_size)

No (single pass)

Device parallelism (sharding)

Yes

No

Automatic fallback

Yes (may auto‑delegate to on‑batch)

No (final mode)

Latency (small data)

Higher (I/O + orchestration)

Minimal

Capability Matrix#

Capability

Disk-backed (default)

On-batch (*_on_batch)

Full dataset training

fit()

fit_on_batch()

Single training step

N/A

train_on_batch()

Prior predictive sampling

sample_prior_predictive()

sample_prior_predictive_on_batch()

Posterior sampling

sample()

N/A

Posterior predictive sampling

predict() or sample_posterior_predictive()

predict_on_batch() or sample_posterior_predictive_on_batch()

Log-likelihood computation

log_likelihood()

N/A

Effect estimation

estimate_effect()

(consumes outputs above)

Quick Recommendations#

  • Moderate or large data, or need persisted outputs: use disk-backed (e.g., fit(), predict()).

  • Small data, rapid iteration, CI, or read-only / ephemeral filesystem: use on-batch (*_on_batch).

  • If predict() issues a fallback warning, call predict_on_batch() directly. This occurs when the model or posterior sample shapes are incompatible with shard-based chunked execution.

  • Custom training loop: iterate with train_on_batch().

  • Need multi-device (sharding) execution: disk-backed.

  • Need raw NumPy/dict outputs (no xarray.DataTree): on-batch with return_datatree=False.

Note

For MCMC inference, only fit_on_batch() or sample() is supported for training and posterior sampling, as MCMC is incompatible with epoch-based or chunked batch processing. See MCMC Support for more details.

Example: predict() with Fallback Warning#

A common scenario for the fallback warning occurs when the model contains local latent variables, which can make posterior sample shapes incompatible with shard-based parallel execution. The example below illustrates this case.

import logging

import jax.numpy as jnp
import numpyro.distributions as dist
from jax import random
from jax.typing import ArrayLike
from numpyro import optim, plate, sample
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal

from aimz import ImpactModel

logging.basicConfig(level=logging.INFO, force=True)


def model(X: ArrayLike, y: ArrayLike | None = None) -> None:
    # Model includes a local latent variable
    sigma = sample("sigma", dist.Exponential().expand((X.shape[0],)))
    with plate("data", size=X.shape[0]):
        sample("y", dist.Normal(0.0, sigma), obs=y)


rng_key = random.key(42)
rng_key, rng_key_X, rng_key_y = random.split(rng_key, 3)
X = random.normal(rng_key_X, (100, 2))
y = random.normal(rng_key_y, (100,))


im = ImpactModel(
    model,
    rng_key=rng_key,
    inference=SVI(
        model,
        guide=AutoNormal(model),
        optim=optim.Adam(step_size=1e-3),
        loss=Trace_ELBO(),
    ),
# This internally calls the `.run()` method of `SVI`
).fit_on_batch(X, y)
# Calling `.predict()` triggers a fallback warning
im.predict(X)
/tmp/ipykernel_830/1252916572.py:2: UserWarning: One or more posterior sample shapes are not compatible with `.predict()` under sharded parallelism; falling back to `.predict_on_batch()`.
  im.predict(X)
<xarray.DataTree 'root'>
Group: /
├── Group: /posterior
│       Dimensions:      (chain: 1, draw: 1000, sigma_dim_0: 100)
│       Coordinates:
│         * chain        (chain) int64 8B 0
│         * draw         (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999
│         * sigma_dim_0  (sigma_dim_0) int64 800B 0 1 2 3 4 5 6 ... 93 94 95 96 97 98 99
│       Data variables:
│           sigma        (chain, draw, sigma_dim_0) float32 400kB 1.615 0.8619 ... 2.558
│       Attributes:
│           created_at:    2026-05-24T02:55:52.818277+00:00
│           aimz_version:  0.12.0
└── Group: /posterior_predictive
        Dimensions:  (chain: 1, draw: 1000, y_dim_0: 100)
        Coordinates:
          * chain    (chain) int64 8B 0
          * draw     (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
          * y_dim_0  (y_dim_0) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99
        Data variables:
            y        (chain, draw, y_dim_0) float32 400kB -3.284 -0.4314 ... 3.398
        Attributes:
            created_at:    2026-05-24T02:55:52.820719+00:00
            aimz_version:  0.12.0

Performance Tips#

  • Tune batch_size appropriately; it also determines the chunk size for Zarr-backed arrays.

  • Monitor disk usage, as chunk sizes scale with batch_size and num_samples.

  • Reduce num_samples first for faster iteration.

  • Use on-batch methods in tests to minimize I/O overhead.