google-deepmind/chex
Python
Captured source
source ↗google-deepmind/chex
Language: Python
License: Apache-2.0
Stars: 944
Forks: 69
Open issues: 70
Created: 2020-08-06T09:32:36Z
Pushed: 2026-06-10T13:59:50Z
Default branch: main
Fork: no
Archived: no
README:
Chex
Chex is a library of utilities for helping to write reliable JAX code.
This includes utils to help:
- Instrument your code (e.g. assertions, warnings)
- Debug (e.g. transforming
pmapsinvmapswithin a context manager). - Test JAX code across many
variants(e.g. jitted vs non-jitted).
Installation
You can install the latest released version of Chex from PyPI via:
pip install chex
or you can install the latest development version from GitHub:
pip install git+https://github.com/deepmind/chex.git
Modules Overview
Dataclass (dataclass.py)
Dataclasses are a popular construct introduced by Python 3.7 to allow to easily specify typed data structures with minimal boilerplate code. They are not, however, compatible with JAX and dm-tree out of the box.
In Chex we provide a JAX-friendly dataclass implementation reusing python dataclasses.
Chex implementation of dataclass registers dataclasses as internal _PyTree_ nodes to ensure compatibility with JAX data structures.
In addition, we provide a class wrapper that exposes dataclasses as collections.Mapping descendants which allows to process them (e.g. (un-)flatten) in dm-tree methods as usual Python dictionaries. See `@mappable_dataclass` docstring for more details.
Example:
@chex.dataclass class Parameters: x: chex.ArrayDevice y: chex.ArrayDevice parameters = Parameters( x=jnp.ones((2, 2)), y=jnp.ones((1, 2)), ) # Dataclasses can be treated as JAX pytrees jax.tree_util.tree_map(lambda x: 2.0 * x, parameters) # and as mappings by dm-tree tree.flatten(parameters)
NOTE: Unlike standard Python 3.7 dataclasses, Chex dataclasses cannot be constructed using positional arguments. They support construction arguments provided in the same format as the Python dict constructor. Dataclasses can be converted to tuples with the from_tuple and to_tuple methods if necessary.
parameters = Parameters( jnp.ones((2, 2)), jnp.ones((1, 2)), ) # ValueError: Mappable dataclass constructor doesn't support positional args.
Assertions (asserts.py)
One limitation of PyType annotations for JAX is that they do not support the specification of DeviceArray ranks, shapes or dtypes. Chex includes a number of functions that allow flexible and concise specification of these properties.
E.g. suppose you want to ensure that all tensors t1, t2, t3 have the same shape, and that tensors t4, t5 have rank 2 and (3 or 4), respectively.
chex.assert_equal_shape([t1, t2, t3])
chex.assert_rank([t4, t5], [2, {3, 4}])More examples:
from chex import assert_shape, assert_rank, ...
assert_shape(x, (2, 3)) # x has shape (2, 3)
assert_shape([x, y], [(), (2,3)]) # x is scalar and y has shape (2, 3)
assert_rank(x, 0) # x is scalar
assert_rank([x, y], [0, 2]) # x is scalar and y is a rank-2 array
assert_rank([x, y], {0, 2}) # x and y are scalar OR rank-2 arrays
assert_type(x, int) # x has type `int` (x can be an array)
assert_type([x, y], [int, float]) # x has type `int` and y has type `float`
assert_equal_shape([x, y, z]) # x, y, and z have equal shapes
assert_trees_all_close(tree_x, tree_y) # values and structure of trees match
assert_tree_all_finite(tree_x) # all tree_x leaves are finite
assert_devices_available(2, 'gpu') # 2 GPUs available
assert_tpu_available() # at least 1 TPU available
assert_numerical_grads(f, (x, y), j) # f^{(j)}(x, y) matches numerical gradsSee asserts.py documentation to find all supported assertions.
If you cannot find a specific assertion, please consider making a pull request or openning an issue on the bug tracker.
Optional Arguments
All chex assertions support the following optional kwargs for manipulating the emitted exception messages:
custom_message: A string to include into the emitted exception messages.include_default_message: Whether to include the default Chex message into
the emitted exception messages.
exception_type: An exception type to use.AssertionErrorby default.
For example, the following code:
dataset = load_dataset()
params = init_params()
for i in range(num_steps):
params = update_params(params, dataset.sample())
chex.assert_tree_all_finite(params,
custom_message=f'Failed at iteration {i}.',
exception_type=ValueError)will raise a ValueError that includes a step number when params get polluted with NaNs or Nones.
Static and Value (aka *Runtime*) Assertions
Chex divides all assertions into 2 classes: *static* and *value* assertions.
1. *static* assertions use anything except concrete values of tensors. Examples: assert_shape, assert_trees_all_equal_dtypes, assert_max_traces.
2. *value* assertions require access to tensor values, which are not available during JAX tracing (see HowJAX primitives work), thus such assertion need special treatment in a *jitted* code.
To enable value assertions in a jitted function, it can be decorated with chex.chexify() wrapper. Example:
@chex.chexify @jax.jit def logp1_abs_safe(x: chex.Array) -> chex.Array: chex.assert_tree_all_finite(x) return jnp.log(jnp.abs(x) + 1) logp1_abs_safe(jnp.ones(2)) # OK logp1_abs_safe(jnp.array([jnp.nan, 3])) # FAILS (in async mode) # The error will be raised either at the next line OR at the next # `logp1_abs_safe` call. See the docs for more detain on async mode. logp1_abs_safe.wait_checks() # Wait for the (async) computation to complete.
See [this…
Excerpt shown — open the source for the full document.