google-deepmind/chex

Python

Open original ↗

Captured source

source ↗
published Aug 6, 2020seen 5dcaptured 8hhttp 200method plain

google-deepmind/chex

Language: Python

License: Apache-2.0

Stars: 944

Forks: 69

Open issues: 70

Created: 2020-08-06T09:32:36Z

Pushed: 2026-06-10T13:59:50Z

Default branch: main

Fork: no

Archived: no

README:

Chex

!CI status !docs

Chex is a library of utilities for helping to write reliable JAX code.

This includes utils to help:

  • Instrument your code (e.g. assertions, warnings)
  • Debug (e.g. transforming pmaps in vmaps within a context manager).
  • Test JAX code across many variants (e.g. jitted vs non-jitted).

Installation

You can install the latest released version of Chex from PyPI via:

pip install chex

or you can install the latest development version from GitHub:

pip install git+https://github.com/deepmind/chex.git

Modules Overview

Dataclass (dataclass.py)

Dataclasses are a popular construct introduced by Python 3.7 to allow to easily specify typed data structures with minimal boilerplate code. They are not, however, compatible with JAX and dm-tree out of the box.

In Chex we provide a JAX-friendly dataclass implementation reusing python dataclasses.

Chex implementation of dataclass registers dataclasses as internal _PyTree_ nodes to ensure compatibility with JAX data structures.

In addition, we provide a class wrapper that exposes dataclasses as collections.Mapping descendants which allows to process them (e.g. (un-)flatten) in dm-tree methods as usual Python dictionaries. See `@mappable_dataclass` docstring for more details.

Example:

@chex.dataclass
class Parameters:
x: chex.ArrayDevice
y: chex.ArrayDevice

parameters = Parameters(
x=jnp.ones((2, 2)),
y=jnp.ones((1, 2)),
)

# Dataclasses can be treated as JAX pytrees
jax.tree_util.tree_map(lambda x: 2.0 * x, parameters)

# and as mappings by dm-tree
tree.flatten(parameters)

NOTE: Unlike standard Python 3.7 dataclasses, Chex dataclasses cannot be constructed using positional arguments. They support construction arguments provided in the same format as the Python dict constructor. Dataclasses can be converted to tuples with the from_tuple and to_tuple methods if necessary.

parameters = Parameters(
jnp.ones((2, 2)),
jnp.ones((1, 2)),
)
# ValueError: Mappable dataclass constructor doesn't support positional args.

Assertions (asserts.py)

One limitation of PyType annotations for JAX is that they do not support the specification of DeviceArray ranks, shapes or dtypes. Chex includes a number of functions that allow flexible and concise specification of these properties.

E.g. suppose you want to ensure that all tensors t1, t2, t3 have the same shape, and that tensors t4, t5 have rank 2 and (3 or 4), respectively.

chex.assert_equal_shape([t1, t2, t3])
chex.assert_rank([t4, t5], [2, {3, 4}])

More examples:

from chex import assert_shape, assert_rank, ...

assert_shape(x, (2, 3)) # x has shape (2, 3)
assert_shape([x, y], [(), (2,3)]) # x is scalar and y has shape (2, 3)

assert_rank(x, 0) # x is scalar
assert_rank([x, y], [0, 2]) # x is scalar and y is a rank-2 array
assert_rank([x, y], {0, 2}) # x and y are scalar OR rank-2 arrays

assert_type(x, int) # x has type `int` (x can be an array)
assert_type([x, y], [int, float]) # x has type `int` and y has type `float`

assert_equal_shape([x, y, z]) # x, y, and z have equal shapes

assert_trees_all_close(tree_x, tree_y) # values and structure of trees match
assert_tree_all_finite(tree_x) # all tree_x leaves are finite

assert_devices_available(2, 'gpu') # 2 GPUs available
assert_tpu_available() # at least 1 TPU available

assert_numerical_grads(f, (x, y), j) # f^{(j)}(x, y) matches numerical grads

See asserts.py documentation to find all supported assertions.

If you cannot find a specific assertion, please consider making a pull request or openning an issue on the bug tracker.

Optional Arguments

All chex assertions support the following optional kwargs for manipulating the emitted exception messages:

  • custom_message: A string to include into the emitted exception messages.
  • include_default_message: Whether to include the default Chex message into

the emitted exception messages.

  • exception_type: An exception type to use. AssertionError by default.

For example, the following code:

dataset = load_dataset()
params = init_params()
for i in range(num_steps):
params = update_params(params, dataset.sample())
chex.assert_tree_all_finite(params,
custom_message=f'Failed at iteration {i}.',
exception_type=ValueError)

will raise a ValueError that includes a step number when params get polluted with NaNs or Nones.

Static and Value (aka *Runtime*) Assertions

Chex divides all assertions into 2 classes: *static* and *value* assertions.

1. *static* assertions use anything except concrete values of tensors. Examples: assert_shape, assert_trees_all_equal_dtypes, assert_max_traces.

2. *value* assertions require access to tensor values, which are not available during JAX tracing (see HowJAX primitives work), thus such assertion need special treatment in a *jitted* code.

To enable value assertions in a jitted function, it can be decorated with chex.chexify() wrapper. Example:

@chex.chexify
@jax.jit
def logp1_abs_safe(x: chex.Array) -> chex.Array:
chex.assert_tree_all_finite(x)
return jnp.log(jnp.abs(x) + 1)

logp1_abs_safe(jnp.ones(2)) # OK
logp1_abs_safe(jnp.array([jnp.nan, 3])) # FAILS (in async mode)

# The error will be raised either at the next line OR at the next
# `logp1_abs_safe` call. See the docs for more detain on async mode.
logp1_abs_safe.wait_checks() # Wait for the (async) computation to complete.

See [this…

Excerpt shown — open the source for the full document.