JAX es una nueva biblioteca de Python de aprendizaje automático de Google, diseñada para la computación numérica de alto rendimiento. Su API para funciones numéricas se basa en NumPy. Tanto Python como NumPy son ampliamente utilizados y familiares, lo que hace que JAX sea simple, flexible y fácil de aprender.
JAX se define como «Transformaciones componibles de los programas Python+NumPy: diferenciar, vectorizar, JIT a GPU/TPU (graphic processor unit/ tensor processing unit) y más».
La biblioteca utiliza la transformación de la función de Autograd (puede diferenciar automáticamente el código nativo de Python y Numpy) para convertir una función en otra, que devuelve el gradiente de la original. Jax también ofrece un JIT de transformación de funciones para la compilación jit (Just in time compilation) de las funciones existentes y vmap y pmap para la vectorización y la paralelización, respectivamente.
JAX es un poco más rápido que NumPy que ya está bastante optimizado. JAX y NumPy generan eficazmente la misma serie corta de llamadas BLAS y LAPACK ejecutadas en una arquitectura de CPU y no hay mucho margen de mejora con respecto a la implementación de referencia de NumPy, puede darse el caso de que, con matrices pequeñas, JAX sea un poco más lento.
Mientras que TensorFlow y Pytorch pueden ser compilados, pero estos modos compilados se añadieron posteriormente a su desarrollo inicial y, por tanto, tienen algunos inconvenientes. En el caso de TensorFlow2.0, aunque el modo Eager Execution, es el modo por defecto, no es 100% compatible con el modo gráfico, lo que en ocasiones produce una mala experiencia para el desarrollador.
Pytorch tiene un mal historial de verse obligado a usar formatos de tensores menos intuitivos desde que adoptó la Eager Execution.
La diferencia entre los modos de ejecución es que Graph es difícil de aprender y de probar y es poco intuitivo. No obstante, la ejecución de gráficos es ideal para el entrenamiento de modelos grandes.
Para modelos pequeños, principiantes y desarrolladores promedio, la ejecución Eager es más adecuada.
Ventaja de JAX
La ventaja de JAX es que fue concebido para ambos modos de ejecución (Eager y Graph) desde el principio y adolece de los problemas de sus predecesores PyTorch y Tensorflow. Estas últimas bibliotecas de aprendizaje profundo constan de APIs de alto nivel para métodos avanzados de aprendizaje profundo. JAX, en comparación con éstos, es una biblioteca más funcional para una programación diferenciable arbitraria, permite compilar Jit tus propias funciones de Python en núcleos optimizados para XLA (XLA Accelerated Linear Algebra- es un compilador específico del área para álgebra lineal) utilizando una API de una sola función. La compilación y la diferenciación automática se pueden componer arbitrariamente, por lo que puedes expresar algoritmos sofisticados y obtener el máximo rendimiento sin tener que salir de Python.
JAX quizás sea actualmente lo más avanzado en términos Machine Learning (ML) y promete hacer que la programación de aprendizaje automático sea más intuitiva, estructurada y limpia. Y, sobre todo, puede reemplazar con importantes ventajas a Tensorflow y PyTorch.