Changelog#
All notable changes to this project will be documented in this file and are best viewed on the Changelog page.
The format is based on Keep a Changelog, and this project adheres to Semantic Versioning.
Unreleased#
v0.12.0 - 2026-05-23#
Changed#
Input arrays of any shape with at least one dimension are now accepted; the leading axis is treated as the sample axis. Previously,
Xwas required to be 2D andy1D, andywith shape(n, 1)triggered aDataConversionWarningfromscikit-learn(#199).log_likelihood()now evaluates the kernel directly at each posterior draw, mirroring the per-draw pattern used by predictive sampling. The seeded kernel is also constructed once before the per-batch loop, so each batch reuses the same cached compilation (#202).Writer-thread queue sizing used when streaming batched outputs now adapts to available host memory and the per-batch output size (#208).
Disk-backed methods (
predict(),sample_prior_predictive(),log_likelihood()) now preallocate each site’s Zarr array and write each batch into a fixed slice, replacing the previous per-batch append. This avoids repeated Zarr resizing and is faster when the batch size is small and the number of batches is large. As a consequence, every return site must emit an axis-1 size equal to the input batch size; kernels with incompatible return sites raiseNotImplementedError(#213).
Fixed#
Writer-thread startup errors while opening Zarr output groups are now reported through the existing writer error path and queued items are drained before shutdown, preventing the main thread from waiting indefinitely when a background writer fails before consuming its queue (#210).
Removed#
The
scikit-learndependency (#199).
v0.11.0 - 2026-04-29#
Changed#
Removed package-level logging configuration from
aimz/__init__.py.aimzno longer sets a log level, attaches aStreamHandler(sys.stdout), or callslogging.captureWarnings(True)on import; theaimzlogger now only has alogging.NullHandler()attached. Configuring handlers, levels, and warnings capture is the responsibility of the application. Log messages emitted byImpactModelwere also refined—trailing ellipses were removed and posterior sampling now reports the number of samples being drawn—and the output-directory cleanup notice raised whenpredict_on_batch()andlog_likelihood()encounter an error is now logged at thewarninglevel (previouslydebug) (#192).
Fixed#
Fixed
sample_prior_predictive()failing on multi-device meshes withValueError: in_specs ... does not match the specs of the input ... @obs. The probe batch used to trace the kernel and draw global prior samples is now built withbatch_size=1anddevice=None, preventing JAX’s sharding-in-types from propagating theobsmesh axis onto global (non-batched) sample sites that are later passed as the replicatedsamplesargument to thejax.shard_map()sampler (#194).
v0.10.0 - 2026-04-17#
Added#
Added support for Python 3.14.
estimate_effect()now accepts anon_batchkeyword argument. WhenTrue, predictions are dispatched throughpredict_on_batch()and any rawdictresults are automatically converted toxarray.DataTree(#180).
Changed#
ArrayDatasetnow employs NumPy-based indexing inArrayLoaderinstead of triggering JAX tracing on each batch (#168).Changed the default value of
to_jaxinArrayDatasetfromTruetoFalseto avoid redundant conversion (#170).
Fixed#
Fixed auto-computed
batch_sizerounding down to zero on multi-device setups whenMAX_ELEMENTS // num_samplesis smaller than the number of devices (#172).pad_array()now pads with NumPy when given NumPy arrays, avoiding premature device transfers, and skips padding entirely whenn_padis zero (#174).
v0.9.1 - 2025-12-08#
Fixed#
Fixed
jax.shard_map()closure error for shardedrng_keyin parallelism methods when using JAX 0.8 and newer versions (#140).
v0.9.0 - 2025-11-16#
Added#
Added the class method
cleanup_models()to clean up temporary directories for all active model instances (#136).
Changed#
The output subdirectory naming convention has changed from using only a timestamp to the pattern
<timestamp>_<caller_name>/, where<caller_name>is the name of the method that triggered the write operation (#138).Lowered the logging level for exceptions during temporary directory cleanup from
exceptiontodebugto reduce console noise.
v0.8.1 - 2025-10-23#
Changed#
The minimum required versions are: Dask 2025.7, JAX 0.8, and Xarray 2025.7.
Replaced deprecated
jax.experimental.shard_map.shard_mapwithjax.shard_map()to ensure compatibility with JAX 0.8 and newer versions (#128).Logging exception messages are displayed before the writer thread is shut down, providing a more immediate response for
predict()andlog_likelihood(), especially when interrupted by the keyboard (#130).
v0.8.0 - 2025-10-14#
Added#
Extended MLflow autologging to support the
fit_on_batch()method (#119).Added str and repr methods to the {class}~aimz.ImpactModel (#118).
KernelSpecnow includes asample_sitesattribute listing all stochastic sample sites in the model kernel (#125).
v0.7.0 - 2025-09-29#
Added#
output_dirattribute to the root and group nodes ofxarray.DataTreeobjects returned bysample_prior_predictive(),predict(), andlog_likelihood(), specifying the directory where results are saved (#85).Introduced the public
KernelSpecdataclass and thekernel_specattribute onImpactModel. This exposes a lazily-built, cached structural specification of the user kernel (fields:traced,return_sites,output_observed) so training and predictive methods avoid redundant model tracing (#98).When available, an
output_dirattribute is added to the root node ofxarray.DataTreeobject returned byestimate_effect(), specifying the directory where results are saved (#110).
Changed#
All
tqdmprogress bars now usedynamic_ncols=Trueto adjust column width dynamically (#93).fit_on_batch(),sample_prior_predictive_on_batch(),sample_prior_predictive(), andtrain_on_batch()now reuse the cachedkernel_specand avoid redundant model tracing (#98).set_posterior_sample()no longer accepts areturn_sitesparameter; downstream methods can now set it explicitly (#100).set_posterior_sample()now raises an error when an empty posterior dictionary ({}) is provided (#101).sample_prior_predictive_on_batch()andsample_prior_predictive()now include posterior samples in the returned results if available (#103).sample_prior_predictive_on_batch(),sample_prior_predictive(),sample(),sample_posterior_predictive_on_batch(),sample_posterior_predictive(),predict_on_batch(), andpredict()can now accept a singlestror an iterable ofstrvalues for thereturn_sitesparameter (#107).sample_prior_predictive_on_batch()returns the default output site along with deterministic sites whenreturn_sitesis not specified, to be consistent with the behavior of other sampling methods (#108).estimate_effect()returns aposteriorgroup node in thexarray.DataTreeobject when posterior samples are available, to be consistent with other methods (#110).Subdirectories under
temp_dirnow include microseconds in their names to avoid duplicates and file-exists errors (#110).
Fixed#
Methods in
ImpactModelno longer include an emptyposteriordata variable in root node of the returnedxarray.DataTreewhen no posterior samples are available (#91).
v0.6.0 - 2025-09-14#
Added#
sample_prior_predictive_on_batch(),sample(),sample_posterior_predictive_on_batch(), andpredict_on_batch()methods inImpactModelnow support areturn_datatreeparameter. When set toTrue(by default), results are returned as anxarray.DataTree; otherwise, adictis returned (#74).MLflow integration for
ImpactModel(#71).
Changed#
Methods in
ImpactModelnow automatically determine thebatch_sizeif it is not provided, based on the input data and number of samples (#70).sample_posterior_predictive_on_batch()andsample_posterior_predictive()no longer accept thein_sampleargument. Results are now always written to theposterior_predictivegroup.
Removed#
Removed the
tqdmdependency (#80).
Fixed#
Methods in
ImpactModelnow handle empty posterior dictionaries ({}) gracefully instead of failing when no posterior samples are available (#76).
v0.5.0 - 2025-09-01#
Added#
Added a
return_sitesparameter to thepredict()andpredict_on_batch()methods inImpactModel, allowing users to specify which sites to include in the output (#55).sample_prior_predictive_on_batch(), replacingsample_prior_predictive()(#67).sample_posterior_predictive_on_batch(), replacingsample_posterior_predictive()(#67).
Changed#
Switched documentation build system from MkDocs to Sphinx and ReadTheDocs (https://aimz.readthedocs.io).
Added input
Xvalidation tosample_prior_predictive()(#65).Exposed
ImpactModelat the top-level package, allowingfrom aimz import ImpactModel(#67).sample_prior_predictive()now returns axarray.DataTreeinstead of a dictionary, and writes outputs to files like the other methods (#67).sample_posterior_predictive()is now an alias ofpredict()and returns axarray.DataTree(#67).
Fixed#
Enhanced data array validation to preserve device placement for JAX arrays (#53).
Fixed incompatibility with Zarr when models output arrays in
bfloat16by automatically promoting them tofloat32before saving (#57).Fixed the error message in
sample_posterior_predictive()whenself.param_outputis passed as an argument, which previously incorrectly referencedsample_prior_predictive()(#65).
v0.4.0 - 2025-08-18#
Added#
Support for NumPyro MCMC in
ImpactModel, includingfit_on_batch(),sample(), andset_posterior_sample()methods (#35).
Changed#
ImpactModelmethodspredict(),predict_on_batch(),log_likelihood(), andestimate_effect()now return outputs as xarray DataTree instead of ArviZ InferenceData. Dimension names now follow thedim_Nconvention instead of the previousdimNstyle (#49).fit(),fit_on_batch(), andtrain_on_batch()methods inImpactModelnow check for"/"in kernel site names to ensure compatibility with xarray DataTree (#49).
Removed#
Removed the
arvizdependency (#49).
Fixed#
predict()inImpactModelnow checks for available posterior samples before falling back topredict_on_batch().ArrayLoadervalidates thatbatch_sizeis a positive integer.
v0.3.2 - 2025-08-13#
Changed#
Updated
predict()andpredict_on_batch()to check for available posterior samples before returning outputs. This prevents errors when posterior samples are not defined based on the model specification.
v0.3.1 - 2025-08-02#
Fixed#
ArrayDatasetandArrayLoadernow preserve the order in which input arrays are provided, ensuring consistent input mapping in methods likepredict()andlog_likelihood()(#43).
v0.3.0 - 2025-07-18#
Changed#
ImpactModelinitialization parametervihas been renamed toinferencefor compatibility with MCMC in future releases (#36).ImpactModelnow supportsArrayLoaderfor both input and output data (#24).Renamed the posterior sample attribute of
ImpactModelfromposterior_samples_toposterior, which is now initialized toNone(#25).ArrayLoaderandArrayDatasetno longer require thetorchdependency.ArrayDatasetnow accepts only named arrays, andArrayLoaderyields tuples of a dictionary and a padding integer (#26).
Removed#
Removed the
torchdependency (#26).
v0.2.0 - 2025-07-10#
Added#
train_on_batch()andfit_on_batch()methods toImpactModel(#15).Custom
ArrayDatasetclass for handling data inImpactModel, removing the need for thejax-dataloaderdependency (#14).GitHub Pages documentation site (#10).
Installation instructions in the documentation site (#10).
ArrayLoaderclass supportsshuffleparameter for epoch training forfit()(#15).
Changed#
Adopted
jax.typingmodule for improved type hints.Removed unnecessary JAX array type conversion in
ImpactModelmethods.The
fit()method now uses epoch-based (minibatch) training (#15).Updated
fit(),train_on_batch(), andfit_on_batch()to train the model using the internal SVI state, continuing from the last state if available (#15).
Removed#
Removed the
jax-dataloaderdependency (#14).Removed the
guideproperty, as it is part of theviproperty.
v0.1.0 - 2025-06-27#
Added#
Initial public release.