google-deepmind/dm_pix

Python

Open original ↗

Captured source

source ↗
published Jun 30, 2021seen 5dcaptured 9hhttp 200method plain

google-deepmind/dm_pix

Description: PIX is an image processing library in JAX, for JAX.

Language: Python

License: Apache-2.0

Stars: 440

Forks: 28

Open issues: 2

Created: 2021-06-30T16:25:50Z

Pushed: 2026-06-02T11:30:40Z

Default branch: master

Fork: no

Archived: no

README:

PIX

PIX is an image processing library in [JAX], for [JAX].

Overview

[JAX] is a library resulting from the union of [Autograd] and [XLA] for high-performance machine learning research. It provides [NumPy], [SciPy], automatic differentiation and first-class GPU/TPU support.

PIX is a library built on top of JAX with the goal of providing image processing functions and tools to JAX in a way that they can be optimised and parallelised through [jax.jit][jit], [jax.vmap][vmap] and [jax.pmap][pmap].

Installation

PIX is written in pure Python, but depends on C++ code via JAX.

Because JAX installation is different depending on your CUDA version, PIX does not list JAX as a dependency in [pyproject.toml], although it is technically listed for reference, but commented.

First, follow [JAX installation instructions] to install JAX with the relevant accelerator support.

Then, install PIX using pip:

$ pip install dm-pix

Quickstart

To use PIX, you just need to import dm_pix as pix and use it right away!

For example, let's assume to have loaded the JAX logo (available in examples/assets/jax_logo.jpg) in a variable called image and we want to flip it left to right.

![JAX logo]

All it's needed is the following code!

import dm_pix as pix

# Load an image into a NumPy array with your preferred library.
image = load_image()

flip_left_right_image = pix.flip_left_right(image)

And here is the result!

![JAX logo left-right]

All the functions in PIX can be [jax.jit][jit]ed, [jax.vmap][vmap]ed and [jax.pmap][pmap]ed, so all the following functions can take advantage of optimization and parallelization.

import dm_pix as pix
import jax

# Load an image into a NumPy array with your preferred library.
image = load_image()

# Vanilla Python function.
flip_left_right_image = pix.flip_left_right(image)

# `jax.jit`ed function.
flip_left_right_image = jax.jit(pix.flip_left_right)(image)

# Assuming to have a single device, like a CPU or a single GPU, we add a
# single leading dimension for using `image` with the parallelized or
# the multi-device parallelization version of `pix.flip_left_right`.
# To know more, please refer to JAX documentation of `jax.vmap` and `jax.pmap`.
image = image[np.newaxis, ...]

# `jax.vmap`ed function.
flip_left_right_image = jax.vmap(pix.flip_left_right)(image)

# `jax.pmap`ed function.
flip_left_right_image = jax.pmap(pix.flip_left_right)(image)

You can check it yourself that the result from the four versions of pix.flip_left_right is the same (up to the accelerator floating point accuracy)!

Examples

We have a few examples in the [examples/] folder. They are not much more involved then the previous example, but they may be a good starting point for you!

Testing

We provide a suite of tests to help you both testing your development environment and to know more about the library itself! All test files have _test suffix, and can be executed using pytest.

If you already have PIX installed, you just need to install some extra dependencies and run pytest as follows:

$ pip install -e ".[test]"
$ python -m pytest [-n ] dm_pix

If you want an isolated virtual environment, you just need to run our utility bash script as follows:

$ ./test.sh

Citing PIX

This repository is part of the [DeepMind JAX Ecosystem], to cite PIX please use the [DeepMind JAX Ecosystem citation].

Contribute!

We are very happy to accept contributions!

Please read our [contributing guidelines](./CONTRIBUTING.md) and send us PRs!

[Autograd]: https://github.com/hips/autograd "Autograd on GitHub" [DeepMind JAX Ecosystem]: https://deepmind.com/blog/article/using-jax-to-accelerate-our-research "DeepMind JAX Ecosystem" [DeepMind JAX Ecosystem citation]: https://github.com/deepmind/jax/blob/main/deepmind2020jax.txt "Citation" [JAX]: https://github.com/jax-ml/jax "JAX on GitHub" [JAX installation instructions]: https://github.com/jax-ml/jax#installation "JAX installation" [jit]: https://jax.readthedocs.io/en/latest/jax.html#jax.jit "jax.jit documentation" [NumPy]: https://numpy.org/ "NumPy" [pmap]: https://jax.readthedocs.io/en/latest/jax.html#jax.pmap "jax.pmap documentation" [SciPy]: https://www.scipy.org/ "SciPy" [XLA]: https://www.tensorflow.org/xla "XLA" [vmap]: https://jax.readthedocs.io/en/latest/jax.html#jax.vmap "jax.vmap documentation"

[examples/]: ./examples/ [JAX logo]: ./examples/assets/jax_logo.jpg [JAX logo left-right]: ./examples/assets/flip_left_right_jax_logo.jpg [pyproject.toml]: ./pyproject.toml