Now that you understand what JAX is and its core design philosophy centered around function transformations, let's get your environment set up so you can start writing and running JAX code. Installing JAX is typically straightforward, but requires slightly different steps depending on whether you want to run on a CPU, NVIDIA GPU, or Google TPU.
This course assumes you have a working Python installation (version 3.7 or later) and are familiar with using pip
or conda
for package management. Using a virtual environment is highly recommended to avoid conflicts with other projects or your system's Python installation.
Before installing JAX, create and activate a virtual environment.
Using venv
:
# Create a virtual environment named .venv (or your preferred name)
python -m venv .venv
# Activate the environment
# On Linux/macOS:
source .venv/bin/activate
# On Windows (Command Prompt):
.\.venv\Scripts\activate.bat
# On Windows (PowerShell):
.\.venv\Scripts\Activate.ps1
Using conda
:
# Create a conda environment named 'jax-env' with Python
conda create -n jax-env python=3.9 # Or your desired Python version
# Activate the environment
conda activate jax-env
Once your environment is active, you can proceed with the JAX installation.
If you only plan to use your CPU, the installation is simple. JAX itself is a pure Python package, while jaxlib
contains the compiled backend (XLA) and platform-specific code.
Install the latest versions of jax
and jaxlib
using pip:
pip install --upgrade jax jaxlib
This command installs a jaxlib
build optimized for your system's CPU architecture.
To leverage an NVIDIA GPU, you need a compatible NVIDIA driver, CUDA toolkit, and optionally cuDNN installed on your system. JAX requires specific versions of CUDA and cuDNN that correspond to the pre-built jaxlib
wheels.
The most reliable way to install the correct GPU-enabled jaxlib
is to follow the specific instructions on the official JAX GitHub repository, as the required versions and installation commands change frequently.
Visit: https://github.com/google/jax#installation
Find the command tailored to your specific CUDA and cuDNN versions. The command will generally look something like this (but do not run this exact command without checking the official instructions):
# Example Only - Check official JAX GitHub for the correct command!
pip install --upgrade jax jaxlib[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
The [cuda<version>_cudnn<version>]
part specifies the target build. Using the wrong version will likely result in JAX not being able to detect or use your GPU. Always refer to the official JAX installation guide for the up-to-date command.
Using JAX on TPUs is primarily done through Google Cloud Platform (GCP) or Google Colaboratory (Colab).
Runtime
-> Change runtime type
-> TPU
), JAX is often pre-installed or can be installed with the standard CPU installation command. Colab manages the TPU drivers and underlying software stack.
pip install --upgrade jax jaxlib
After installation, you can verify that JAX is installed correctly and check which devices (CPU, GPU, TPU) it can detect. Open a Python interpreter or run a script with the following code:
import jax
import jax.numpy as jnp
try:
devices = jax.devices()
print("JAX detected the following devices:")
for i, device in enumerate(devices):
print(f"{i}: {device.platform.upper()} ({device.device_kind})")
# Test a simple computation
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (10,))
y = jnp.dot(x, x)
print(f"\nSuccessfully executed a simple JAX operation. Result: {y}")
# Check default device
print(f"\nDefault device: {jax.default_backend()}")
except Exception as e:
print("An error occurred during JAX verification:")
print(e)
print("\nPlease check your installation steps, especially GPU driver/CUDA versions if applicable.")
Example Output (CPU-only):
JAX detected the following devices:
0: CPU (cpu)
Successfully executed a simple JAX operation. Result: 10.811179161071777
Default device: cpu
Example Output (Single GPU):
JAX detected the following devices:
0: GPU (NVIDIA GeForce RTX 3090) # Your GPU model will vary
1: CPU (cpu)
Successfully executed a simple JAX operation. Result: 10.811179161071777
Default device: gpu
Example Output (Multi-GPU or TPU):
You would see multiple GPU or TPU devices listed.
JAX detected the following devices:
0: TPU (TPU v3)
1: TPU (TPU v3)
... # (Potentially more TPUs)
N: CPU (cpu)
Successfully executed a simple JAX operation. Result: 10.811179161071777
Default device: tpu
If the script runs without errors and lists the expected devices, your JAX installation is ready. If you encounter errors, particularly relating to CUDA or GPUs, double-check that your driver, CUDA toolkit, and jaxlib
versions are compatible by consulting the official JAX installation guide.
With JAX installed, you're ready to move on to working directly with JAX arrays and exploring its core features.
© 2025 ApX Machine Learning