ReleaseGoogle (DeepMind / Gemini)Google (DeepMind / Gemini)published Feb 25, 2026seen 5d

google-deepmind/kfac-jax v0.0.8

google-deepmind/kfac-jax

Open original ↗

Captured source

source ↗
published Feb 25, 2026seen 5dcaptured 13hhttp 200method plain

kfac_jax 0.0.8

Repository: google-deepmind/kfac-jax

Tag: v0.0.8

Published: 2026-02-25T00:40:23Z

Prerelease: no

Release notes: NOTE: this is the last version which explicitly sets jax.config.jax_pmap_shmap_merge to False. After this, users must set this is their codebase is designed to use the old pmap. kfac_jax now supports both old and new pmap (though this might change in the future so that only new is supported).

What's Changed

  • has_aux fix by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/341
  • - Adding a basic implementation of an adaptive technique to set the initial damping value (used in the automatic damping adaptation). by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/346
  • Adding a note that clarifies the format of the array arguments to the optimizer's step() and init() functions. by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/347
  • Filtering out only scalar values to be logged in polyak stats by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/348
  • Enable training on a fixed number of batches from the training dataset in a pre-emption safe way. by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/352
  • Minor code quality improvements by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/353
  • Deterministic resume when num_batches specified. by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/354
  • Silence pytype error for deprecated JAX API by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/355
  • Remove layer tags from processed jaxpr in kfac_jax transforms. This is required as the initial layer tags for the underlying function are not valid anymore once we apply one of the kfac_jax transforms. In general, this is a necessity for subsequently using the jaxpr of the transformed function, either within or outside the kfac_jax framework. by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/357
  • Fixing broken test by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/360
  • Internal Change by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/359
  • Internal Change by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/361
  • Fixing a PyType issue caused by recent JAX CL. by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/365
  • Ignore pytype errors produced with --use-functools-partial-overlay by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/364
  • Minor improvement to schedule code for examples by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/366
  • Removing broken "mask" feature from sigmoid_cross_entropy loss in examples code. by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/372
  • [kfac_jax] Prepare for jax_pmap_shmap_merge=True. by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/373
  • Improve logging of parameter registrations. by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/375
  • Excluding opt_state from eval worker when possible. by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/374
  • Updates to schedules module in examples code: by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/377
  • - Fixing issue that broke the tracer and scanner logic when a layer had a literal (i.e. a constant) as input. by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/378
  • Minor fixes to docstrings. by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/376
  • [pmap] Make kfac_jax get_first more robust under jax_pmap_shmap_merge by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/379
  • Bumping version number in preparation for next official PyPI release. by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/381
  • Adding function using_legacy_pmap() to detect when using the legacy pmap is being used (i.e. when jax.config.jax_pmap_shmap_merge exists and is False). This should allow the 0.0.8 release to continue working when jax.config.jax_pmap_shmap_merge is removed from JAX, while also supporting older versions of JAX. by @copybara-service[bot] in https://github.com/google-deepmind/kfac-jax/pull/382

Full Changelog: https://github.com/google-deepmind/kfac-jax/compare/v0.0.7...v0.0.8

Notability

notability 3.0/10

Routine minor release of optimization library