Introduction to the JAX library for ML in Python

JAX (Just After eXecution) is a recent machine learning library used for expressing and composing numerical programs. JAX is able to compile numerical programs for the CPU and even accelerators like GPU and TPU to generate optimized code all while using pure python. JAX works great for machine-learning programs because of the familiarity of Python and NumPy together with hardware acceleration. This is great for the definition and composition of user-wielded function transformations. These transformations include automatic differentiation, automatic batching, end-to-end compilation (via XLA), parallelizing over multiple accelerators, and more. Researchers use it for a wide range of advanced applications, from studying training dynamics of neural networks to developing Machine Learning solutions, to probabilistic programming, to developing accelerated numerical code, and to scientific applications in physics and biology. Various tests have shown that JAX can perform up to 8600% faster when used for basic functions. This is highly valuable for data-heavy application-facing models, or just for getting more machine learning experiments done in a day.

Already understand why you want to use JAX? Jump forward to the code!

Some of its vital features are:

  • Just-in-Time (JIT) compilation.

  • Enabling NumPy code on not only CPUs but GPUs and TPUs as well.

  • Automatic differentiation of both NumPy and native Python code.

  • Automatic vectorization.

  • Expressing and composing transformations of numerical programs.

  • An Advanced (pseudo) random number generation.

  • There are more options for control flow.

JAX’s popularity is rising in the deep-learning industry because of its speed it is used increasingly in machine learning programs and accelerating research. JAX provides a general foundation for high-performance scientific computing this is useful in many various fields and instances, not just deep learning. Even if most of your work is not in Python but if you want to build some sort of hybrid model-based / neural-network system, then it is probably worth it to use JAX going forward. If most of your work is not in Python, or you’re using some specialized software for your studies (thermodynamics, semiconductors, etc.) then JAX probably isn’t the tool for you, unless you want to export data from these programs for some sort of custom computational processing. Suppose your area of interest is closer to physics/mathematics and incorporates computational methods (dynamical systems, differential geometry, statistical physics) and most of your work is in e.g. Mathematica. In that case, it’s probably worth it to stick with what you’re using, especially if you have a large custom codebase.

Getting started with JAX

You can follow along in this Jupyter Notebook, here we install JAX easily with pip in our command line:

This however supports CPU only which is useful for local development. If you want both CPU and GPU support you should first install CUDA and CuDNN if not already installed. Also, make sure to map the jaxlib version with the CUDA version you have.

Here is the JAX installation Github Guide for more installation options and troubleshooting.

We will import both JAX and Numpy into our notebook link here for a comparison of different use cases:

Why use JAX?

Accelerated Linear Algebra (XLA compiler) — XLA is a domain-specific compiler for linear algebra that has been used extensively by Tensorflow and is one of the factors that make JAX so fast. In order to perform matrix operations as fast as possible, the code is compiled into a set of computation kernels that can be extensively optimized based on the nature of the code.

Examples of such optimizations include:

  • Fusion of operations: Intermediate results are not saved in memory

  • Optimized layout: Optimize the “shape” an array is represented in memory

Just-in-time compilation to speed up functions — Just-in-time compilation is a way of executing code that entails the compilation of the code at run time rather than before the execution. Just-in-time compilation comes with Accelerated Linear Algebra (XLA compiler). If we have a sequence of operations, the @jit decorator comes into play to compile multiple operations together using XLA. In order to use XLA and jit, one can use either the jit() function or @jit decorator.

Using the timeit command we can see the improvement in execution time is quite clear. We use block_until_ready because JAX uses asynchronous execution by default. Although this is incredibly useful in deep learning, jit is not without limitation. One of its flaws is when you use “if” statements in your function jit may likely be unable to represent your function accurately.

Auto differentiation with grad() function

As well as evaluating numerical functions, we also want to transform them. One transformation is automatic differentiation. JAX is able to differentiate through all sorts of python and NumPy functions including loops, branches, recursions, and more. This is very useful in deep learning as backpropagation becomes very easy.

In the example below, we define a simple quadratic function and take its derivative on point 1.0. We will find the derivative manually as well In order to prove that the result it’s correct.

There is so much more to doing auto differentiation with JAX, if you are interested in its full capabilities, you can find more about it in the official documentation.

Auto-vectorization with vmap

Another transformation in JAX’s API that you might find useful is vmap(), the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a function’s primitive operations for better performance. When composed with jit(), it can be just as fast as adding the batch dimensions beforehand. In the example below we will take a function that operates on a single data point and vectorize it so it can accept a batch of these data points (or a vector) of arbitrary size.

vmap batches all the values together and passes them through the function so it squares all values at once. When d(x) run without vmap the square of each value was computed one at a time and the result was appended to the list. Needless to say, this results in an increase both in speed and memory consumption.

Replicate computation across devices with pmap

Pmap is another transformation that enables us to replicate the computation into multiple cores or devices and execute them in parallel. It automatically distributes computation across all the current devices and handles all the communication between them. You can run jax.devices() to check out the available devices.

Notice that DeviceArray is now SharedDeviceArray this is the structure that handles the parallel execution. JAX supports collective communication between devices. If for example, we want to perform an operation on values from different devices. To perform this, we need to gather all the data from all devices and find the mean.

The function above collects all “x” from the devices, finds the mean, and returns the result to each device to continue with the parallel computation. The code above however will not run unless you have more than one device communicating with each other to have the parallel computation. With pmap, we can define our own computation patterns and exploit our devices in the best possible way.

Control flows

In Python programming, the order in which the program’s code is executed at runtime is called control flow. The control flow of a Python program is regulated by conditional statements, loops, and function calls.

Python has three types of control structures:

  • Sequential — whose execution process happens in a sequence.

  • Selection — used for decisions and branching, i.e., if, if-else statements

  • Repetition — used for looping, i.e., repeating a piece of code multiple times.

Control flow with autodiff

When using grad in your python functions you can use regular python control-flow structures with no problems, as if you were using Autograd (or Pytorch or TF Eager).

Control flow with jit

Control flow with jit however is more complicated, and by default, it has more constraints.

When jit-compiling a function we want to compile a function that can be cached and reused for many different argument values. To get a view of your Python code that is valid for many different argument values, JAX traces it on abstract values that represent sets of possible inputs. There are multiple different levels of abstraction and different transformations which use different abstraction levels. If we trace using the abstract value we get a view of the function that can be reused for any concrete value in the corresponding functions (e.g while working on different sets of arrays) which means we can save on compile time.

The function being traced above isn’t committed to a specific concrete value. In the line with if x < 3 this expression x < 3 is a boolean. When Python attempts to coerce that to a concrete True or False, we get an error: we don’t know which branch to take, and can’t continue tracing. You can relax the traceability constraints by having jit trace on more refined abstract values. We could use the static_argnums argument to jit, we can specify to trace on concrete values of some arguments.

Asynchronous dispatch

Essentially what this means is control is returned to the python program even before operations are complete. It instead returns a DeviceArray which is a future, i.e., a value that will be produced in the future on an accelerator device but isn’t necessarily available immediately. The future can be passed to other operations before the computation is completed. Thus JAX allows Python code to run ahead of the accelerator, ensuring that it can enqueue operations for the hardware accelerator without it having to wait.

Pseudo-Random number generator (PRNG)

A random number generator has a state. The following “random” number is a function of the previous number and the state. The sequence of random values is limited and does repeat. Instead of a typical stateful PseudoRandom Number Generator (PRNGs) as in Numpy and Scipy, JAX random functions require a PRNG state to be passed explicitly as a first argument.

Something to note is PRNGs work well when dealing with vectorization and parallel computation between devices.

JAX vs NumPy

  • Accelerator Devices — The differences between NumPy and JAX can be seen in relation to accelerator devices, such as GPUs and TPUs. Classic NumPy’s promotion rules are too willing to overpromote to 64-bit types, which is problematic for a system designed to run on accelerators. JAX uses floating-point promotion rules that are more suited to modern accelerator devices and are less aggressive about promoting floating-point types similar to Pytorch.

  • Control Behavior — When performing unsafe type casts JAX’s behavior may be backend dependent, and in general, may diverge from NumPy’s behavior. Numpy allows control over the result in these scenarios via the casting argument JAX does not provide any such configuration, instead directly inheriting the behavior of XLA: ConvertElementType.

  • Arrays — JAX’s array update functions, unlike their NumPy versions, operate out-of-place. That is, the updated array is returned as a new array and the original array is not modified by the update.

  • Inputs — NumPy is generally happy accepting Python lists or tuples as inputs to its API functions JAX however returns an error. This is deliberate because passing lists or tuples to traced functions can lead to silent performance degradation that might otherwise be difficult to detect.

Conclusion

I have briefly covered what makes JAX a great library and it promises to make ML programming more intuitive, structured, and clean. There is so much more to this library that we haven’t covered go ahead and explore more in-depth uses of JAX. You can learn more from its documentation here.

JAX also provides a whole ecosystem of exciting libraries like:

  • Haiku is a neural network library providing object-oriented programming models.

  • RLax is a library for deep reinforcement learning.

  • Jraph, pronounced “giraffe”, is a library used for Graph Neural Networks (GNNs).

  • Optax provides an easy one-liner interface to utilize gradient-based optimization methods efficiently.

  • Chex is used for testing purposes.

Follow me here for more AI, Machine Learning, and Data Science tutorials to come!

You can stay up to date with Accel.AI; workshops, research, and social impact initiatives through our website, mailing list, meetup group, Twitter, and Facebook.

References

https://jax.readthedocs.io/en/latest/notebooks/quickstart.html

https://theaisummer.com/jax/

https://developer.nvidia.com/gtc/2020/video/s21989

https://www.shakudo.io/blog/a-quick-introduction-to-jax