Interventions & Effect Estimation#

This guide covers two closely related functionalities:

  • The intervention argument available in predictive methods like predict(), predict_on_batch(), and their posterior predictive counterparts. Internally, this enables hard (do) interventions on specified sample sites using NumPyro’s do effect handler to generate counterfactual draws without rewriting the model.

  • The estimate_effect() method, which computes the elementwise difference between an intervention (counterfactual) scenario and a baseline (factual) scenario to quantify causal or policy impact.

Typical workflow:

  1. Generate one predictive result under factual conditions (optionally also using intervention if you want to hold certain sites at specific values).

  2. Generate another predictive result under a modified intervention mapping.

  3. Pass both results (or the argument dictionaries to generate them lazily) to estimate_effect() to obtain the effect output.

Each scenario is a DataTree produced by the prediction API or materialized on-demand via argument dictionaries.

Interventions#

The intervention argument is a mapping (dict[str, ArrayLike]) from sample site name to a replacement value; during predictive sampling each listed site is fixed, enabling counterfactual or policy analysis. Values must broadcast to the site’s per‑observation shape (e.g., intervening on a length‑N vector site generally requires shape (N,)). You can modify multiple sites at once; any not specified follow their posterior (or prior) distribution.

Setting in_sample=True stores draws under posterior_predictive while in_sample=False stores them under predictions—the group must match between baseline and intervention scenarios when computing effects. Deterministic downstream sites automatically reflect the intervened values.

# Minimal sketch of a model exposing a stochastic site 'z'
def model(X, Z, y=None):
    ...
    # site we may choose to override at prediction time
    z = numpyro.sample("z", ...)
    ...

# Fit (details elided);
im = ImpactModel(model, ...).fit_on_batch(...)

# Baseline scenario: set 'z' to its observed/factual value Z
baseline = im.predict_on_batch(X, intervention={"z": Z})

# Modified scenario: counterfactual where we overwrite 'z' with zeros
modified = im.predict_on_batch(
    X,
    intervention={"z": jnp.zeros_like(Z)},
)

Effect Estimation#

The estimate_effect() method computes an elementwise difference between two predictive scenarios (intervention - baseline) and returns a single-group DataTree that preserves sampling dimensions.

One baseline and one intervention scenario must be provided, either eagerly (output_baseline / output_intervention) or lazily through argument dictionaries (args_baseline / args_intervention). Mixing is allowed; for example, a precomputed baseline can be supplied with output_baseline while the intervention is generated lazily with args_intervention (or the reverse). Both scenarios must come from the same predictive group (both posterior_predictive or both predictions) with matching variable sets and shapes.

The result contains that shared group name and each variable is the elementwise difference

\[\text{intervention} - \text{baseline}\]

retaining leading draw / chain dimensions.

Eager (precomputed scenarios):

effect = im.estimate_effect(
    output_baseline=baseline,
    output_intervention=modified,
)

Lazy (defer prediction):

effect = im.estimate_effect(
    args_baseline={
        "X": X,
        "intervention": {"z": Z},
        "in_sample": False,
    },
    args_intervention={
        "X": X,
        "intervention": {"z": jnp.zeros_like(Z)},
        "in_sample": False,
    },
)

Mixed (precomputed baseline, lazy intervention):

effect = im.estimate_effect(
    output_baseline=baseline,
    args_intervention={
        "X": X,
        "intervention": {"z": jnp.zeros_like(Z)},
        "in_sample": False,
    },
)

The returned DataTree captures the elementwise difference for every variable present in the predictive group. Any subsequent summary (e.g. mean, intervals) can be computed using Xarray, ArviZ, or standard NumPy / JAX utilities.

Note

estimate_effect() computes the posterior predictive contrast between two scenarios under structural interventions, propagating full posterior uncertainty through the difference. Whether this contrast admits a causal interpretation depends on the structural assumptions encoded in the model (the kernel): causal identification is a property of the model specification, not the estimation procedure. When the user-defined model encodes appropriate causal assumptions—such as conditioning on confounders and specifying correct functional relationships—this contrast corresponds to a causal effect estimate.

Example: Causal Network with Confounder#

This example illustrates a simple causal network. The variable Z has a direct causal effect on the outcome Y, while both are influenced by a shared confounder, C. An additional variable, X, is an observed exogenous factor that influences Z but has no direct effect on Y.

Our objective is to estimate the causal effect of Z (or alternatively X) on Y, while properly accounting for the confounding influence of C. We assume the following generative model for the observed data:

Model#

import logging

import jax.numpy as jnp
import numpyro.distributions as dist
from jax import nn, random
from jax.typing import ArrayLike
from numpyro import optim, plate, sample
from numpyro.infer import SVI, Trace_ELBO, init_to_feasible
from numpyro.infer.autoguide import AutoNormal

from aimz import ImpactModel

logging.basicConfig(level=logging.INFO, force=True)


def model(X: ArrayLike, C: ArrayLike, y: ArrayLike | None = None) -> None:
    # Observed confounder
    c = sample("c", dist.Exponential(), obs=C)

    # Priors for coefficients in the structural model
    # C -> Z and C -> Y
    beta_cz = sample("beta_cz", dist.Normal())
    beta_cy = sample("beta_cy", dist.Normal())

    # X -> Z and Z -> Y
    beta_xz = sample("beta_xz", dist.Normal())
    beta_zy = sample("beta_zy", dist.Normal())

    # Intercepts
    beta_z = sample("beta_z", dist.Normal())
    beta_y = sample("beta_y", dist.Normal())

    # Observation noise for Z
    sigma = sample("sigma", dist.Exponential())

    # Plate over data
    with plate("data", X.shape[0]):
        mu_z = beta_z + beta_cz * c + beta_xz * X.squeeze(axis=1)
        z = sample("z", dist.LogNormal(mu_z, sigma))

        logits = beta_y + beta_cy * c + beta_zy * z
        sample("y", dist.Bernoulli(logits=logits), obs=y)

Simulating Data under a Known Structural Model#

We generate synthetic data consistent with the assumed structure:

  • C is drawn from an exponential distribution.

  • X is a count variable from a Poisson distribution.

  • Z is generated as a noisy exponential function of C and X.

  • Y is a binary outcome influenced by both C and Z through a logistic model.

# Create a pseudo-random number generator key for JAX
rng_key = random.key(42)

# Sample C from an Exponential distribution
rng_key, rng_subkey = random.split(rng_key)
C = random.exponential(rng_subkey, shape=(100,))

# Sample X from a Poisson distribution
rng_key, rng_subkey = random.split(rng_key)
X = random.poisson(rng_subkey, lam=1, shape=(100, 1))

# Generate Z influenced by C and X
rng_key, rng_subkey = random.split(rng_key)
mu_z = -1.0 + 0.5 * C - 1.5 * X.squeeze()
sigma_z = 10.0  # Add substantial noise to reduce correlation between C and Z
Z = jnp.exp(random.normal(rng_subkey, shape=(100,)) * sigma_z + mu_z)

# Generate Y from a logistic regression on C and Z
rng_key, rng_subkey = random.split(rng_key)
logits = -2.0 + 5.0 * C + 0.1 * Z
p = nn.sigmoid(logits)
y = random.bernoulli(rng_subkey, p=p).astype(jnp.int32)

Fitting the Model and Estimating Effects#

We fit the model using stochastic variational inference. Once trained, we perform a counterfactual analysis to isolate the effect of Z on Y.

  • dt_factual represents predictions under the factual setting (with observed Z).

  • dt_counterfactual represents predictions under a counterfactual intervention where Z is set to zero.

Note

This model contains a local latent variable, which requires predict_on_batch() here. Prefer predict() whenever it is compatible with the model. See model compatibility for details.

Comparing these two distributions allows us to estimate the effect of Z on Y, adjusted for the influence of C.

im = ImpactModel(
    model,
    rng_key=rng_key,
    inference=SVI(
        model,
        guide=AutoNormal(model, init_loc_fn=init_to_feasible()),
        optim=optim.Adam(step_size=1e-3),
        loss=Trace_ELBO(),
    ),
)
im.fit_on_batch(X, y, C=C)

# Predict under factual (Z) and counterfactual (zeroed Z) scenarios
dt_factual = im.predict_on_batch(X, C=C, intervention={"z": Z})
dt_counterfactual = im.predict_on_batch(
    X,
    C=C,
    intervention={"z": jnp.zeros_like(Z)},
)

# Estimate effect of intervening on Z while conditioning on C
effect = im.estimate_effect(
    output_baseline=dt_factual,
    output_intervention=dt_counterfactual,
)
effect
<xarray.DataTree 'root'>
Group: /
├── Group: /posterior_predictive
│       Dimensions:  (chain: 1, draw: 1000, y_dim_0: 100)
│       Coordinates:
│         * chain    (chain) int64 8B 0
│         * draw     (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
│         * y_dim_0  (y_dim_0) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99
│       Data variables:
│           y        (chain, draw, y_dim_0) int32 400kB 0 0 1 1 0 0 1 ... 1 0 -1 -1 0 0
│       Attributes:
│           aimz_version:  0.12.0
└── Group: /posterior
        Dimensions:  (chain: 1, draw: 1000, z_dim_0: 100)
        Coordinates:
          * chain    (chain) int64 8B 0
          * draw     (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
          * z_dim_0  (z_dim_0) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99
        Data variables:
            beta_cy  (chain, draw) float32 4kB 2.496 1.31 2.062 ... 1.907 2.169 1.718
            beta_cz  (chain, draw) float32 4kB 0.08686 0.1083 0.04103 ... 0.1054 0.1363
            beta_xz  (chain, draw) float32 4kB -0.0563 -0.05368 ... -0.04063 -0.01289
            beta_y   (chain, draw) float32 4kB -0.4724 -0.3403 ... -0.239 -0.1706
            beta_z   (chain, draw) float32 4kB -0.04289 -0.05741 ... -0.01758 -0.02535
            beta_zy  (chain, draw) float32 4kB 0.126 -0.2083 0.1889 ... 0.5462 0.1949
            sigma    (chain, draw) float32 4kB 0.3966 0.3815 0.3818 ... 0.3876 0.3586
            z        (chain, draw, z_dim_0) float32 400kB 1.15 0.8842 ... 0.8195 1.737
        Attributes:
            created_at:    2026-05-24T03:02:58.460378+00:00
            aimz_version:  0.12.0