google-deepmind/kfac-jax v0.0.6
google-deepmind/kfac-jax
Captured source
source ↗published Apr 3, 2024seen 5dcaptured 12hhttp 200method plain
v0.0.6
Repository: google-deepmind/kfac-jax
Tag: v0.0.6
Published: 2024-04-03T17:12:00Z
Prerelease: no
Release notes:
What's Changed
- Adding logging for the number of parameters and optimizer state. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/125
- Adding automatic cross-device averaging of auxiliary loss/models stats to optimizer. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/139
- Add
rel_grad_normandrel_update_normstats logging by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/147 - Fixing bug that would sometimes cause an exception for networks with scalar-valued parameters. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/151
- [JAX] Migrate XlaBuilder users to emit direct stablehlo MLIR lowerings. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/161
- Still fixing docs requirements dependencies. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/166
- Still fixing docs requirements dependencies. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/169
- Still fixing docs requirements dependencies. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/171
- Still fixing docs requirements dependencies. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/173
- Adding capability pass custom arguments to the registration functions, and call them in a custom module, for standard losses in the example code. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/175
- Fix or ignore some pytype errors. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/177
- [LSC] Ignore incorrect type annotations related to jax.numpy APIs by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/176
- * Adding a
sum_of_objects. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/190 - - Adding Polyak averaging feature to example experiments codebase. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/195
- Adding precon_damping_mult feature to optimizer. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/196
- Reland https://github.com/google/jax/pull/10573. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/199
- - minor refactoring by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/201
- Fixing issue where loss_registered_reldiff was not computed properly in multi-device settings. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/202
- Adding a new schedule and applying some fixes to existing ones in the examples codebase. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/204
- Remove gradient normalization from the preconditioning function by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/206
Full Changelog: https://github.com/google-deepmind/kfac-jax/compare/v0.0.5...v0.0.6