Before distributing computations across multiple devices using primitives like pmap
, it's essential to understand how JAX identifies and manages the available hardware accelerators. JAX automatically detects CPUs, GPUs, and TPUs connected to your system (or assigned within a distributed environment) and provides utilities to inspect and interact with them.
JAX offers several functions to query the computational resources it can access. The most fundamental is jax.devices()
. It returns a list of all devices JAX can potentially use across all participating hosts (in a multi-host setup) or just the local host if running standalone.
import jax
# List all devices JAX can see globally (may be identical to local_devices in single-host)
all_devices = jax.devices()
print(f"All available devices: {all_devices}")
# Get the total number of devices globally
num_devices = jax.device_count()
print(f"Total number of devices: {num_devices}")
# List only devices local to the current process/host
local_devices_list = jax.local_devices()
print(f"Local devices: {local_devices_list}")
# Get the number of local devices
num_local_devices = jax.local_device_count()
print(f"Number of local devices: {num_local_devices}")
# Example output (might vary based on hardware)
# All available devices: [CpuDevice(id=0)] # If only CPU is available
# All available devices: [TpuDevice(id=0), TpuDevice(id=1), ...] # On a TPU Pod slice
# All available devices: [GpuDevice(id=0), GpuDevice(id=1)] # On a dual-GPU machine
# Total number of devices: 1
# Local devices: [CpuDevice(id=0)]
# Number of local devices: 1
Each element in the list returned by jax.devices()
or jax.local_devices()
is a Device
object. These objects contain information about the specific hardware, such as its platform ('cpu'
, 'gpu'
, 'tpu'
), a unique ID within that platform, and potentially other attributes like process index in multi-host scenarios.
if local_devices_list:
device = local_devices_list[0]
print(f"First local device: ID={device.id}, Platform={device.platform}")
# Example output: First local device: ID=0, Platform=gpu
Understanding the distinction between jax.devices()
and jax.local_devices()
is particularly relevant in multi-host TPU environments. jax.devices()
provides a global view across all hosts connected to the TPU Pod slice, while jax.local_devices()
shows only the devices directly attached to the current Python process. For single-host GPU or CPU setups, these two functions often return the same list.
By default, JAX operations and array creation target the first device listed by jax.local_devices()
, which is typically cpu:0
or gpu:0
or tpu:0
depending on your setup and backend configuration.
import jax.numpy as jnp
# Creates the array on the default device (usually device 0)
x = jnp.ones((3, 3))
print(f"Default device for x: {x.device()}")
# Example output: Default device for x: GpuDevice(id=0)
This default behavior is convenient for single-device workflows, but for distributed computing, you often need more explicit control.
jax.device_put
The jax.device_put()
function allows you to explicitly place a NumPy array or a Python scalar onto a specific JAX device, returning a DeviceArray
(JAX's array type) resident on that device.
Its signature is jax.device_put(x, device=None)
.
x
: The data to be placed (e.g., NumPy array, Python scalar/list).device
: The target Device
object (obtained from jax.devices()
or jax.local_devices()
). If None
, the default device is used.import numpy as np
if num_local_devices > 1:
# Create a NumPy array on the host CPU
host_array = np.random.rand(2, 2)
print(f"NumPy array type: {type(host_array)}")
# Explicitly place it onto the second available JAX device
target_device = jax.local_devices()[1]
device_array = jax.device_put(host_array, device=target_device)
print(f"Placed array type: {type(device_array)}")
print(f"Array is now on device: {device_array.device()}")
# Example output (on a multi-GPU system):
# NumPy array type: <class 'numpy.ndarray'>
# Placed array type: <class 'jaxlib.xla_extension.DeviceArray'>
# Array is now on device: GpuDevice(id=1)
elif num_local_devices == 1:
print("Only one local device available, skipping explicit placement example.")
else:
print("No JAX devices found.")
jax.device_put()
is significant because:
pmap
: When using pmap
, you often prepare data shards on the host and then use jax.device_put
(or rely on pmap
's implicit placement) to ensure data starts on the correct devices before the parallel computation begins. Although pmap
can handle the distribution implicitly for input arguments, explicitly placing initial model parameters or states might sometimes be necessary for clarity or specific initialization patterns.jit
compilation often determines where operations run, jax.device_put
can influence the location of initial data, which can be important for performance, preventing unnecessary transfers later.Once an array is placed on a device (either explicitly via jax.device_put
or implicitly), JAX tries to keep computations involving that array on the same device, minimizing data transfers. Operations involving arrays on different devices might trigger data movement or result in errors if the operation isn't defined across devices.
You can influence which devices JAX sees, particularly GPUs, using environment variables before importing JAX. The most common one is CUDA_VISIBLE_DEVICES
.
# Example: Make only GPU 1 visible to JAX (and other CUDA applications)
export CUDA_VISIBLE_DEVICES=1
python my_jax_script.py
Inside my_jax_script.py
, jax.local_devices()
would likely only list GpuDevice(id=0)
(as JAX re-enumerates the visible devices starting from 0), which actually corresponds to the physical GPU 1. Setting CUDA_VISIBLE_DEVICES=""
typically hides all GPUs, forcing JAX to use the CPU.
JAX also respects the JAX_PLATFORMS
environment variable. If set, JAX will only initialize the specified platform backends (e.g., JAX_PLATFORMS=cpu
or JAX_PLATFORMS=gpu
). This can be useful for forcing a specific backend when multiple are available.
Mastering device management is foundational for distributed computing. Knowing which devices are available, where your data resides (x.device()
), and how to control placement (jax.device_put
) are prerequisites for effectively utilizing pmap
and scaling your computations. In the following sections, we will build upon this foundation to implement parallel execution patterns.
© 2025 ApX Machine Learning