aimz.ImpactModel.predict_on_batch#

ImpactModel.predict_on_batch(X, *, intervention=None, rng_key=None, in_sample=True, return_sites=None, return_datatree=True, **kwargs)[source]#

Predict the output based on the fitted model.

This method returns predictions for a single batch of input data and is better suited for:

1) Models incompatible with predict() due to their posterior sample shapes.

2) Scenarios where writing results to to files (e.g., disk, cloud storage) is not desired.

3) Smaller datasets, as this method may be slower due to limited parallelism.

Parameters:
  • X (ArrayLike) – Input data. The leading axis is the sample axis.

  • intervention (dict | None) – A dictionary mapping sample sites to their corresponding intervention values. Interventions enable counterfactual analysis by modifying the specified sample sites during prediction (posterior predictive sampling).

  • rng_key (Array | None) – A pseudo-random number generator key. By default, an internal key is used and split as needed.

  • in_sample (bool) – Specifies the group where posterior predictive samples are stored in the returned output. If True, samples are stored in the posterior_predictive group, indicating they were generated based on data used during model fitting. If False, samples are stored in the predictions group, indicating they were generated based on out-of-sample data.

  • return_sites (str | Iterable[str] | None) – Names of variables (sites) to return. If None, samples param_output and deterministic sites.

  • return_datatree (bool) – If True, return a DataTree; otherwise return a dict.

  • **kwargs (object) – Additional arguments passed to the model.

Returns:

Posterior predictive samples. Posterior samples are included if available.

Raises:

TypeError – If param_output is passed as an argument.

Return type:

xr.DataTree | dict[str, npt.NDArray]