aimz.ImpactModel.predict#
- ImpactModel.predict(X, *, intervention=None, rng_key=None, in_sample=True, return_sites=None, batch_size=None, output_dir=None, progress=True, **kwargs)[source]#
Predict the output based on the fitted model.
This method performs posterior predictive sampling to generate model-based predictions. It is optimized for batch processing of large input data and is not recommended for use in loops that process only a few inputs at a time. Results are written to disk in the Zarr format, with sampling and file writing decoupled and executed concurrently.
- Parameters:
X (ArrayLike | ArrayLoader) – Input data. If array-like, the leading axis is the sample axis. Alternatively, a data loader that holds all array-like objects and handles batching internally.
intervention (dict | None) – A dictionary mapping sample sites to their corresponding intervention values. Interventions enable counterfactual analysis by modifying the specified sample sites during prediction (posterior predictive sampling).
rng_key (Array | None) – A pseudo-random number generator key. By default, an internal key is used and split as needed.
in_sample (bool) – Specifies the group where posterior predictive samples are stored in the returned output. If
True, samples are stored in theposterior_predictivegroup, indicating they were generated based on data used during model fitting. IfFalse, samples are stored in thepredictionsgroup, indicating they were generated based on out-of-sample data.return_sites (str | Iterable[str] | None) – Names of variables (sites) to return. If
None, samplesparam_outputand deterministic sites.batch_size (int | None) – The batch size for data loading during posterior predictive sampling. It also determines the chunk size used to store the samples. If
None, it is determined automatically based on the input data and number of samples. Ignored ifXis a data loader, in which case the data loader is expected to handle batching internally.output_dir (str | Path | None) – The directory where the outputs will be saved. If the specified directory does not exist, it will be created automatically. If
None, a default temporary directory will be created. A timestamped subdirectory will be generated within this directory to store the outputs. Outputs are saved in the Zarr format.progress (bool) – Whether to display a progress bar.
**kwargs (object) – Additional arguments passed to the model.
- Returns:
Posterior predictive samples. Posterior samples are included if available.
- Raises:
TypeError – If
param_outputis passed as an argument.NotImplementedError – If a return site’s axis-1 size does not match the input batch size.
- Return type:
xr.DataTree
See also
cleanup()to remove the temporary directory if created.