JAX is a new python library Google machine learning, designed for high-performance numerical computing. Its API for numeric functions is based on NumPy. Both Python and NumPy are widely used and familiar, making JAX simple, flexible, and easy to learn.
JAX is defined as “Composable transformations of Python+NumPy programs: diff, vectorize, JIT to GPU/TPU (graphic processor unit/tensor processing unit) and more.”
The library uses the Autograd function transformation (can automatically differentiate native Python and Numpy code) to convert one function to another, which returns the gradient of the original. Jax also offers a Function transformation JIT for jit compilation (Just in time compilation) of the existing functions and vmap and pmap for vectorization and parallelization, respectively.
JAX is a bit faster than NumPy which is already quite optimized. JAX and NumPy effectively generate the same short series of BLAS and LAPACK calls executed on a CPU architecture and there is not much room for improvement over the NumPy reference implementation, it may be the case that with small arrays JAX is a little slower.
While TensorFlow and Pytorch They can be compiled, but these compiled modes were added after their initial development and therefore have some drawbacks. In the case of TensorFlow2.0, although the mode Eager Execution, is the default mode, it is not 100% compatible with graphical mode, which sometimes produces a bad experience for the developer.
Pytorch has a bad history of being forced to use less intuitive tensor formats since adopting Eager Execution.
The difference between the execution modes is that Graph It is difficult to learn and test and is not very intuitive. However, graph execution is ideal for training large models.
For small models, beginners and average developers, the Eager run is more suitable.
JAX Advantage
The advantage of JAX is that it was designed for both execution modes (Eager and Graph) from the beginning and suffers from the problems of its predecessors PyTorch and Tensorflow. These latest deep learning libraries consist of high-level APIs for advanced deep learning methods. JAX, compared to these, is a most functional library for arbitrary differentiable programming, it allows you to Jit compile your own Python functions into XLA-optimized kernels (XLA Accelerated Linear Algebra- is an area-specific compiler for linear algebra) using a single-function API. Compilation and automatic differentiation can be arbitrarily compose, so you can express sophisticated algorithms and get the maximum performance without having to leave Python.
JAX maybe it is currently the most advanced in terms of Machine Learning (ML) and promises to make machine learning programming more intuitive, structured, and cleaner. And, above all, it can replace with important advantages to Tensorflow and PyTorch.