aimz.ImpactModel#

class aimz.ImpactModel(kernel, rng_key, inference, *, param_input='X', param_output='y')[source]#

Impact modeling interface: fit, sample, predict, and estimate effects.

Parameters:
  • kernel (Callable)

  • rng_key (Array)

  • inference (SVI | MCMC)

  • param_input (str)

  • param_output (str)

__init__(kernel, rng_key, inference, *, param_input='X', param_output='y')[source]#

Initialize an ImpactModel instance.

Parameters:
  • kernel (Callable) – A probabilistic model with NumPyro primitives.

  • rng_key (Array) – A pseudo-random number generator key.

  • inference (SVI | MCMC) – An inference method supported by NumPyro, such as an instance of SVI or MCMC.

  • param_input (str) – Name of the parameter in the kernel for the main input data.

  • param_output (str) – Name of the parameter in the kernel for the output data.

Return type:

None

Warning

The rng_key parameter should be provided as a typed key array created with jax.random.key(), rather than a legacy uint32 key created with jax.random.PRNGKey().

Methods

__init__(kernel, rng_key, inference, *[, ...])

Initialize an ImpactModel instance.

cleanup()

Clean up the temporary directory created for storing outputs.

cleanup_models()

Clean up temporary directories for all ImpactModel instances.

estimate_effect([output_baseline, ...])

Estimate the effect of an intervention.

fit(X[, y, num_samples, rng_key, progress, ...])

Fit the impact model to the provided data using epoch-based training.

fit_on_batch(X, y, *[, num_steps, ...])

Fit the impact model to the provided batch of data.

is_fitted()

Check fitted status.

log_likelihood(X[, y, batch_size, ...])

Compute the log-likelihood of the data under the given model.

predict(X, *[, intervention, rng_key, ...])

Predict the output based on the fitted model.

predict_on_batch(X, *[, intervention, ...])

Predict the output based on the fitted model.

sample(*[, num_samples, rng_key, ...])

Draw posterior samples from a fitted model.

sample_posterior_predictive(X, *[, ...])

Draw samples from the posterior predictive distribution.

sample_posterior_predictive_on_batch(X, *[, ...])

Draw samples from the posterior predictive distribution.

sample_prior_predictive(X, *[, num_samples, ...])

Draw samples from the prior predictive distribution.

sample_prior_predictive_on_batch(X, *[, ...])

Draw samples from the prior predictive distribution.

set_posterior_sample(posterior_sample)

Set posterior samples for the model.

train_on_batch(X, y[, rng_key])

Run a single VI step on the given batch of data.

Attributes

inference

The underlying NumPyro inference object.

kernel

A probabilistic model with NumPyro primitives.

kernel_spec

The cached KernelSpec or None if not yet built.

param_input

Parameter name in kernel for the input data.

param_output

Parameter name in kernel for the output data.

posterior

Posterior samples by variable name, or None if not set.

rng_key

Pseudo-random number generator key.

temp_dir

Temporary directory path, or None if not set.

vi_result

Variational inference result, or None if not set.