google-deepmind/dm-haiku
Python
Captured source
source ↗google-deepmind/dm-haiku
Description: JAX-based neural network library
Language: Python
License: Apache-2.0
Stars: 3240
Forks: 293
Open issues: 105
Created: 2020-02-18T07:14:02Z
Pushed: 2026-06-02T17:56:40Z
Default branch: main
Fork: no
Archived: no
README:
Haiku: [Sonnet] for [JAX]
[Overview](#overview) | [Why Haiku?](#why-haiku) | [Quickstart](#quickstart) | [Installation](#installation) | **Examples** | [User manual](#user-manual) | **Documentation** | [Citing Haiku](#citing-haiku)
> [!IMPORTANT] > 📣 As of July 2023 [Google DeepMind] recommends that new projects adopt > [Flax] instead of Haiku. [Flax] is a neural network library originally > developed by [Google Brain] and now by [Google DeepMind]. 📣 > > At the time of writing [Flax] has superset of the features available in Haiku, > a larger and > more active development team and > more adoption with users outside of Alphabet. [Flax] has > more extensive documentation, > examples > and an active community creating end > to end examples. > > Haiku will remain best-effort supported, however the project will enter > maintenance mode, meaning > that development efforts will be focussed on bug fixes and compatibility with > new releases of JAX. > > New releases will be made to keep Haiku working with newer versions of Python > and [JAX], however we will not be adding (or accepting PRs for) new features. > > We have significant usage of Haiku internally at [Google DeepMind] and > currently plan to support Haiku in this mode indefinitely.
What is Haiku?
> Haiku is a tool
> For building neural networks
> Think: "[Sonnet] for [JAX]"
Haiku is a simple neural network library for [JAX] developed by some of the authors of [Sonnet], a neural network library for [TensorFlow].
Documentation on Haiku can be found at https://dm-haiku.readthedocs.io/.
Disambiguation: if you are looking for Haiku the operating system then please see https://haiku-os.org/.
Overview
[JAX] is a numerical computing library that combines NumPy, automatic differentiation, and first-class GPU/TPU support.
Haiku is a simple neural network library for JAX that enables users to use familiar object-oriented programming models while allowing full access to JAX's pure function transformations.
Haiku provides two core tools: a module abstraction, hk.Module, and a simple function transformation, hk.transform.
hk.Modules are Python objects that hold references to their own parameters, other modules, and methods that apply functions on user inputs.
hk.transform turns functions that use these object-oriented, functionally "impure" modules into pure functions that can be used with jax.jit, jax.grad, jax.pmap, etc.
Why Haiku?
There are a number of neural network libraries for JAX. Why should you choose Haiku?
Haiku has been tested by researchers at DeepMind at scale.
- DeepMind has reproduced a number of experiments in Haiku and JAX with relative
ease. These include large-scale results in image and language processing, generative models, and reinforcement learning.
Haiku is a library, not a framework.
- Haiku is designed to make specific things simpler: managing model parameters
and other model state.
- Haiku can be expected to compose with other libraries and work well with the
rest of JAX.
- Haiku otherwise is designed to get out of your way - it does not define custom
optimizers, checkpointing formats, or replication APIs.
Haiku does not reinvent the wheel.
- Haiku builds on the programming model and APIs of Sonnet, a neural network
library with near universal adoption at DeepMind. It preserves Sonnet's Module-based programming model for state management while retaining access to JAX's function transformations.
- Haiku APIs and abstractions are as close as reasonable to Sonnet. Many users
have found Sonnet to be a productive programming model in TensorFlow; Haiku enables the same experience in JAX.
Transitioning to Haiku is easy.
- By design, transitioning from TensorFlow and Sonnet to JAX and Haiku is easy.
- Outside of new features (e.g.
hk.transform), Haiku aims to match the API of
Sonnet 2. Modules, methods, argument names, defaults, and initialization schemes should match.
Haiku makes other aspects of JAX simpler.
- Haiku offers a trivial model for working with random numbers. Within a
transformed function, hk.next_rng_key() returns a unique rng key.
- These unique keys are deterministically derived from an initial random key
passed into the top-level transformed function, and are thus safe to use with JAX program transformations.
Quickstart
Let's take a look at an example neural network, loss function, and training loop. (For more examples, see our examples directory. The MNIST example is a good place to start.)
import haiku as hk import jax.numpy as jnp def softmax_cross_entropy(logits, labels): one_hot = jax.nn.one_hot(labels, logits.shape[-1]) return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1) def loss_fn(images, labels): mlp = hk.Sequential([ hk.Linear(300), jax.nn.relu, hk.Linear(100), jax.nn.relu, hk.Linear(10), ]) logits = mlp(images) return jnp.mean(softmax_cross_entropy(logits, labels)) loss_fn_t = hk.transform(loss_fn) loss_fn_t = hk.without_apply_rng(loss_fn_t) rng = jax.random.PRNGKey(42) dummy_images, dummy_labels = next(input_dataset) params = loss_fn_t.init(rng, dummy_images, dummy_labels) def update_rule(param, update): return param - 0.01 * update for images, labels in input_dataset: grads = jax.grad(loss_fn_t.apply)(params, images, labels) params = jax.tree.map(update_rule, params, grads)
The core of Haiku is hk.transform. The transform function allows you to write neural network functions that rely on parameters (here the weights…
Excerpt shown — open the source for the full document.