Interventions & Effect Estimation#
This guide covers two closely related functionalities:
The
interventionargument available in predictive methods likepredict(),predict_on_batch(), and their posterior predictive counterparts. Internally, this enables hard (do) interventions on specified sample sites using NumPyro’sdoeffect handler to generate counterfactual draws without rewriting the model.The
estimate_effect()method, which computes the elementwise difference between an intervention (counterfactual) scenario and a baseline (factual) scenario to quantify causal or policy impact.
Typical workflow:
Generate one predictive result under factual conditions (optionally also using
interventionif you want to hold certain sites at specific values).Generate another predictive result under a modified
interventionmapping.Pass both results (or the argument dictionaries to generate them lazily) to
estimate_effect()to obtain the effect output.
Each scenario is a DataTree produced by the prediction API or materialized on-demand via argument dictionaries.
Interventions#
The intervention argument is a mapping (dict[str, ArrayLike]) from sample site name to a replacement value; during predictive sampling each listed site is fixed, enabling counterfactual or policy analysis.
Values must broadcast to the site’s per‑observation shape (e.g., intervening on a length‑N vector site generally requires shape (N,)).
You can modify multiple sites at once; any not specified follow their posterior (or prior) distribution.
Setting in_sample=True stores draws under posterior_predictive while in_sample=False stores them under predictions—the group must match between baseline and intervention scenarios when computing effects.
Deterministic downstream sites automatically reflect the intervened values.
# Minimal sketch of a model exposing a stochastic site 'z'
def model(X, Z, y=None):
...
# site we may choose to override at prediction time
z = numpyro.sample("z", ...)
...
# Fit (details elided);
im = ImpactModel(model, ...).fit_on_batch(...)
# Baseline scenario: set 'z' to its observed/factual value Z
baseline = im.predict_on_batch(X, intervention={"z": Z})
# Modified scenario: counterfactual where we overwrite 'z' with zeros
modified = im.predict_on_batch(
X,
intervention={"z": jnp.zeros_like(Z)},
)
Effect Estimation#
The estimate_effect() method computes an elementwise difference between two predictive scenarios (intervention - baseline) and returns a single-group DataTree that preserves sampling dimensions.
One baseline and one intervention scenario must be provided, either eagerly (output_baseline / output_intervention) or lazily through argument dictionaries (args_baseline / args_intervention).
Mixing is allowed; for example, a precomputed baseline can be supplied with output_baseline while the intervention is generated lazily with args_intervention (or the reverse).
Both scenarios must come from the same predictive group (both posterior_predictive or both predictions) with matching variable sets and shapes.
The result contains that shared group name and each variable is the elementwise difference
retaining leading draw / chain dimensions.
Eager (precomputed scenarios):
effect = im.estimate_effect(
output_baseline=baseline,
output_intervention=modified,
)
Lazy (defer prediction):
effect = im.estimate_effect(
args_baseline={
"X": X,
"intervention": {"z": Z},
"in_sample": False,
},
args_intervention={
"X": X,
"intervention": {"z": jnp.zeros_like(Z)},
"in_sample": False,
},
)
Mixed (precomputed baseline, lazy intervention):
effect = im.estimate_effect(
output_baseline=baseline,
args_intervention={
"X": X,
"intervention": {"z": jnp.zeros_like(Z)},
"in_sample": False,
},
)
The returned DataTree captures the elementwise difference for every variable present in the predictive group.
Any subsequent summary (e.g. mean, intervals) can be computed using Xarray, ArviZ, or standard NumPy / JAX utilities.
Note
estimate_effect() computes the posterior predictive contrast between two scenarios under structural interventions, propagating full posterior uncertainty through the difference.
Whether this contrast admits a causal interpretation depends on the structural assumptions encoded in the model (the kernel): causal identification is a property of the model specification, not the estimation procedure.
When the user-defined model encodes appropriate causal assumptions—such as conditioning on confounders and specifying correct functional relationships—this contrast corresponds to a causal effect estimate.
Example: Causal Network with Confounder#
This example illustrates a simple causal network. The variable Z has a direct causal effect on the outcome Y, while both are influenced by a shared confounder, C.
An additional variable, X, is an observed exogenous factor that influences Z but has no direct effect on Y.
Our objective is to estimate the causal effect of Z (or alternatively X) on Y, while properly accounting for the confounding influence of C.
We assume the following generative model for the observed data:
Model#
import logging
import jax.numpy as jnp
import numpyro.distributions as dist
from jax import nn, random
from jax.typing import ArrayLike
from numpyro import optim, plate, sample
from numpyro.infer import SVI, Trace_ELBO, init_to_feasible
from numpyro.infer.autoguide import AutoNormal
from aimz import ImpactModel
logging.basicConfig(level=logging.INFO, force=True)
def model(X: ArrayLike, C: ArrayLike, y: ArrayLike | None = None) -> None:
# Observed confounder
c = sample("c", dist.Exponential(), obs=C)
# Priors for coefficients in the structural model
# C -> Z and C -> Y
beta_cz = sample("beta_cz", dist.Normal())
beta_cy = sample("beta_cy", dist.Normal())
# X -> Z and Z -> Y
beta_xz = sample("beta_xz", dist.Normal())
beta_zy = sample("beta_zy", dist.Normal())
# Intercepts
beta_z = sample("beta_z", dist.Normal())
beta_y = sample("beta_y", dist.Normal())
# Observation noise for Z
sigma = sample("sigma", dist.Exponential())
# Plate over data
with plate("data", X.shape[0]):
mu_z = beta_z + beta_cz * c + beta_xz * X.squeeze(axis=1)
z = sample("z", dist.LogNormal(mu_z, sigma))
logits = beta_y + beta_cy * c + beta_zy * z
sample("y", dist.Bernoulli(logits=logits), obs=y)
Simulating Data under a Known Structural Model#
We generate synthetic data consistent with the assumed structure:
C is drawn from an exponential distribution.
X is a count variable from a Poisson distribution.
Z is generated as a noisy exponential function of C and X.
Y is a binary outcome influenced by both C and Z through a logistic model.
# Create a pseudo-random number generator key for JAX
rng_key = random.key(42)
# Sample C from an Exponential distribution
rng_key, rng_subkey = random.split(rng_key)
C = random.exponential(rng_subkey, shape=(100,))
# Sample X from a Poisson distribution
rng_key, rng_subkey = random.split(rng_key)
X = random.poisson(rng_subkey, lam=1, shape=(100, 1))
# Generate Z influenced by C and X
rng_key, rng_subkey = random.split(rng_key)
mu_z = -1.0 + 0.5 * C - 1.5 * X.squeeze()
sigma_z = 10.0 # Add substantial noise to reduce correlation between C and Z
Z = jnp.exp(random.normal(rng_subkey, shape=(100,)) * sigma_z + mu_z)
# Generate Y from a logistic regression on C and Z
rng_key, rng_subkey = random.split(rng_key)
logits = -2.0 + 5.0 * C + 0.1 * Z
p = nn.sigmoid(logits)
y = random.bernoulli(rng_subkey, p=p).astype(jnp.int32)
Fitting the Model and Estimating Effects#
We fit the model using stochastic variational inference. Once trained, we perform a counterfactual analysis to isolate the effect of Z on Y.
dt_factual represents predictions under the factual setting (with observed Z).
dt_counterfactual represents predictions under a counterfactual intervention where Z is set to zero.
Note
This model contains a local latent variable, which requires predict_on_batch() here.
Prefer predict() whenever it is compatible with the model.
See model compatibility for details.
Comparing these two distributions allows us to estimate the effect of Z on Y, adjusted for the influence of C.
im = ImpactModel(
model,
rng_key=rng_key,
inference=SVI(
model,
guide=AutoNormal(model, init_loc_fn=init_to_feasible()),
optim=optim.Adam(step_size=1e-3),
loss=Trace_ELBO(),
),
)
im.fit_on_batch(X, y, C=C)
# Predict under factual (Z) and counterfactual (zeroed Z) scenarios
dt_factual = im.predict_on_batch(X, C=C, intervention={"z": Z})
dt_counterfactual = im.predict_on_batch(
X,
C=C,
intervention={"z": jnp.zeros_like(Z)},
)
# Estimate effect of intervening on Z while conditioning on C
effect = im.estimate_effect(
output_baseline=dt_factual,
output_intervention=dt_counterfactual,
)
effect
<xarray.DataTree 'root'>
Group: /
├── Group: /posterior_predictive
│ Dimensions: (chain: 1, draw: 1000, y_dim_0: 100)
│ Coordinates:
│ * chain (chain) int64 8B 0
│ * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
│ * y_dim_0 (y_dim_0) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99
│ Data variables:
│ y (chain, draw, y_dim_0) int32 400kB 0 0 1 1 0 0 1 ... 1 0 -1 -1 0 0
│ Attributes:
│ aimz_version: 0.12.0
└── Group: /posterior
Dimensions: (chain: 1, draw: 1000, z_dim_0: 100)
Coordinates:
* chain (chain) int64 8B 0
* draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
* z_dim_0 (z_dim_0) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99
Data variables:
beta_cy (chain, draw) float32 4kB 2.496 1.31 2.062 ... 1.907 2.169 1.718
beta_cz (chain, draw) float32 4kB 0.08686 0.1083 0.04103 ... 0.1054 0.1363
beta_xz (chain, draw) float32 4kB -0.0563 -0.05368 ... -0.04063 -0.01289
beta_y (chain, draw) float32 4kB -0.4724 -0.3403 ... -0.239 -0.1706
beta_z (chain, draw) float32 4kB -0.04289 -0.05741 ... -0.01758 -0.02535
beta_zy (chain, draw) float32 4kB 0.126 -0.2083 0.1889 ... 0.5462 0.1949
sigma (chain, draw) float32 4kB 0.3966 0.3815 0.3818 ... 0.3876 0.3586
z (chain, draw, z_dim_0) float32 400kB 1.15 0.8842 ... 0.8195 1.737
Attributes:
created_at: 2026-05-24T02:56:37.261669+00:00
aimz_version: 0.12.0