Model Persistence#
Model persistence allows you to save a trained model to disk and reload it later for inference or continued training.
This documentation shows how to serialize and deserialize an ImpactModel instance using cloudpickle, which extends the standard pickle module to handle a wide range of Python objects, including closures and local functions.
An alternative is dill, which offers similar functionality.
Note
For MLflow users, see the MLflow Integration page for details on saving and loading models with MLflow.
Model Training#
import logging
from pathlib import Path
import cloudpickle
import jax.numpy as jnp
import numpyro.distributions as dist
from jax import random
from jax.typing import ArrayLike
from numpyro import optim, sample
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
logging.basicConfig(level=logging.INFO, force=True)
from aimz import ImpactModel
def model(X: ArrayLike, y: ArrayLike | None = None) -> None:
"""Linear regression model."""
w = sample("w", dist.Normal().expand((X.shape[1],)))
b = sample("b", dist.Normal())
mu = jnp.dot(X, w) + b
sigma = sample("sigma", dist.Exponential())
sample("y", dist.Normal(mu, sigma), obs=y)
rng_key = random.key(42)
rng_key, rng_key_w, rng_key_b, rng_key_x, rng_key_e = random.split(rng_key, 5)
w = random.normal(rng_key_w, (10,))
b = random.normal(rng_key_b)
X = random.normal(rng_key_x, (1000, 10))
e = random.normal(rng_key_e, (1000,))
y = jnp.dot(X, w) + b + e
rng_key, rng_subkey = random.split(rng_key)
im = ImpactModel(
model,
rng_key=rng_subkey,
inference=SVI(
model,
guide=AutoNormal(model),
optim=optim.Adam(step_size=1e-3),
loss=Trace_ELBO(),
),
)
im.fit_on_batch(X, y, progress=False);
Serialization#
Save a trained ImpactModel (and optionally its input data) to disk for later use:
with Path("model.pkl").open("wb") as f:
cloudpickle.dump((im, X, y), f)
Deserialization#
Load a previously saved ImpactModel (and optionally its input data) from disk in a fresh new session or different runtime environment.
To use the loaded model correctly, the same dependencies, imports, and any constants or variables that the model relied on when it was saved must be available.
Any JAX arrayβwhether part of the ImpactModel or the input dataβwill be placed on the default device.
from pathlib import Path
import cloudpickle
import jax.numpy as jnp
import numpyro.distributions as dist
from numpyro import sample
with Path("model.pkl").open("rb") as f:
im, X, y = cloudpickle.load(f)
Model Usage#
# Resume training from the previous SVI state
im.fit_on_batch(X, y, progress=False)
# Predict using the loaded model
im.predict_on_batch(X)
<xarray.DataTree 'root'>
Group: /
βββ Group: /posterior
β Dimensions: (chain: 1, draw: 1000, w_dim_0: 10)
β Coordinates:
β * chain (chain) int64 8B 0
β * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
β * w_dim_0 (w_dim_0) int64 80B 0 1 2 3 4 5 6 7 8 9
β Data variables:
β b (chain, draw) float32 4kB 0.3919 0.4467 0.423 ... 0.3477 0.4213
β sigma (chain, draw) float32 4kB 1.023 1.035 1.068 ... 1.062 1.018 1.026
β w (chain, draw, w_dim_0) float32 40kB 0.5846 0.8836 ... -0.03594
β Attributes:
β created_at: 2026-05-24T02:57:09.409429+00:00
β aimz_version: 0.12.0
βββ Group: /posterior_predictive
Dimensions: (chain: 1, draw: 1000, y_dim_0: 1000)
Coordinates:
* chain (chain) int64 8B 0
* draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
* y_dim_0 (y_dim_0) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
Data variables:
y (chain, draw, y_dim_0) float32 4MB 2.211 -0.9371 ... -2.104 -2.566
Attributes:
created_at: 2026-05-24T02:57:09.411730+00:00
aimz_version: 0.12.0See Also#
dilldocumentationjax Arrayserialization