Disk-Backed vs. On-Batch Methods#
This page explains and compares the two complementary execution styles provided by ImpactModel:
Disk-backed (default) methods iterate over the input in chunks, materialize results incrementally, and persist structured artifacts (Zarr-backed
xarray.DataTreeplus metadata) to a temporary or user-specified output directory.On-batch (
*_on_batchsuffix) methods execute a single, fully in-memory pass and can optionally return a plaindictinstead of axarray.DataTree. The naming mirrors the Keras convention to signal an immediate, single-batch, memory-resident operation.
Why Disk-Backed by Default#
The non-*_on_batch methods default to a disk-backed (chunked) execution model for several reasons:
Posterior predictive and prior predictive tensors can scale as
(#samples x #dims x #posterior_samples x ...). Even moderate increases in any axis (time, spatial units, parameter samples) can exceed host or accelerator RAM.Using
batch_sizewith chunked iteration limits peak memory and prevents out-of-memory errors.Persisted Zarr arrays with metadata (coords, dims, attrs) create an artifact you can reopen without rerunning inference.
The
xarray.DataTree+ Zarr format integrates with scientific Python tools such as Dask and ArviZ.Summaries (means, HDIs, residual PPC stats) can be computed lazily over chunked storage without first materializing dense arrays.
One API works for both small experiments and large-scale use cases.
Comparison#
Disk-backed variants target larger datasets, enable chunked processing, multi-device parallelism, and stable artifact generation.
These methods build internal data loaders, iterate in chunks, and decouple sampling from file I/O, enabling concurrent execution.
Outputs consolidate into a single xarray.DataTree backed by Zarr files for post-hoc analysis.
On-batch variants, in contrast, favor minimal overhead, immediate return, and greater flexibility when posterior sample shapes are not shard-friendly.
Feature Summary#
Feature |
Disk-backed (default) |
On-batch ( |
|---|---|---|
Typical dataset size |
Medium → large |
Small → moderate |
Supported use cases |
Standard models |
Broader model support |
Peak memory usage |
Chunk-bounded |
Full batch resident |
Writes to disk |
Yes |
No |
Return type |
|
|
Custom batch sizing |
Yes ( |
No (single pass) |
Device parallelism (sharding) |
Yes |
No |
Automatic fallback |
Yes (may auto‑delegate to on‑batch) |
No (final mode) |
Latency (small data) |
Higher (I/O + orchestration) |
Minimal |
Capability Matrix#
Capability |
Disk-backed (default) |
On-batch ( |
|---|---|---|
Full dataset training |
||
Single training step |
N/A |
|
Prior predictive sampling |
||
Posterior sampling |
N/A |
|
Posterior predictive sampling |
|
|
Log-likelihood computation |
N/A |
|
Effect estimation |
(consumes outputs above) |
Quick Recommendations#
Moderate or large data, or need persisted outputs: use disk-backed (e.g.,
fit(),predict()).Small data, rapid iteration, CI, or read-only / ephemeral filesystem: use on-batch (
*_on_batch).If
predict()issues a fallback warning, callpredict_on_batch()directly. This occurs when the model or posterior sample shapes are incompatible with shard-based chunked execution.Custom training loop: iterate with
train_on_batch().Need multi-device (sharding) execution: disk-backed.
Need raw NumPy/dict outputs (no
xarray.DataTree): on-batch withreturn_datatree=False.
Note
For MCMC inference, only fit_on_batch() or sample() is supported for training and posterior sampling,
as MCMC is incompatible with epoch-based or chunked batch processing. See MCMC Support for more details.
Example: predict() with Fallback Warning#
A common scenario for the fallback warning occurs when the model contains local latent variables, which can make posterior sample shapes incompatible with shard-based parallel execution. The example below illustrates this case.
import logging
import jax.numpy as jnp
import numpyro.distributions as dist
from jax import random
from jax.typing import ArrayLike
from numpyro import optim, plate, sample
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
from aimz import ImpactModel
logging.basicConfig(level=logging.INFO, force=True)
def model(X: ArrayLike, y: ArrayLike | None = None) -> None:
# Model includes a local latent variable
sigma = sample("sigma", dist.Exponential().expand((X.shape[0],)))
with plate("data", size=X.shape[0]):
sample("y", dist.Normal(0.0, sigma), obs=y)
rng_key = random.key(42)
rng_key, rng_key_X, rng_key_y = random.split(rng_key, 3)
X = random.normal(rng_key_X, (100, 2))
y = random.normal(rng_key_y, (100,))
im = ImpactModel(
model,
rng_key=rng_key,
inference=SVI(
model,
guide=AutoNormal(model),
optim=optim.Adam(step_size=1e-3),
loss=Trace_ELBO(),
),
# This internally calls the `.run()` method of `SVI`
).fit_on_batch(X, y)
# Calling `.predict()` triggers a fallback warning
im.predict(X)
/tmp/ipykernel_829/1252916572.py:2: UserWarning: One or more posterior sample shapes are not compatible with `.predict()` under sharded parallelism; falling back to `.predict_on_batch()`. im.predict(X)
<xarray.DataTree 'root'>
Group: /
├── Group: /posterior
│ Dimensions: (chain: 1, draw: 1000, sigma_dim_0: 100)
│ Coordinates:
│ * chain (chain) int64 8B 0
│ * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999
│ * sigma_dim_0 (sigma_dim_0) int64 800B 0 1 2 3 4 5 6 ... 93 94 95 96 97 98 99
│ Data variables:
│ sigma (chain, draw, sigma_dim_0) float32 400kB 1.615 0.8619 ... 2.558
│ Attributes:
│ created_at: 2026-05-24T03:02:13.883093+00:00
│ aimz_version: 0.12.0
└── Group: /posterior_predictive
Dimensions: (chain: 1, draw: 1000, y_dim_0: 100)
Coordinates:
* chain (chain) int64 8B 0
* draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
* y_dim_0 (y_dim_0) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99
Data variables:
y (chain, draw, y_dim_0) float32 400kB -3.284 -0.4314 ... 3.398
Attributes:
created_at: 2026-05-24T03:02:13.885558+00:00
aimz_version: 0.12.0Performance Tips#
Tune
batch_sizeappropriately; it also determines the chunk size for Zarr-backed arrays.Monitor disk usage, as chunk sizes scale with
batch_sizeandnum_samples.Reduce
num_samplesfirst for faster iteration.Use on-batch methods in tests to minimize I/O overhead.