ForkCohereCoherepublished May 15, 2024seen 6d

cohere-ai/jax

forked from jax-ml/jax

Open original ↗

Captured source

source ↗
published May 15, 2024seen 6dcaptured 9hhttp 200method plain

cohere-ai/jax

Description: Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

License: Apache-2.0

Stars: 0

Forks: 0

Open issues: 0

Created: 2024-05-15T18:32:23Z

Pushed: 2024-05-15T18:20:10Z

Default branch: main

Fork: yes

Parent repository: jax-ml/jax

Archived: no

README:

JAX: Autograd and XLA

!Continuous integration

[Quickstart](#quickstart-colab-in-the-cloud) | [Transformations](#transformations) | [Install guide](#installation) | [Neural net libraries](#neural-network-libraries) | **Change logs** | **Reference docs**

What is JAX?

JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning.

With its updated version of Autograd, JAX can automatically differentiate native Python and NumPy functions. It can differentiate through loops, branches, recursion, and closures, and it can take derivatives of derivatives of derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation) via [grad](#automatic-differentiation-with-grad) as well as forward-mode differentiation, and the two can be composed arbitrarily to any order.

What’s new is that JAX uses XLA to compile and run your NumPy programs on GPUs and TPUs. Compilation happens under the hood by default, with library calls getting just-in-time compiled and executed. But JAX also lets you just-in-time compile your own Python functions into XLA-optimized kernels using a one-function API, [jit](#compilation-with-jit). Compilation and automatic differentiation can be composed arbitrarily, so you can express sophisticated algorithms and get maximal performance without leaving Python. You can even program multiple GPUs or TPU cores at once using [pmap](#spmd-programming-with-pmap), and differentiate through the whole thing.

Dig a little deeper, and you'll see that JAX is really an extensible system for [composable function transformations](#transformations). Both [grad](#automatic-differentiation-with-grad) and [jit](#compilation-with-jit) are instances of such transformations. Others are [vmap](#auto-vectorization-with-vmap) for automatic vectorization and [pmap](#spmd-programming-with-pmap) for single-program multiple-data (SPMD) parallel programming of multiple accelerators, with more to come.

This is a research project, not an official Google product. Expect bugs and sharp edges. Please help by trying it out, reporting bugs, and letting us know what you think!

import jax.numpy as jnp
from jax import grad, jit, vmap

def predict(params, inputs):
for W, b in params:
outputs = jnp.dot(inputs, W) + b
inputs = jnp.tanh(outputs) # inputs to the next layer
return outputs # no activation on last layer

def loss(params, inputs, targets):
preds = predict(params, inputs)
return jnp.sum((preds - targets)**2)

grad_loss = jit(grad(loss)) # compiled gradient evaluation function
perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example grads

Contents

  • [Quickstart: Colab in the Cloud](#quickstart-colab-in-the-cloud)
  • [Transformations](#transformations)
  • [Current gotchas](#current-gotchas)
  • [Installation](#installation)
  • [Neural net libraries](#neural-network-libraries)
  • [Citing JAX](#citing-jax)
  • [Reference documentation](#reference-documentation)

Quickstart: Colab in the Cloud

Jump right in using a notebook in your browser, connected to a Google Cloud GPU. Here are some starter notebooks:

JAX now runs on Cloud TPUs. To try out the preview, see the Cloud TPU Colabs.

For a deeper dive into JAX:

notebooks](https://github.com/google/jax/tree/main/docs/notebooks).

Transformations

At its core, JAX is an extensible system for transforming numerical functions. Here are four transformations of primary interest: grad, jit, vmap, and pmap.

Automatic differentiation with grad

JAX has roughly the same API as Autograd. The most popular function is `grad` for reverse-mode gradients:

from jax import grad
import jax.numpy as jnp

def tanh(x): # Define a function
y = jnp.exp(-2.0 * x)
return (1.0 - y) / (1.0 + y)

grad_tanh = grad(tanh) # Obtain its gradient function
print(grad_tanh(1.0)) # Evaluate it at x = 1.0
# prints 0.4199743

You can differentiate to any order with grad.

print(grad(grad(grad(tanh)))(1.0))
# prints 0.62162673

For more advanced autodiff, you can use `jax.vjp` for reverse-mode vector-Jacobian products and `jax.jvp` for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. Here's one way to compose those to make a function that efficiently computes full Hessian matrices:

from jax import jit, jacfwd, jacrev

def hessian(fun):
return jit(jacfwd(jacrev(fun)))

As with Autograd, you're free to use differentiation with Python control structures:

def abs_val(x):
if x > 0:
return x
else:
return -x

abs_val_grad = grad(abs_val)…

Excerpt shown — open the source for the full document.