aimz.ImpactModel.train_on_batch#

ImpactModel.train_on_batch(X, y, rng_key=None, **kwargs)[source]#

Run a single VI step on the given batch of data.

Parameters:
  • X (ArrayLike) – Input data. The leading axis is the sample axis.

  • y (ArrayLike) – Output data. The leading axis is the sample axis.

  • rng_key (Array | None) – A pseudo-random number generator key. By default, an internal key is used and split as needed. The key is only used for initialization if the internal SVI state is not yet set.

  • **kwargs (object) – Additional arguments passed to the model.

Returns:

  • Updated SVI state after the training step.

  • Loss value as a scalar array.

Return type:

tuple[SVIState, Array]

Note

This method updates the internal SVI state on every call, so it is not necessary to capture the returned state externally unless explicitly needed. However, the returned loss value can be used for monitoring or logging.