Site icon TensorOps

Should I Switch From NumPy to JAX?

JAX vs NumPy: Since its creation in 2005 as an open source library NumPy managed to rule as the unquestionable favourite math tool among python developers. JAX is the new toy in the box; With overwhelming performance and better fit for deep learning use cases – can it capture the hearts of data scientists around the world and become their new favourite?

Other libraries have tried to take NumPy’s place before, will JAX finally do it?

If you’re a Python user, chances are you most likely have used NumPy. Started by Travis Oliphant in 2005, this library is the swiss-army-knife for python developers work with scientific computations. NumPy provides a high-performance multidimensional array object and a set of functions that together create a useful tool mathematic programming. NumPy’s usefulness earned it the place as the underlying engine of other popular libraries like SciPy and Pandas.

JAX is (still) a research project in development. Introduced by Google in 2018 JAX is designed to replace NumPy by being faster, easier and better fit for deep learning. It brings Autograd and linear algebra together while allowing you to run your NumPy programs on GPU and TPU. It integrates just-in-time compilation transforming your Python code into new XLA-optimized functions, with the help of under-the-hood magic.

NumPyJAX
HardwareCPUCPU, GPU, TPU
Originated byTravis Oliphant in 2005Google in 2020
Open Source
AutoGrad
Immutable
Function derivation
Function vectorisation
Parallel computation
ExecutionSynchronouslyAsynchronously
Summary table

In this post, I will go over on how familiar the JAX syntax is for NumPy users, while making a performance comparison between the libraries. I will also go over the exclusive features JAX brings to the table and highlight that where to be more careful when choosing JAX.

One thing’s for sure! You can’t always just wave your magic wand at your code and suddenly make it run faster when using JAX. But there are still many reasons you switch from NumPy to JAX.

JAX Looks like NumPy

In its basis, JAX is NumPy on steroids. Since JAX imitates the NumPY API, you can use the same syntax you’re already familiar with. For example, creating an array with JAX looks very similar to NumPy:

x = np.arange(10)
print(x)
# NumPy output: [0 1 2 3 4 5 6 7 8 9]
x_jax = jnp.arange(10)
print(x_jax)
# JAX output: [0 1 2 3 4 5 6 7 8 9]

In this case both libraries output look alike values. The big appeal of JAX is that for the most part, all your code will run the same if you just switch the import numpy as np to import jax as np. The convention however is to import jax as jnp to destinguish between the two. I’ll show you why replacing NumPy with JAX doesn’t always work.

Familiar syntax but not identical

One key difference between NumPy and JAX arrays, is that JAX arrays are immutable. After they’re initialised, their content cannot be changed.

x = jnp.arange(10)
x[0] = 10
# output: TypeError: '<class 'jaxlib.xla_extension.DeviceArray'>' 
# object does not support item assignment.
# JAX arrays are immutable.

If for any need reason, you might need to change the content of your arrays, JAX allows you to perform calculations using NumPy arrays. There are workarounds to change values on a JAX array but it’s not as easy as with NumPy.

Another thing that you may notice is the that JAX stores arrays in the DeviceArray type, which enables the same object to be broadcasted to different type accelerators. Later in the blog I’ll mention the advantages of it.

JAX is faster

Even if you are not implementing ML/DL you may consider JAX just for it being faster than NumPy. JAX has several ways of achieving this speed improvement and will show case two of them applying them to different typical tasks:

  1. JIT – Just In Time compilation performed on sum of matrices.
  2. Extended accelerators support leveraged for large matrix dot product

All of our experiments were run using Google Colab notebooks.

JIT

JAX brings just-in-time compilation (JIT) to the table. This allows it to compile your Python functions into a super fast XLA-optimized executables. In this first example I used a function that performs a sum of matrices powers to the third degree, and tested it out in under the following conditions:

  1. NumPy over CPU
  2. JAX over CPU with JIT
def f(x):
  return -4*x*x*x + 9*x*x + 6*x - 3

x = np.random.randn(10000, 10000)

%timeit f(x) # NumPy
%timeit jax.jit(f)(jnp.array(x)) # JAX + JIT

At you can see, when using JIT, JAX performs the operation 6 times faster than NumPy! Later in the post I mention how JAX does this magic by translating the python code to its own language.

Extended accelerator support

It shouldn’t be surprising that matrix multiplication run faster on accelerators like GPUs and TPUs. What JAX offers in this case is the ability to automatically leverage accelerators with no changes made to the code! Let me show case that by multiplying a large (10000 x 10000) random matrix by itself on different hardwares:

  1. NumPy (CPU)
  2. JAX (CPU)
  3. JAX (GPU)
  4. JAX (TPU)
x = np.random.normal(size=(10000, 10000))
x_jax = jax.random.normal(jax.random.PRNGKey(0), (10000, 10000))

%timeit np.dot(x, x.T) # NumPy
%timeit jnp.dot(x_jax, x_jax.T).block_until_ready() # JAX
JAX and NumPy – On CPU NumPy wins but JAX can run on accelerators

By keeping the code EXACTLY the same and only changing the Colab’s accelerator from CPU to GPU and then TPU I was able to shrink the running time by several magnitudes of order. In its best run, NumPy performed the dot product of the matrices in approximately 11.7 seconds, outperforming JAX results on a CPU. But running on machines with GPUs and TPUs JAX automatically leveraged them reached up to 120X improvement. Read later (again) to understand why JAX doesn’t always run faster than NumPy on CPUs.

Designed for ML and Data Science

Earlier in the post, you saw that at its basis, JAX is NumPy on steroids. But that is only be the tip of the iceberg. For example, JAX also includes a growing version of its own SciPy package. JAX has also three cool functions that can make it the default for ML/DL libraries’ backends:

  1. grad – perform differentiation.
  2. vmap – vectorize operations. and;
  3. pmap – parallel computation.

let’s see how these functions work and why they are better for ML/DL.

calculating gradients

Gradient calculation is of course important for Deep Machine and Learning algorithms that almost always use Gradient Descent as their optimization method. JAX’s grad function offers out-of-the-box quick automatic derivation of python functions. Using autodiff, JAX makes it easy to even compute high order derivatives like so:

def f(x):
 return -4*x*x*x + 9*x*x + 6*x - 3
 
dfdx = jax.grad(f)
d2fdx = jax.grad(dfdx) # jax.grad(jax.grad(dfdx))
d3fdx = jax.grad(d2fdx) # jax.grad(jax.grad(jax.grad(dfdx)))

Start with the original function f(x).

grad will return the following three derivatives.

You can test the new functions. Make sure the input you feed to the transformed functions is a float value.

f(1) # output: 8
dfdx(1.) # output: 12
d2fdx(1.) # output: -6
d3fdx(1.) # output: -24

The ability to calculate high order derivatives of complex functions is also useful for back-propagation that’s used in deep learning, when updating the network’s weights. On this Kaggle kernel there’s example to how grad can help implement a simple gradient descent example, with more abstraction that NumPy can offer.

VMAP: write loops -> execute on XLA:

Array programming written for XLA can boost some computations by a few orders of magnitude if they can be expressed as vector-algebra. One example of such computation is of course the forward pass performed when training neural networks. With NumPy that’s achievable but requires writing the code in advance using a vector-based syntax. JAX provides the vmap transformation to automatically get a vectorised implementation of a function.

JAX will automatically try to transform this function to a vectorised form, making it runnable on linear accelerators (like GPUs). For the ML developer it means: easier way to express math without losing performance.

In our tests, we didn’t find good use to vmap however. Since it’s working only with JAX compatible object, then probably the code you’re trying to vectorize is already vectorized..

PMAP – parallelization without effort

Conceptually, the process of pmap is not very much different from what happens with vmap. JAX allows for multiple device parallelism by using pmap to distribute a function that usually runs in one device and distribute it automatically to run parallel across multiple TPUs.

convolve_pmap = jax.pmap(convolve)

This command by itself will allow JAX to concurrently perform the convolution in multiple devices it can access, which will significantly increase its computing speed specially on larger computations.

Under JAX’s Hood

As with most stuff, there’s always a catch! Are JAX transformation functions the solution for everything? Should you always pick JAX over NumPy?

JAX transformations are not perfect

In the background JAX translate python functions to a set of primitive instructions to create an intermediate language jaxpr. You can actually reveal how JAX interprets functions by calling make_jaxpr on a function:

def f(x):
  return -4*x*x*x + 9*x*x + 6*x - 3

trace_f = make_jaxpr(f)

trace_f(3)

The output should look like this:

{ lambda ; a:i32[]. let
    b:i32[] = mul a -4
    c:i32[] = mul b a
    d:i32[] = mul c a
    e:i32[] = mul a 9
    f:i32[] = mul e a
    g:i32[] = add d f
    h:i32[] = mul a 6
    i:i32[] = add g h
    j:i32[] = sub i 3
  in (j,) }

This breakdown to primitive expressions allows running Just-In-Time optimized compilation that yields the improvement in execution times that I mentioned before. However, You may discover that JAX transformations don’t always work on your code. The following snippet, for example, will raise an error:

def f(x):
 if x > 0:
   return x
 else:
   return 2 * x

jax.jit(f)(0)

The input of the function that it’s trying to jit-compile contains a condition. JIT is dependent on the values used to trace it back when translating Python functions into the jaxpr and therefore fails.

NumPy may outperform JAX

You can also discover that JAX might not be the most suitable for your specific use case. And that comes back to the architectural differences between the two libraries.

While NumPy runs exclusively on CPU, making its operations be executed eagerly and synchronously, JAX can execute both eagerly or after the first compilation if you’re using jit. In both cases, operations are dispatched asynchronously. You can also notice that in some cases, JAX might be outperformed by NumPy. This could happen if you’re testing in small inputs, because of jit‘s overhead.

To prove this point, I tried to run the function that computes the dot product with smaller sized matrices, using NumPy and JAX with jit. Inspired by this StackOverflow discussion I found that there are some variable size ranges in which JAX can underperform NumPy.

JAX’s philosophy can have you think in a different way when writing your code. For example, the four transformations were designed to only work in pure functions – functions that should give the same result given the same input.

While errors might not be directly thrown when running the transformation functions on an impure function, other “side effects” may show up later in the process without you knowing, and silently messing up all of your data and results.

Final say: Change to JAX?

For new code I’d say definitely yes. For old code? not sure. For the most part, you can pass JAX objects or call JAX methods instead of NumPy’s. If you’re running your code on a machine with an accelerator you’ll notice the change by orders of magnitude immediately . But in order to truly leverage the benefits of JAX on CPU – you will need to adapt your code and JIT-compile parts of the code, use pmap and vmap and more. Therefor if you write new code it should be easy, but for old code? could be a too big of an effort for CPU running tasks.

Exit mobile version