MCMC Support#
While aimz is primarily designed around variational inference and predictive sampling, it also provides support for MCMC methods via the NumPyro backend, using the same aimz interface (e.g., fit_on_batch() and predict_on_batch()).
This enables users to apply MCMC to more complex models where variational inference may be less effective and dataset sizes are relatively small.
import logging
import jax.numpy as jnp
import numpyro.distributions as dist
from jax import random
from jax.typing import ArrayLike
from numpyro import plate, sample
from numpyro.infer import MCMC, NUTS
from aimz import ImpactModel
logging.basicConfig(level=logging.INFO, force=True)
Model and Data#
We set up a linear regression model and create synthetic data for both features and targets as an example.
def model(X: ArrayLike, y: ArrayLike | None = None) -> None:
"""Linear regression model."""
w = sample("w", dist.Normal().expand((X.shape[1],)))
b = sample("b", dist.Normal())
mu = jnp.dot(X, w) + b
sigma = sample("sigma", dist.Exponential())
with plate("data", size=X.shape[0]):
sample("y", dist.Normal(mu, sigma), obs=y)
rng_key = random.key(42)
rng_key, rng_key_w, rng_key_b, rng_key_x, rng_key_e = random.split(rng_key, 5)
w = random.normal(rng_key_w, (10,))
b = random.normal(rng_key_b)
X = random.normal(rng_key_x, (1000, 10))
e = random.normal(rng_key_e, (1000,))
y = jnp.dot(X, w) + b + e
MCMC Sampling and Prediction#
MCMC sampling can be performed using the ImpactModel class by setting the inference argument to MCMC.
Users can configure the sampler, warm-up steps, and other MCMC-specific parameters.
Calling fit_on_batch() initiates the sampling process.
Internally, aimz executes the sampler via the run() method and stores the posterior samples using get_samples().
Note that calling fit() with MCMC as the inference method will raise a TypeError, as this method is intended for mini-batch training or subsampling.
Regardless of the number of chains (num_chains) used, the posterior samples are combined across chains to ensure compatibility with the rest of the aimz interface.
Posterior predictive sampling can be performed using the predict() or predict_on_batch() methods.
rng_key, rng_subkey = random.split(rng_key)
im = ImpactModel(
model,
rng_key=rng_subkey,
inference=MCMC(NUTS(model), num_warmup=500, num_samples=500),
)
im.fit_on_batch(X, y)
im.inference.print_summary()
im.predict_on_batch(X)
mean std median 5.0% 95.0% n_eff r_hat
b 0.41 0.03 0.41 0.36 0.47 1354.28 1.00
sigma 1.03 0.02 1.03 1.00 1.07 1086.22 1.00
w[0] 0.59 0.04 0.59 0.54 0.66 1220.29 1.00
w[1] 0.86 0.03 0.86 0.81 0.92 1337.06 1.00
w[2] -0.90 0.03 -0.90 -0.96 -0.85 1167.08 1.00
w[3] -0.60 0.04 -0.60 -0.65 -0.54 830.57 1.00
w[4] -1.24 0.03 -1.24 -1.29 -1.18 754.35 1.00
w[5] -0.80 0.03 -0.80 -0.85 -0.75 1840.03 1.00
w[6] -0.51 0.03 -0.51 -0.56 -0.46 972.23 1.00
w[7] -1.22 0.03 -1.22 -1.27 -1.16 1357.39 1.00
w[8] -0.16 0.03 -0.16 -0.22 -0.11 1130.10 1.00
w[9] -0.10 0.03 -0.10 -0.15 -0.05 880.53 1.00
Number of divergences: 0
<xarray.DataTree 'root'>
Group: /
├── Group: /posterior
│ Dimensions: (chain: 1, draw: 500, w_dim_0: 10)
│ Coordinates:
│ * chain (chain) int64 8B 0
│ * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
│ * w_dim_0 (w_dim_0) int64 80B 0 1 2 3 4 5 6 7 8 9
│ Data variables:
│ b (chain, draw) float32 2kB 0.386 0.4236 0.4178 ... 0.3851 0.4542
│ sigma (chain, draw) float32 2kB 1.033 1.033 1.01 ... 1.018 1.034 1.015
│ w (chain, draw, w_dim_0) float32 20kB 0.6264 0.8762 ... -0.08495
│ Attributes:
│ created_at: 2026-05-24T03:03:12.206758+00:00
│ aimz_version: 0.12.0
└── Group: /posterior_predictive
Dimensions: (chain: 1, draw: 500, y_dim_0: 1000)
Coordinates:
* chain (chain) int64 8B 0
* draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
* y_dim_0 (y_dim_0) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
Data variables:
y (chain, draw, y_dim_0) float32 2MB 3.052 -1.501 ... -4.549 -1.457
Attributes:
created_at: 2026-05-24T03:03:12.209104+00:00
aimz_version: 0.12.0Using External MCMC Samples#
Users can run MCMC sampling directly using NumPyro and then insert the posterior samples into an ImpactModel instance using the set_posterior_sample() method for downstream analysis.
For example:
mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000)
rng_key, rng_subkey = random.split(rng_key)
mcmc.run(rng_key, X, y)
im.set_posterior_sample(mcmc.get_samples())
im.predict_on_batch(X)
<xarray.DataTree 'root'>
Group: /
├── Group: /posterior
│ Dimensions: (chain: 1, draw: 1000, w_dim_0: 10)
│ Coordinates:
│ * chain (chain) int64 8B 0
│ * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
│ * w_dim_0 (w_dim_0) int64 80B 0 1 2 3 4 5 6 7 8 9
│ Data variables:
│ b (chain, draw) float32 4kB 0.441 0.43 0.412 ... 0.3979 0.3862 0.4424
│ sigma (chain, draw) float32 4kB 1.024 1.02 1.038 ... 1.006 1.038 1.04
│ w (chain, draw, w_dim_0) float32 40kB 0.6524 0.8598 ... -0.08973
│ Attributes:
│ created_at: 2026-05-24T03:03:15.947980+00:00
│ aimz_version: 0.12.0
└── Group: /posterior_predictive
Dimensions: (chain: 1, draw: 1000, y_dim_0: 1000)
Coordinates:
* chain (chain) int64 8B 0
* draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
* y_dim_0 (y_dim_0) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
Data variables:
y (chain, draw, y_dim_0) float32 4MB 2.16 -0.8406 ... -2.185 -2.508
Attributes:
created_at: 2026-05-24T03:03:15.950461+00:00
aimz_version: 0.12.0