aimz.ImpactModel.predict_on_batch#
- ImpactModel.predict_on_batch(X, *, intervention=None, rng_key=None, in_sample=True, return_sites=None, return_datatree=True, **kwargs)[source]#
Predict the output based on the fitted model.
This method returns predictions for a single batch of input data and is better suited for:
1) Models incompatible with
predict()due to their posterior sample shapes.2) Scenarios where writing results to to files (e.g., disk, cloud storage) is not desired.
3) Smaller datasets, as this method may be slower due to limited parallelism.
- Parameters:
X (ArrayLike) – Input data. The leading axis is the sample axis.
intervention (dict | None) – A dictionary mapping sample sites to their corresponding intervention values. Interventions enable counterfactual analysis by modifying the specified sample sites during prediction (posterior predictive sampling).
rng_key (Array | None) – A pseudo-random number generator key. By default, an internal key is used and split as needed.
in_sample (bool) – Specifies the group where posterior predictive samples are stored in the returned output. If
True, samples are stored in theposterior_predictivegroup, indicating they were generated based on data used during model fitting. IfFalse, samples are stored in thepredictionsgroup, indicating they were generated based on out-of-sample data.return_sites (str | Iterable[str] | None) – Names of variables (sites) to return. If
None, samplesparam_outputand deterministic sites.return_datatree (bool) – If
True, return aDataTree; otherwise return adict.**kwargs (object) – Additional arguments passed to the model.
- Returns:
Posterior predictive samples. Posterior samples are included if available.
- Raises:
TypeError – If
param_outputis passed as an argument.- Return type: