google-deepmind/kfac-jax v0.0.7
google-deepmind/kfac-jax
Captured source
source ↗published May 20, 2025seen 5dcaptured 8hhttp 200method plain
kfac_jax 0.0.7
Repository: google-deepmind/kfac-jax
Tag: v0.0.7
Published: 2025-05-20T17:48:51Z
Prerelease: no
Release notes:
What's Changed
- Create a pyproject.toml file to replace the requirements.txt files by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/208
- Match kfac jaxpr debug info result paths with out vars. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/220
- Nest the
compute_exact_quad_modelto allow filtering ofvectorsthat will be multiplied by zero to save computing expensive matrix vector products by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/222 - Remove deprecated
jax.tree_mapcalls by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/227 - * Fixing all Losses to return everything in non-auxiliary data during flattening, to avoid any tracer leaks when the weight is dynamic. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/232
- Log different block class assignements in the curvature estimator. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/234
- Internal cleanup by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/236
- Simplifying the LayerTag Primitive machinary. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/237
- Internal change. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/238
- - Expanding Polyak averaging functionality in examples codebase. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/239
- * Simplifying the LossTag machinery. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/240
- Minor refactor by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/241
- - Updating schedule construction code in examples folder so that it properly detects misspelled argument names. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/243
- Fixing bug in log_train_stats_with_polyak_avg_every_n_steps of example code. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/245
- Replace Bernoulli distributions with Rademachers by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/247
- - Adding feature to BlockDiagonalCurvature to return undamped diagonal. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/249
- Add precon_power option to KFAC optimizer. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/251
- Minor non-functional change. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/255
- Change the default estimation mode of the curvature estimators to
ggn_curvature_propby @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/253 - Added pytype None checks to accumulators.py. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/203
- Remove the
TwoKroneckerFactoredclass and use theKroneckerFactoredclass instead. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/257 - Add TNT blocks to kfac_jax. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/258
- Add an option to specify a different value function for the preconditioner's curvature estimator. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/259
- - Adding handling of jitted functions to graph scanner. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/254
- Fix progress off by one. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/260
- Split
curvature_estimator.pymodule into a package. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/261 - Adding the repeated dense graph patterns. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/263
- Separated the
optimizersmodule in kfac examples into separate modules by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/264 - - Passing stats to _post_param_update_processing in examples code. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/267
- - Adding support for the "Schedule-free" method to be used as a wrapper for Optax optimizers in the examples codebase. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/268
- [kfac-jax] Update graph matching test to support the new "algorithm" tuning parameters for dot_general that will be included in the next JAX release. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/269
- Improving polynomial schedule in the examples codebase so that it works as expected when the initial value is *lower* than the final value. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/271
- - Changing automatic registration (aka the graph scanner) so that it doesn't automatically register a parameter if said parameter is used more than once in the graph. In that case, it resorts to the default "generic" registration (which doesn't make any structure assumptions about how the parameter is used). by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/276
- Changing optimizer to throw an exception when using burnin without a provided data iterator instead of silently skipping burnin. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/277
- Removing check that initial_damping is not set when use_adaptive_damping is False. by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/280
- Add sharding rules to some more primitives so that backward pass of minformer passes. There are a couple of changes here: by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/283
- Removing hacky "fixes" to test_graph_matcher. Basically, the test insists that the manual registration includes all of the params from the main equation in the match found by the graph scanner. Instead of filtering these out, we now ensure that they are included in the manual registrations done in tests/models.py. Note that passing all these params won't be required when using manual registration in practice. Only certain params are mandatory for particular layers (based the type of curvature block that gets assigned to them). by @copybara-service in https://github.com/google-deepmind/kfac-jax/pull/284
- - Improved and simplified implementation of "debug" mode based on…
Excerpt shown — open the source for the full document.
Notability
notability 3.0/10Routine minor release of research library