ReleaseGoogle (DeepMind / Gemini)Google (DeepMind / Gemini)published Mar 20, 2026seen 5d

google-deepmind/optax v0.2.8

google-deepmind/optax

Open original ↗

Captured source

source ↗
published Mar 20, 2026seen 5dcaptured 8hhttp 200method plain

Optax 0.2.8

Repository: google-deepmind/optax

Tag: v0.2.8

Published: 2026-03-20T23:29:14Z

Prerelease: no

Release notes:

What's Changed

  • Following the JAX 0.9.2 release, the jax_pmap_shmap_merge config flag was removed so that the jax.pmap implementation is always based on jax.jit and jax.shard_map, and opting into the old jax.pmap behavior is no longer an option. Optax had opted into the old behavior to give users time to migrate, and as of Optax 0.2.8 this is no longer supported. This changed shouldn't impact most users, but if you experience errors or performance regressions as a result of it, you can update your code following JAX's migration guide (or use JAX 0.9.2 or earlier and set jax.config.update("jax_pmap_shmap_merge", False)).
  • Explicitly specify the dtype of the gradient accumulator in the MultiStep transform. by @copybara-service[bot] in https://github.com/google-deepmind/optax/pull/1605
  • feat: add preconditioning and coef presets to muon by @massena-t in https://github.com/google-deepmind/optax/pull/1602
  • Backwards compatibility export for the newton schulz iterator by @copybara-service[bot] in https://github.com/google-deepmind/optax/pull/1608
  • Remove TensorFlow dependency in Adversarial training example by @rajasekharporeddy in https://github.com/google-deepmind/optax/pull/1609
  • Improve lookahead docstrings with example and usage notes by @rdyro in https://github.com/google-deepmind/optax/pull/1619
  • Make sure inject_hyperparams uses the dtype inferred from parameters... by @copybara-service[bot] in https://github.com/google-deepmind/optax/pull/1615
  • Memory-optimization for microbatching. by @copybara-service[bot] in https://github.com/google-deepmind/optax/pull/1623
  • Remove TensorFlow dependency and migrate mlp_mnist to Flax NNX by @selamw1 in https://github.com/google-deepmind/optax/pull/1536
  • Let inject use the highest dtype found in the params as the default dtype of params. by @copybara-service[bot] in https://github.com/google-deepmind/optax/pull/1628
  • Support scheduling alpha for AdEMAmix by @copybara-service[bot] in https://github.com/google-deepmind/optax/pull/1630
  • [JAX] Suppress type errors found by pytype after correcting definition of jax.typing.ArrayLike. by @copybara-service[bot] in https://github.com/google-deepmind/optax/pull/1629
  • [JAX] Suppress type errors found by pytype after correcting definition of jax.typing.ArrayLike. by @copybara-service[bot] in https://github.com/google-deepmind/optax/pull/1633

New Contributors

  • @massena-t made their first contribution in https://github.com/google-deepmind/optax/pull/1602
  • @selamw1 made their first contribution in https://github.com/google-deepmind/optax/pull/1536

Full Changelog: https://github.com/google-deepmind/optax/compare/v0.2.7...v0.2.8

Notability

notability 3.0/10

Minor release, routine maintenance