aimz.ImpactModel#
- class aimz.ImpactModel(kernel, rng_key, inference, *, param_input='X', param_output='y')[source]#
Impact modeling interface: fit, sample, predict, and estimate effects.
- Parameters:
- __init__(kernel, rng_key, inference, *, param_input='X', param_output='y')[source]#
Initialize an
ImpactModelinstance.- Parameters:
kernel (Callable) – A probabilistic model with NumPyro primitives.
rng_key (Array) – A pseudo-random number generator key.
inference (SVI | MCMC) – An inference method supported by NumPyro, such as an instance of
SVIorMCMC.param_input (str) – Name of the parameter in the
kernelfor the main input data.param_output (str) – Name of the parameter in the
kernelfor the output data.
- Return type:
None
Warning
The
rng_keyparameter should be provided as a typed key array created withjax.random.key(), rather than a legacyuint32key created withjax.random.PRNGKey().
Methods
__init__(kernel, rng_key, inference, *[, ...])Initialize an
ImpactModelinstance.cleanup()Clean up the temporary directory created for storing outputs.
Clean up temporary directories for all
ImpactModelinstances.estimate_effect([output_baseline, ...])Estimate the effect of an intervention.
fit(X[, y, num_samples, rng_key, progress, ...])Fit the impact model to the provided data using epoch-based training.
fit_on_batch(X, y, *[, num_steps, ...])Fit the impact model to the provided batch of data.
Check fitted status.
log_likelihood(X[, y, batch_size, ...])Compute the log-likelihood of the data under the given model.
predict(X, *[, intervention, rng_key, ...])Predict the output based on the fitted model.
predict_on_batch(X, *[, intervention, ...])Predict the output based on the fitted model.
sample(*[, num_samples, rng_key, ...])Draw posterior samples from a fitted model.
sample_posterior_predictive(X, *[, ...])Draw samples from the posterior predictive distribution.
sample_posterior_predictive_on_batch(X, *[, ...])Draw samples from the posterior predictive distribution.
sample_prior_predictive(X, *[, num_samples, ...])Draw samples from the prior predictive distribution.
sample_prior_predictive_on_batch(X, *[, ...])Draw samples from the prior predictive distribution.
set_posterior_sample(posterior_sample)Set posterior samples for the model.
train_on_batch(X, y[, rng_key])Run a single VI step on the given batch of data.
Attributes
The underlying NumPyro inference object.
A probabilistic model with NumPyro primitives.
The cached
KernelSpecorNoneif not yet built.Parameter name in
kernelfor the input data.Parameter name in
kernelfor the output data.Posterior samples by variable name, or
Noneif not set.Pseudo-random number generator key.
Temporary directory path, or
Noneif not set.Variational inference result, or
Noneif not set.