Frequently Asked Questions#
What is a kernel?#
A kernel in aimz is a user-defined NumPyro model (a stochastic function or Callable) built with primitives like sample() and deterministic().
Its signature and body define the inputs and output (e.g., X, y, …), encoding the probabilistic structure—priors, likelihood, and latent variables.
How do I use different argument names than X and y?#
By default, aimz expects your kernel signature to include parameters named X (input) and y (output).
If you use different names—e.g. features / target or covariates / outcome—declare them when instantiating ImpactModel:
def kernel(features, extra, target=None):
...
im = ImpactModel(
kernel,
...,
param_input="features",
param_output="target",
)
If you see an error like:
Kernel must accept 'X' and 'y' as argument(s). Modify the kernel signature or set `param_input` and `param_output` accordingly.
it means you neither matched the defaults nor overrode them.
Fix it by renaming your arguments to X / y or supplying param_input / param_output as shown above.
Do I need to know NumPyro to use aimz?#
Yes. The aimz package builds on NumPyro’s primitives and effect handlers. You should be comfortable writing a model function, defining a guide (for SVI) or configuring MCMC, and reading model traces. The library focuses on orchestration, not abstracting away core probabilistic modeling concepts.
Can I use aimz with any NumPyro model?#
No. Most conventional models with global latents and a plate-based structure work out of the box. The core requirement is that every sample site in the model must have a static shape that does not change with the data.
On multi-device systems, predict() uses data parallelism: the input data is sharded across devices along axis 0 (the observation dimension), while posterior samples are replicated on every device.
The forward sampler used internally also requires every traced site to have a fixed shape across all sample iterations.
Several modeling patterns can violate this requirement. The two most common are:
Local latent variables
A
sample()call inside aplatewithoutobs=produces a posterior whose shape grows with the plate size (e.g.,(num_samples, n_obs)). Becausen_obscan differ between training and prediction, the replicated-posterior contract breaks.The
scan()primitivescan()is commonly used for sequential or autoregressive models (e.g., state-space models, time-series forecasting). The number of sites it creates typically grows with the sequence length, making the trace shape dynamic for the same reason.
When this incompatibility is detected, predict() issues a warning and automatically falls back to predict_on_batch(), which processes data in a single batch without sharding.
However, if the posterior shapes are fundamentally incompatible with the new input, the forward pass will still fail with a shape mismatch.
Note that a model with nested plates whose sample() sites are all observed (obs=) or whose latent shapes are fixed remains compatible.
In contrast, a model with a single plate containing one unobserved sample() site triggers the fallback.
The following examples illustrate the compatibility with predict():
import numpyro.distributions as dist
from numpyro import plate, sample
X, y = ...
# Compatible with .predict(): all latents are global
def kernel(X, y=None):
...
alpha = sample("alpha", dist.Normal())
beta = sample("beta", dist.Normal().expand([X.shape[1]]))
with plate("obs", X.shape[0]):
mu = alpha + X @ beta
sample("y", dist.Normal(mu), obs=y)
# Falls back to .predict_on_batch(): `mu` is a local latent
# with posterior shape (num_samples, X.shape[0])
def kernel(X, y=None):
...
with plate("obs", X.shape[0]):
mu = sample("mu", dist.Normal())
sample("y", dist.Normal(mu), obs=y)
If you encounter an unsupported pattern—ideally with a minimal reproducible example—please open an issue or submit a PR. We plan to broaden coverage based on user needs.
Does aimz ship built-in model templates?#
No. This is intentional to keep the library lightweight and avoid prescribing a specific modeling style. Future recipes or example galleries may be provided separately, but the library itself does not include canonical model classes.
What kinds of data can aimz handle?#
aimz accepts NumPy or JAX arrays of any shape with at least one dimension; the leading axis is treated as the sample axis.
This covers tabular inputs ((n, d)), 1D inputs ((n,)), and higher-rank inputs such as sequences ((n, seq_len, d)) or images ((n, h, w, c)).
Multiple named arrays are supported as long as they share the same leading-axis size.
The output variable has the same flexibility — it can be 1D for scalar targets, 2D for multi-output regression, or higher-rank as the model requires — provided its leading axis matches the input.
Ragged or nested structures are not currently supported.
If native support for a specific structure is important for your use case, opening an issue helps prioritize it, and contributions are welcome.
Can I use aimz for general-purpose Bayesian inference?#
Yes. aimz is a flexible, object-oriented interface to NumPyro and supports a wide range of Bayesian modeling tasks—regression, classification, uncertainty quantification, and predictive simulation—even if your application doesn’t involve interventions or causal analysis.
Can I use posterior samples generated elsewhere?#
Yes—you do not need to train a model from scratch and sample posteriors.
After initializing an ImpactModel with your model, call set_posterior_sample() with a dictionary mapping site names to arrays.
Each array must share the same leading dimension (number of draws), and the dictionary must not be empty.
Once injected, the model is treated as fitted, and the prediction, log-likelihood, and posterior predictive methods will use the supplied samples.
When should I use the *_on_batch variants?#
Use the batch-specific variants only when you need explicit, single-batch control (e.g., custom training loops, micro‑benchmarking, or integrating with external schedulers). The higher-level methods handle internal batching, iteration, shuffling, streaming, and aggregation automatically and are preferred for typical workflows. See Disk-Backed vs. On-Batch Methods for a detailed comparison of both approaches and guidance on when to use each.
How do I control which variables (sites) are sampled?#
By default, prediction and sampling methods use the set of return sites cached in return_sites—typically the model output plus any deterministic sites discovered during the first trace.
To override this behavior, pass return_sites=(...) explicitly to the relevant methods.
How to ensure reproducible results?#
ImpactModel requires an explicit JAX pseudo-random number generator key for initialization.
Using the same initial key ensures that all subsequent stochastic operations are reproducible.
Stochastic methods accept an optional rng_key for per-call determinism.
If provided, it affects only that call and does not modify the model’s internal key.
If omitted, a new subkey is derived internally, so repeated calls may produce different results.
To fully reproduce results, log the initial seed along with other artifacts.
Why do some methods return DataTree?#
A DataTree organizes heterogeneous groups (posterior, posterior_predictive, predictions) with labeled dimensions and coordinates, facilitating I/O, slicing, and downstream analysis.
It can also be easily converted to an arviz.InferenceData object using arviz.from_datatree().
If desired, you can pass return_datatree=False to methods such as predict_on_batch() to return a plain dictionary instead.
Why do I not see a posterior group in the output?#
It appears in the returned DataTree only if posterior samples are available (fitted or injected).
Where is the on-disk output written?#
All outputs are written under the directory passed via output_dir.
If output_dir=None, a temporary directory is created (accessible via
temp_dir) and removed when the model is cleaned up
(either explicitly with cleanup() or when the instance is
garbage collected).
Each group in the returned DataTree stores its own artifact path
in an output_dir attribute, and the root tree includes the top-level path.
Does serialization persist the posterior samples?#
Yes.
Pickling (or MLflow integration via aimz.mlflow) preserves the posterior samples (if set) and the cached KernelSpec so retracing / re-fitting is unnecessary upon load.
See Model Persistence or MLflow Integration for more details.