Changelog#

All notable changes to this project will be documented in this file and are best viewed on the Changelog page.

The format is based on Keep a Changelog, and this project adheres to Semantic Versioning.

Unreleased#

v0.12.0 - 2026-05-23#

Changed#

  • Input arrays of any shape with at least one dimension are now accepted; the leading axis is treated as the sample axis. Previously, X was required to be 2D and y 1D, and y with shape (n, 1) triggered a DataConversionWarning from scikit-learn (#199).

  • log_likelihood() now evaluates the kernel directly at each posterior draw, mirroring the per-draw pattern used by predictive sampling. The seeded kernel is also constructed once before the per-batch loop, so each batch reuses the same cached compilation (#202).

  • Writer-thread queue sizing used when streaming batched outputs now adapts to available host memory and the per-batch output size (#208).

  • Disk-backed methods (predict(), sample_prior_predictive(), log_likelihood()) now preallocate each site’s Zarr array and write each batch into a fixed slice, replacing the previous per-batch append. This avoids repeated Zarr resizing and is faster when the batch size is small and the number of batches is large. As a consequence, every return site must emit an axis-1 size equal to the input batch size; kernels with incompatible return sites raise NotImplementedError (#213).

Fixed#

  • Writer-thread startup errors while opening Zarr output groups are now reported through the existing writer error path and queued items are drained before shutdown, preventing the main thread from waiting indefinitely when a background writer fails before consuming its queue (#210).

Removed#

  • The scikit-learn dependency (#199).

v0.11.0 - 2026-04-29#

Changed#

  • Removed package-level logging configuration from aimz/__init__.py. aimz no longer sets a log level, attaches a StreamHandler(sys.stdout), or calls logging.captureWarnings(True) on import; the aimz logger now only has a logging.NullHandler() attached. Configuring handlers, levels, and warnings capture is the responsibility of the application. Log messages emitted by ImpactModel were also refined—trailing ellipses were removed and posterior sampling now reports the number of samples being drawn—and the output-directory cleanup notice raised when predict_on_batch() and log_likelihood() encounter an error is now logged at the warning level (previously debug) (#192).

Fixed#

  • Fixed sample_prior_predictive() failing on multi-device meshes with ValueError: in_specs ... does not match the specs of the input ... @obs. The probe batch used to trace the kernel and draw global prior samples is now built with batch_size=1 and device=None, preventing JAX’s sharding-in-types from propagating the obs mesh axis onto global (non-batched) sample sites that are later passed as the replicated samples argument to the jax.shard_map() sampler (#194).

v0.10.0 - 2026-04-17#

Added#

Changed#

  • ArrayDataset now employs NumPy-based indexing in ArrayLoader instead of triggering JAX tracing on each batch (#168).

  • Changed the default value of to_jax in ArrayDataset from True to False to avoid redundant conversion (#170).

Fixed#

  • Fixed auto-computed batch_size rounding down to zero on multi-device setups when MAX_ELEMENTS // num_samples is smaller than the number of devices (#172).

  • pad_array() now pads with NumPy when given NumPy arrays, avoiding premature device transfers, and skips padding entirely when n_pad is zero (#174).

v0.9.1 - 2025-12-08#

Fixed#

  • Fixed jax.shard_map() closure error for sharded rng_key in parallelism methods when using JAX 0.8 and newer versions (#140).

v0.9.0 - 2025-11-16#

Added#

  • Added the class method cleanup_models() to clean up temporary directories for all active model instances (#136).

Changed#

  • The output subdirectory naming convention has changed from using only a timestamp to the pattern <timestamp>_<caller_name>/, where <caller_name> is the name of the method that triggered the write operation (#138).

  • Lowered the logging level for exceptions during temporary directory cleanup from exception to debug to reduce console noise.

v0.8.1 - 2025-10-23#

Changed#

  • The minimum required versions are: Dask 2025.7, JAX 0.8, and Xarray 2025.7.

  • Replaced deprecated jax.experimental.shard_map.shard_map with jax.shard_map() to ensure compatibility with JAX 0.8 and newer versions (#128).

  • Logging exception messages are displayed before the writer thread is shut down, providing a more immediate response for predict() and log_likelihood(), especially when interrupted by the keyboard (#130).

v0.8.0 - 2025-10-14#

Added#

  • Extended MLflow autologging to support the fit_on_batch() method (#119).

  • Added str and repr methods to the {class}~aimz.ImpactModel (#118).

  • KernelSpec now includes a sample_sites attribute listing all stochastic sample sites in the model kernel (#125).

v0.7.0 - 2025-09-29#

Added#

Changed#

Fixed#

  • Methods in ImpactModel no longer include an empty posterior data variable in root node of the returned xarray.DataTree when no posterior samples are available (#91).

v0.6.0 - 2025-09-14#

Added#

Changed#

Removed#

  • Removed the tqdm dependency (#80).

Fixed#

  • Methods in ImpactModel now handle empty posterior dictionaries ({}) gracefully instead of failing when no posterior samples are available (#76).

v0.5.0 - 2025-09-01#

Added#

Changed#

Fixed#

  • Enhanced data array validation to preserve device placement for JAX arrays (#53).

  • Fixed incompatibility with Zarr when models output arrays in bfloat16 by automatically promoting them to float32 before saving (#57).

  • Fixed the error message in sample_posterior_predictive() when self.param_output is passed as an argument, which previously incorrectly referenced sample_prior_predictive() (#65).

v0.4.0 - 2025-08-18#

Added#

Changed#

Removed#

  • Removed the arviz dependency (#49).

Fixed#

v0.3.2 - 2025-08-13#

Changed#

  • Updated predict() and predict_on_batch() to check for available posterior samples before returning outputs. This prevents errors when posterior samples are not defined based on the model specification.

v0.3.1 - 2025-08-02#

Fixed#

v0.3.0 - 2025-07-18#

Changed#

Removed#

  • Removed the torch dependency (#26).

v0.2.0 - 2025-07-10#

Added#

Changed#

  • Adopted jax.typing module for improved type hints.

  • Removed unnecessary JAX array type conversion in ImpactModel methods.

  • The fit() method now uses epoch-based (minibatch) training (#15).

  • Updated fit(), train_on_batch(), and fit_on_batch() to train the model using the internal SVI state, continuing from the last state if available (#15).

Removed#

  • Removed the jax-dataloader dependency (#14).

  • Removed the guide property, as it is part of the vi property.

v0.1.0 - 2025-06-27#

Added#

  • Initial public release.