To leverage JAX's powerful capabilities, such as automatic differentiation and just-in-time compilation, you'll first need to set up your environment. In this section, we'll guide you through the installation process and prepare you to start experimenting with JAX in your projects. If you're familiar with Python environments and package management, this should be a straightforward process.
Before setting up JAX, ensure that you have the following:
python --version
or python3 --version
in your terminal.pip --version
.JAX is designed to be flexible and efficient, taking advantage of hardware acceleration when available. The installation process can vary slightly depending on whether you want to utilize CPU, GPU, or TPU capabilities.
For most users, starting with a CPU installation is a good choice, especially if you're new to JAX or don't have access to a GPU. You can install JAX with the following pip command:
pip install --upgrade "jax[cpu]"
This command installs JAX along with the necessary dependencies to run computations on your CPU.
If you have a compatible NVIDIA GPU and wish to utilize it for accelerated computations, JAX can be installed with GPU support. This requires CUDA and cuDNN, which must be installed on your system. Once set up, you can install JAX with GPU support as follows:
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Ensure that your CUDA and cuDNN versions match those supported by the JAX release you're installing. Consult the JAX documentation for more details on compatibility.
For users with access to Google Cloud's TPUs, JAX can also be configured to utilize these for even faster computations. However, setting up TPUs involves additional steps beyond the scope of this introduction. Refer to Google's Cloud TPU documentation for guidance.
Once installed, you can verify that JAX is correctly set up by running a simple script to check its functionality. Open a Python shell and execute the following:
import jax.numpy as jnp
# Create an array using JAX
x = jnp.array([1, 2, 3])
# Perform a simple computation
y = jnp.sum(x ** 2)
print(y)
The output should be 14
, which confirms that JAX is operational. The use of jax.numpy
is intentional, as JAX provides a NumPy-like API that makes transitioning from NumPy straightforward.
JAX allows you to configure several aspects of its execution. For instance, you can control the number of CPU threads JAX uses by setting the XLA_FLAGS
environment variable before running your Python script:
export XLA_FLAGS=--xla_cpu_multi_thread_eigen=true
export XLA_FLAGS=--xla_cpu_multi_thread_eigen_max_nthreads=4
Adjust the max_nthreads
value according to your system's CPU capabilities.
With JAX installed and configured, you're ready to explore the world of high-performance numerical computing. As you proceed through this course, you'll delve into JAX's capabilities in greater depth, leveraging its unique features to streamline your data science and machine learning tasks. Remember, JAX's strength lies in its ability to seamlessly integrate with existing Python workflows while providing powerful tools for differentiation and hardware acceleration.
© 2025 ApX Machine Learning