google-deepmind/xarray_jax
Python
Captured source
source ↗google-deepmind/xarray_jax
Language: Python
License: Apache-2.0
Stars: 45
Forks: 8
Open issues: 1
Created: 2025-06-19T14:20:30Z
Pushed: 2026-06-08T17:23:17Z
Default branch: main
Fork: no
Archived: no
README:
xarray_jax
This library supports using xarray datatypes together with the JAX library.
Installation
To install a specific tagged release:
pip install git+https://github.com/google-deepmind/xarray_jax.git@v0.1.0
What is xarray_jax?
JAX is a high-performance library designed for numerical computation and machine learning, operating primarily on arrays which can be contained in generic tree-like data structures known as PyTrees.
Standard xarray datatypes (DataArray, Dataset etc) are able to contain jax arrays, however they are not natively registered with jax as PyTree nodes, preventing their direct use within JAX's core transformations (jit, grad, vmap, etc.).
This library solves that problem. It registers xarray data structures as custom JAX PyTrees. This allows JAX to seamlessly flatten xarray objects into their raw arrays for accelerated computation and then unflatten the results back into fully labeled xarray objects, preserving critical metadata like dimension names and coordinates.
Quick Start
Here's a minimal example showing how to apply JAX's just-in-time (JIT) compilation to a function that operates directly on an xarray.DataArray containing jax.numpy data.
*Example:*
import jax
import jax.numpy as jnp
import xarray as xr
import numpy as np
import xarray_jax
# 1. Create a standard xarray.DataArray with JAX data
# Use np.datetime64 for time coordinates as recommended practice
temperature = xr.DataArray(
data=jnp.array([[20.5, 21.2, 22.1],
[18.3, 19.7, 20.8]]),
dims=('time', 'location'),
coords={
'time': np.array(['2023-01-01', '2023-01-02'], dtype='datetime64[D]'),
'location': ['NYC', 'LA', 'Chicago']
}
)
# 2. Define a pure Python function that works with the DataArray
# This function can use standard xarray methods like .mean()
def process_temperature(temp_data):
fahrenheit = temp_data * 9/5 + 32
return fahrenheit.mean(dim='location')
# 3. JIT-compile the function using jax.jit
# This would fail without xarray_jax registering xarray objects as PyTrees
jitted_process = jax.jit(process_temperature)
# 4. Execute the compiled function with the xarray object
result = jitted_process(temperature)
print("JIT compilation successful!")
print(result)
# Expected Output:
#
# array([70.28 , 67.27999 ], dtype=float32)
# Coordinates:
# * time (time) datetime64[ns] 2023-01-01 2023-01-02Using xarray in JAX Transformations
The primary use case for xarray_jax is enabling JAX transformations on functions that work with xarray data structures. Since JAX arrays are now directly supported by xarray, you can create xarray objects with JAX arrays and use them directly.
Computing Gradients with jax.grad
You can compute gradients through functions operating on xarray.DataArray objects, provided the function returns a raw JAX scalar. Use xarray_jax.jax_data to extract the underlying JAX array when needed.
*Example :*
# Assume 'temperature' DataArray exists from the Quick Start example # Define a function that returns a scalar for gradient computation def temperature_loss(temp_data): target = 20.0 # Simple loss: penalize deviation from 20°C # Note: The differentiable loss function must return a raw JAX scalar loss_xarray = ((temp_data - target) ** 2).sum() return loss_xarray.data # Extract JAX scalar # Compute gradients with respect to the temperature data grad_fn = jax.grad(temperature_loss) gradients = grad_fn(temperature) # Pass the original xarray object print(gradients) # Expected Output: # # array([[ 1. , 2.4000015, 4.200001 ], # [-3.4000015, -0.5999985, 1.5999985]], dtype=float32) # Coordinates: # * time (time) datetime64[ns] 2023-01-01 2023-01-02 # * location (location) # array([[96.15 , 99.7 ], # [92.44 , 96.9600067]], dtype=float32) # Coordinates: # * time (time) # array([ 2., 6., 12., 20., 30.], dtype=float32) # Coordinates: # * time (time) int32 0 1 2 3 4
Notability
notability 5.0/10New utility repo from DeepMind, low traction.