google-deepmind/kfac-jax v0.0.8
google-deepmind/kfac-jax
Captured source
source ↗published Feb 25, 2026seen 5dcaptured 13hhttp 200method plain
kfac_jax 0.0.8
Repository: google-deepmind/kfac-jax
Tag: v0.0.8
Published: 2026-02-25T00:40:23Z
Prerelease: no
Release notes: NOTE: this is the last version which explicitly sets jax.config.jax_pmap_shmap_merge to False. After this, users must set this is their codebase is designed to use the old pmap. kfac_jax now supports both old and new pmap (though this might change in the future so that only new is supported).
What's Changed
- has_aux fix by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/341
- - Adding a basic implementation of an adaptive technique to set the initial damping value (used in the automatic damping adaptation). by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/346
- Adding a note that clarifies the format of the array arguments to the optimizer's step() and init() functions. by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/347
- Filtering out only scalar values to be logged in polyak stats by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/348
- Enable training on a fixed number of batches from the training dataset in a pre-emption safe way. by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/352
- Minor code quality improvements by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/353
- Deterministic resume when num_batches specified. by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/354
- Silence pytype error for deprecated JAX API by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/355
- Remove layer tags from processed jaxpr in kfac_jax transforms. This is required as the initial layer tags for the underlying function are not valid anymore once we apply one of the kfac_jax transforms. In general, this is a necessity for subsequently using the jaxpr of the transformed function, either within or outside the kfac_jax framework. by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/357
- Fixing broken test by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/360
- Internal Change by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/359
- Internal Change by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/361
- Fixing a PyType issue caused by recent JAX CL. by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/365
- Ignore pytype errors produced with --use-functools-partial-overlay by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/364
- Minor improvement to schedule code for examples by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/366
- Removing broken "mask" feature from sigmoid_cross_entropy loss in examples code. by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/372
- [kfac_jax] Prepare for
jax_pmap_shmap_merge=True. by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/373 - Improve logging of parameter registrations. by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/375
- Excluding opt_state from eval worker when possible. by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/374
- Updates to schedules module in examples code: by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/377
- - Fixing issue that broke the tracer and scanner logic when a layer had a literal (i.e. a constant) as input. by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/378
- Minor fixes to docstrings. by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/376
- [pmap] Make kfac_jax get_first more robust under jax_pmap_shmap_merge by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/379
- Bumping version number in preparation for next official PyPI release. by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/381
- Adding function using_legacy_pmap() to detect when using the legacy pmap is being used (i.e. when jax.config.jax_pmap_shmap_merge exists and is False). This should allow the 0.0.8 release to continue working when jax.config.jax_pmap_shmap_merge is removed from JAX, while also supporting older versions of JAX. by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/382
Full Changelog: https://github.com/google-deepmind/kfac-jax/compare/v0.0.7...v0.0.8
Notability
notability 3.0/10Routine minor release of optimization library