google-deepmind/distrax

Python

Open original ↗

Captured source

source ↗
published Apr 1, 2021seen 5dcaptured 8hhttp 200method plain

google-deepmind/distrax

Language: Python

License: Apache-2.0

Stars: 635

Forks: 40

Open issues: 57

Created: 2021-04-01T17:03:49Z

Pushed: 2026-06-10T13:53:19Z

Default branch: main

Fork: no

Archived: no

README:

Distrax

!CI status

Distrax is a lightweight library of probability distributions and bijectors. It acts as a JAX-native reimplementation of a subset of TensorFlow Probability (TFP), with some new features and emphasis on extensibility.

Installation

You can install the latest released version of Distrax from PyPI via:

pip install distrax

or you can install the latest development version from GitHub:

pip install git+https://github.com/deepmind/distrax.git

To run the tests or examples you will need to install additional requirements.

Design Principles

The general design principles for the DeepMind JAX Ecosystem are addressed in this blog. Additionally, Distrax places emphasis on the following:

1. Readability. Distrax implementations are intended to be self-contained and read as close to the underlying math as possible. 2. Extensibility. We have made it as simple as possible for users to define their own distribution or bijector. This is useful for example in reinforcement learning, where users may wish to define custom behavior for probabilistic agent policies. 3. Compatibility. Distrax is not intended as a replacement for TFP, and TFP contains many advanced features that we do not intend to replicate. To this end, we have made the APIs for distributions and bijectors as cross-compatible as possible, and provide utilities for transforming between equivalent Distrax and TFP classes.

Features

Distributions

Distributions in Distrax are simple to define and use, particularly if you're used to TFP. Let's compare the two side-by-side:

import distrax
import jax
import jax.numpy as jnp

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

key = jax.random.PRNGKey(1234)
mu = jnp.array([-1., 0., 1.])
sigma = jnp.array([0.1, 0.2, 0.3])

dist_distrax = distrax.MultivariateNormalDiag(mu, sigma)
dist_tfp = tfd.MultivariateNormalDiag(mu, sigma)

samples = dist_distrax.sample(seed=key)

# Both print 1.775
print(dist_distrax.log_prob(samples))
print(dist_tfp.log_prob(samples))

In addition to behaving consistently, Distrax distributions and TFP distributions are cross-compatible. For example:

mu_0 = jnp.array([-1., 0., 1.])
sigma_0 = jnp.array([0.1, 0.2, 0.3])
dist_distrax = distrax.MultivariateNormalDiag(mu_0, sigma_0)

mu_1 = jnp.array([1., 2., 3.])
sigma_1 = jnp.array([0.2, 0.3, 0.4])
dist_tfp = tfd.MultivariateNormalDiag(mu_1, sigma_1)

# Both print 85.237
print(dist_distrax.kl_divergence(dist_tfp))
print(tfd.kl_divergence(dist_distrax, dist_tfp))

Distrax distributions implement the method sample_and_log_prob, which provides samples and their log-probability in one line. For some distributions, this is more efficient than calling separately sample and log_prob:

mu = jnp.array([-1., 0., 1.])
sigma = jnp.array([0.1, 0.2, 0.3])
dist_distrax = distrax.MultivariateNormalDiag(mu, sigma)

samples = dist_distrax.sample(seed=key, sample_shape=())
log_prob = dist_distrax.log_prob(samples)

# A one-line equivalent of the above is:
samples, log_prob = dist_distrax.sample_and_log_prob(seed=key, sample_shape=())

TFP distributions can be passed to Distrax meta-distributions as inputs. For example:

key = jax.random.PRNGKey(1234)

mu = jnp.array([-1., 0., 1.])
sigma = jnp.array([0.2, 0.3, 0.4])
dist_tfp = tfd.Normal(mu, sigma)

metadist_distrax = distrax.Independent(dist_tfp, reinterpreted_batch_ndims=1)
samples = metadist_distrax.sample(seed=key)
print(metadist_distrax.log_prob(samples)) # Prints 0.38871175

To use Distrax distributions in TFP meta-distributions, Distrax provides the wrapper to_tfp. A wrapped Distrax distribution can be directly used in TFP:

key = jax.random.PRNGKey(1234)

distrax_dist = distrax.Normal(0., 1.)
wrapped_dist = distrax.to_tfp(distrax_dist)
metadist_tfp = tfd.Sample(wrapped_dist, sample_shape=[3])

samples = metadist_tfp.sample(seed=key)
print(metadist_tfp.log_prob(samples)) # Prints -3.3409896

Bijectors

A "bijector" in Distrax is an invertible function that knows how to compute its Jacobian determinant. Bijectors can be used to create complex distributions by transforming simpler ones. Distrax bijectors are functionally similar to TFP bijectors, with a few API differences. Here is an example comparing the two:

import distrax
import jax.numpy as jnp

from tensorflow_probability.substrates import jax as tfp
tfb = tfp.bijectors
tfd = tfp.distributions

# Same distribution.
distrax.Transformed(distrax.Normal(loc=0., scale=1.), distrax.Tanh())
tfd.TransformedDistribution(tfd.Normal(loc=0., scale=1.), tfb.Tanh())

Additionally, Distrax bijectors can be composed and inverted:

bij_distrax = distrax.Tanh()
bij_tfp = tfb.Tanh()

# Same bijector.
inv_bij_distrax = distrax.Inverse(bij_distrax)
inv_bij_tfp = tfb.Invert(bij_tfp)

# These are both the identity bijector.
distrax.Chain([bij_distrax, inv_bij_distrax])
tfb.Chain([bij_tfp, inv_bij_tfp])

All TFP bijectors can be passed to Distrax, and can be freely composed with Distrax bijectors. For example, all of the following will work:

distrax.Inverse(tfb.Tanh())

distrax.Chain([tfb.Tanh(), distrax.Tanh()])

distrax.Transformed(tfd.Normal(loc=0., scale=1.), tfb.Tanh())

Distrax bijectors can also be passed to TFP, but first they must be transformed with to_tfp:

bij_distrax = distrax.to_tfp(distrax.Tanh())

tfb.Invert(bij_distrax)

tfb.Chain([tfb.Tanh(), bij_distrax])

tfd.TransformedDistribution(tfd.Normal(loc=0., scale=1.), bij_distrax)

Distrax also comes with Lambda, a convenient wrapper for turning simple JAX functions into bijectors. Here are a few Lambda examples with their TFP equivalents:

distrax.Lambda(lambda x: x)
# tfb.Identity()

distrax.Lambda(lambda x: 2*x + 3)
# tfb.Chain([tfb.Shift(3), tfb.Scale(2)])

distrax.Lambda(jnp.sinh)
# tfb.Sinh()…

Excerpt shown — open the source for the full document.