google-deepmind/tf2jax

Python

Open original ↗

Captured source

source ↗
published Mar 2, 2022seen 5dcaptured 13hhttp 200method plain

google-deepmind/tf2jax

Language: Python

License: Apache-2.0

Stars: 124

Forks: 20

Open issues: 21

Created: 2022-03-02T20:22:24Z

Pushed: 2026-04-30T12:08:16Z

Default branch: main

Fork: no

Archived: no

README:

TF2JAX

TF2JAX is an experimental library for converting [TensorFlow] functions/graphs to [JAX] functions.

Specifically, it aims to transform a tf.function, e.g.

@tf.function
def tf_fn(x):
return tf.sin(tf.cos(x))

to a python function equivalent to the following JAX code.

def jax_fn(x):
return jnp.sin(jnp.cos(x))

Users are able to apply additional JAX transforms (e.g. jit, grad, vmap, make_jaxpr, etc.) to the converted function as they would any other code written in JAX.

[TOC]

Installation

You can install the latest released version of TF2JAX from PyPI via:

pip install tf2jax

or you can install the latest development version from GitHub:

pip install git+https://github.com/google-deepmind/tf2jax.git

Motivations

TF2JAX enables existing TensorFlow functions and models (including SavedModel and TensorFlow Hub) to be reused and/or fine-tuned within JAX codebases. The conversion process is transparent to the users, which is useful for debugging and introspection.

This also provide a pathway for JAX users to integrate JAX functions serialized via jax2tf.convert, back into their existing JAX codebases.

See [section](#alternatives) at the end for comparison with an alternative approach provided by jax2tf.call_tf.

Disclaimer

This is experimental code with potentially unstable API, and there are no guarantees for using it at this point in time. We highly recommend you thoroughly test the resulting JAX functions to ensure they meet your requirements.

Quick start

The rest of this document assumes the following imports:

import jax
import jax.numpy as jnp
import numpy as np
import tensorflow as tf # Assumes this is v2.
import tf2jax

An example using the convert API and the Sonnet v2 MLP.

import sonnet.v2 as snt

model = snt.nets.MLP((64, 10,))

@tf.function
def forward(x):
return model(x)

x = np.random.normal(size=(128, 16)).astype(np.float32)

# TF -> JAX, jax_params are the network parameters of the MLP
jax_func, jax_params = tf2jax.convert(forward, np.zeros_like(x))

# Call JAX, also return updated jax_params (e.g. variable, batchnorm stats)
jax_outputs, jax_params = jax_func(jax_params, x)

tf2jax.convert has the signature convert(fn: tf.Function, *args, **kwargs), where fn(*args, **kwargs) is used to trace the function fn and generates the corresponding tf.GraphDef. The zeros_like is not necessary, only used here to demonstrate the JAX function is not memorizing the outputs.

Example with a pure function

If your function is pure, i.e. it does not capture any variables, then you can drop the parameters from the inputs and outputs of the converted function with tf2jax.convert_functional.

@tf.function
def forward(x):
return tf.sin(tf.cos(x))

jax_func = tf2jax.convert_functional(forward, np.zeros_like(x))
jax_outputs = jax_func(x)

Randomness and PRNG Keys

A TensorFlow function that make use of random ops will be converted to a JAX function that takes a PRNG key as a keyword-only argument. TF2JAX will complain loudly if a PRNG key is required but not provided.

jax_outputs, jax_params = jax_func(jax_params, x, rng=jax.random.PRNGKey(42))

Custom Gradient

Custom gradient support is highly experimental, please report any errors.

@tf.function
@tf.custom_gradient
def forward(x):
e = tf.exp(x)
def grad(dy):
return dy * tf.sin(x) + e # # This is deliberately the wrong gradient.
return tf.reduce_sum(e), grad

with tf2jax.override_config("convert_custom_gradient", True):
jax_func = tf2jax.convert_functional(forward, np.zeros_like(x))

jax_grads = jax.grad(jax_func)(x)

Support for Serialization Formats

SavedModel

SavedModel is the preferred format for serializing TF2 functions.

model = tf.Module()
model.f = forward
model.f(x) # Dummy call.
tf.saved_model.save(model, "/tmp/blah")

restored = tf.saved_model.load("/tmp/blah")
jax_func, jax_params = tf2jax.convert(restored.f, np.zeros_like(x))

If the restored function has an unambiguous signature, i.e. it was only traced once prior to export. Then TF2JAX can convert the function directly from its GraphDef without tracing it again.

jax_func, jax_params = tf2jax.convert_from_restored(restored.f)

TF-Hub

The (legacy, TF1) TF-Hub format is supported with minor boilerplate.

import tensorflow_hub as hub

hub_model = hub.load("/tmp/blah")
jax_func, jax_params = tf2jax.convert(tf.function(hub_model), tf.zeros_like(x))
jax_outputs, updated_jax_params = jax_func(jax_params, x)

JAX to TensorFlow and back again.

tf2jax.convert_functional can convert the outputs of jax2tf.convert back into JAX code.

# Some JAX function.
def forward(*inputs):
...

# JAX -> TF
tf_func = jax2tf.convert(forward)

# JAX -> TF -> JAX
jax_func = tf2jax.convert_functional(tf.function(tf_func), *tree.map_structure(np.zeros_like, inputs))

# JAX -> TF -> SavedModel -> TF
model = tf.Module()
model.f = tf.function(tf_func)
model.f(*tree.map_structure(tf.zeros_like, inputs)) # Dummy call.
tf.saved_model.save(model, "/tmp/blah")
restored = tf.saved_model.load("/tmp/blah")

# JAX -> TF -> SavedModel -> TF -> JAX
jax_too_func = tf2jax.convert_functional(restored.f, *tree.map_structure(np.zeros_like, inputs))

Additional Configuration

The behaviour of TF2JAX can be configured globally via tf2jax.update_config, or configured locally via the context manager tf2jax.override_config.

Strict shape and dtype checking

By default, TF2JAX will assert that the input shapes to the converted function are compatible with the input shapes of the original function. This is because some functions have shape dependent behaviours that will silently return the incorrect outputs after conversion, e.g. some batchnorm implementation.

jax_func = tf2jax.convert_functional(forward, np.zeros((10, 5), np.float32))

# This will raise an error.
jax_func(np.zeros((20, 5), np.float32))

# This will not.
with…

Excerpt shown — open the source for the full document.