google-deepmind/dm_pix
Python
Captured source
source ↗google-deepmind/dm_pix
Description: PIX is an image processing library in JAX, for JAX.
Language: Python
License: Apache-2.0
Stars: 440
Forks: 28
Open issues: 2
Created: 2021-06-30T16:25:50Z
Pushed: 2026-06-02T11:30:40Z
Default branch: master
Fork: no
Archived: no
README:
PIX
PIX is an image processing library in [JAX], for [JAX].
Overview
[JAX] is a library resulting from the union of [Autograd] and [XLA] for high-performance machine learning research. It provides [NumPy], [SciPy], automatic differentiation and first-class GPU/TPU support.
PIX is a library built on top of JAX with the goal of providing image processing functions and tools to JAX in a way that they can be optimised and parallelised through [jax.jit][jit], [jax.vmap][vmap] and [jax.pmap][pmap].
Installation
PIX is written in pure Python, but depends on C++ code via JAX.
Because JAX installation is different depending on your CUDA version, PIX does not list JAX as a dependency in [pyproject.toml], although it is technically listed for reference, but commented.
First, follow [JAX installation instructions] to install JAX with the relevant accelerator support.
Then, install PIX using pip:
$ pip install dm-pix
Quickstart
To use PIX, you just need to import dm_pix as pix and use it right away!
For example, let's assume to have loaded the JAX logo (available in examples/assets/jax_logo.jpg) in a variable called image and we want to flip it left to right.
![JAX logo]
All it's needed is the following code!
import dm_pix as pix # Load an image into a NumPy array with your preferred library. image = load_image() flip_left_right_image = pix.flip_left_right(image)
And here is the result!
![JAX logo left-right]
All the functions in PIX can be [jax.jit][jit]ed, [jax.vmap][vmap]ed and [jax.pmap][pmap]ed, so all the following functions can take advantage of optimization and parallelization.
import dm_pix as pix import jax # Load an image into a NumPy array with your preferred library. image = load_image() # Vanilla Python function. flip_left_right_image = pix.flip_left_right(image) # `jax.jit`ed function. flip_left_right_image = jax.jit(pix.flip_left_right)(image) # Assuming to have a single device, like a CPU or a single GPU, we add a # single leading dimension for using `image` with the parallelized or # the multi-device parallelization version of `pix.flip_left_right`. # To know more, please refer to JAX documentation of `jax.vmap` and `jax.pmap`. image = image[np.newaxis, ...] # `jax.vmap`ed function. flip_left_right_image = jax.vmap(pix.flip_left_right)(image) # `jax.pmap`ed function. flip_left_right_image = jax.pmap(pix.flip_left_right)(image)
You can check it yourself that the result from the four versions of pix.flip_left_right is the same (up to the accelerator floating point accuracy)!
Examples
We have a few examples in the [examples/] folder. They are not much more involved then the previous example, but they may be a good starting point for you!
Testing
We provide a suite of tests to help you both testing your development environment and to know more about the library itself! All test files have _test suffix, and can be executed using pytest.
If you already have PIX installed, you just need to install some extra dependencies and run pytest as follows:
$ pip install -e ".[test]" $ python -m pytest [-n ] dm_pix
If you want an isolated virtual environment, you just need to run our utility bash script as follows:
$ ./test.sh
Citing PIX
This repository is part of the [DeepMind JAX Ecosystem], to cite PIX please use the [DeepMind JAX Ecosystem citation].
Contribute!
We are very happy to accept contributions!
Please read our [contributing guidelines](./CONTRIBUTING.md) and send us PRs!
[Autograd]: https://github.com/hips/autograd "Autograd on GitHub" [DeepMind JAX Ecosystem]: https://deepmind.com/blog/article/using-jax-to-accelerate-our-research "DeepMind JAX Ecosystem" [DeepMind JAX Ecosystem citation]: https://github.com/deepmind/jax/blob/main/deepmind2020jax.txt "Citation" [JAX]: https://github.com/jax-ml/jax "JAX on GitHub" [JAX installation instructions]: https://github.com/jax-ml/jax#installation "JAX installation" [jit]: https://jax.readthedocs.io/en/latest/jax.html#jax.jit "jax.jit documentation" [NumPy]: https://numpy.org/ "NumPy" [pmap]: https://jax.readthedocs.io/en/latest/jax.html#jax.pmap "jax.pmap documentation" [SciPy]: https://www.scipy.org/ "SciPy" [XLA]: https://www.tensorflow.org/xla "XLA" [vmap]: https://jax.readthedocs.io/en/latest/jax.html#jax.vmap "jax.vmap documentation"
[examples/]: ./examples/ [JAX logo]: ./examples/assets/jax_logo.jpg [JAX logo left-right]: ./examples/assets/flip_left_right_jax_logo.jpg [pyproject.toml]: ./pyproject.toml