ReleaseGoogle (DeepMind / Gemini)Google (DeepMind / Gemini)published Apr 3, 2024seen 5d

google-deepmind/kfac-jax v0.0.6

google-deepmind/kfac-jax

Open original ↗

Captured source

source ↗
published Apr 3, 2024seen 5dcaptured 12hhttp 200method plain

v0.0.6

Repository: google-deepmind/kfac-jax

Tag: v0.0.6

Published: 2024-04-03T17:12:00Z

Prerelease: no

Release notes:

What's Changed

  • Adding logging for the number of parameters and optimizer state. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/125
  • Adding automatic cross-device averaging of auxiliary loss/models stats to optimizer. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/139
  • Add rel_grad_norm and rel_update_norm stats logging by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/147
  • Fixing bug that would sometimes cause an exception for networks with scalar-valued parameters. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/151
  • [JAX] Migrate XlaBuilder users to emit direct stablehlo MLIR lowerings. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/161
  • Still fixing docs requirements dependencies. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/166
  • Still fixing docs requirements dependencies. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/169
  • Still fixing docs requirements dependencies. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/171
  • Still fixing docs requirements dependencies. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/173
  • Adding capability pass custom arguments to the registration functions, and call them in a custom module, for standard losses in the example code. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/175
  • Fix or ignore some pytype errors. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/177
  • [LSC] Ignore incorrect type annotations related to jax.numpy APIs by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/176
  • * Adding a sum_of_objects. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/190
  • - Adding Polyak averaging feature to example experiments codebase. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/195
  • Adding precon_damping_mult feature to optimizer. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/196
  • Reland https://github.com/google/jax/pull/10573. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/199
  • - minor refactoring by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/201
  • Fixing issue where loss_registered_reldiff was not computed properly in multi-device settings. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/202
  • Adding a new schedule and applying some fixes to existing ones in the examples codebase. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/204
  • Remove gradient normalization from the preconditioning function by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/206

Full Changelog: https://github.com/google-deepmind/kfac-jax/compare/v0.0.5...v0.0.6