Okay, let's put the concepts from this chapter into practice. You've learned that JAX provides a NumPy-like API via jax.numpy
and that JAX arrays are immutable and reside on specific devices (CPU, GPU, or TPU). Now, we'll work through some fundamental array operations. Make sure you have JAX installed.
First, let's import jax.numpy
:
import jax.numpy as jnp
import jax
import numpy as np # Often useful to compare or for operations JAX doesn't cover
print(f"JAX version: {jax.__version__}")
print(f"Default JAX backend: {jax.default_backend()}")
# Check available devices (may vary based on your setup)
print(f"Available JAX devices: {jax.devices()}")
Similar to NumPy, you can create JAX arrays from Python lists or tuples, or using dedicated creation functions.
# From Python list
py_list = [1.0, 2.5, 3.0, 4.2]
jax_array_from_list = jnp.array(py_list)
print("JAX array from list:", jax_array_from_list)
print("Type:", type(jax_array_from_list))
print("dtype:", jax_array_from_list.dtype)
# Create specific arrays
zeros_array = jnp.zeros((2, 3)) # Shape (2 rows, 3 columns)
print("\nZeros array:\n", zeros_array)
ones_array = jnp.ones((3, 2), dtype=jnp.int32) # Specify data type
print("\nOnes array (int32):\n", ones_array)
print("dtype:", ones_array.dtype)
range_array = jnp.arange(0, 10, 2) # Start, stop, step
print("\nRange array:", range_array)
linspace_array = jnp.linspace(0, 1, 5) # Start, stop, number of points
print("\nLinspace array:", linspace_array)
Notice the output type is specific to JAX (often jaxlib.xla_extension.DeviceArray
or similar). JAX usually defaults to 64-bit floats (float64
) on CPU and 32-bit floats (float32
) on GPU/TPU for performance reasons, although you can explicitly set the dtype
as shown with jnp.ones
.
For random numbers, JAX uses an explicit stateful pseudo-random number generator (PRNG) approach, which contrasts with NumPy's global state. You need to create a PRNGKey
.
# Create a pseudo-random number generator key
key = jax.random.PRNGKey(42) # Seed is 42
# Generate random numbers (e.g., uniform distribution between 0 and 1)
random_array = jax.random.uniform(key, shape=(2, 2))
print("\nRandom array (uniform):\n", random_array)
# Important: To get new random numbers, you must 'split' the key
key, subkey = jax.random.split(key)
random_normal_array = jax.random.normal(subkey, shape=(3,)) # Standard normal distribution
print("\nRandom array (normal):\n", random_normal_array)
Using and splitting keys ensures reproducibility, which is essential for debugging and consistent results, especially when function transformations are involved later.
You can check array properties just like in NumPy:
print("\nInspecting random_array:")
print("Shape:", random_array.shape)
print("Size:", random_array.size) # Total number of elements
print("Number of dimensions:", random_array.ndim)
print("Data type:", random_array.dtype)
# Check the device the array is on
# This will show CPU, GPU, or TPU depending on your setup and JAX configuration
print("Device:", random_array.device())
Arithmetic operators work element-wise, creating new arrays. Remember, JAX arrays are immutable.
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])
# Element-wise operations
c_add = a + b
print("\na + b =", c_add)
c_mul = a * b
print("a * b =", c_mul)
# Scalar operations
c_scalar_add = a + 10
print("a + 10 =", c_scalar_add)
c_scalar_mul = a * 2
print("a * 2 =", c_scalar_mul)
# Check that 'a' remains unchanged (immutability)
print("Original 'a':", a)
Matrix multiplication uses the @
operator or jnp.matmul
:
mat_a = jnp.array([[1, 2], [3, 4]])
mat_b = jnp.array([[5, 6], [7, 8]])
mat_product = mat_a @ mat_b
# Equivalent: mat_product = jnp.matmul(mat_a, mat_b)
print("\nMatrix product (mat_a @ mat_b):\n", mat_product)
Indexing and slicing work similarly to NumPy, but again, they return new arrays (or more accurately, DeviceArrays which behave like copies in terms of mutability), not views into the original data, due to immutability. Modifying a slice of a JAX array is not possible directly.
data = jnp.arange(12).reshape((3, 4))
print("\nOriginal data:\n", data)
# Get a single element
element = data[1, 2] # Row 1, Column 2
print("Element at [1, 2]:", element)
# Note: Accessing a single element returns a 0-dimensional array
# Get a row
row_1 = data[1, :] # Row 1, all columns
print("Row 1:", row_1)
# Get a column
col_2 = data[:, 2] # All rows, Column 2
print("Column 2:", col_2)
# Get a sub-array (slice)
sub_array = data[0:2, 1:3] # Rows 0-1, Columns 1-2
print("Sub-array (0:2, 1:3):\n", sub_array)
# Attempting to modify a slice will fail or behave unexpectedly
# This is different from NumPy where slices are often views
try:
# This operation is generally not supported or doesn't modify 'data'
# In standard JAX, this raises an error because DeviceArray doesn't support item assignment
# data[0, 0] = 99
# A more JAX-idiomatic way to update requires indexed update functions:
updated_data = data.at[0, 0].set(99)
print("\nOriginal data (still unchanged):\n", data)
print("Updated data (new array):\n", updated_data)
except TypeError as e:
print(f"\nAs expected, direct item assignment failed: {e}")
print("Use `.at[index].set(value)` for functional updates.")
The .at[...].set(...)
syntax is the JAX way to perform out-of-place updates, returning a modified copy while leaving the original array untouched. This functional approach is necessary for compatibility with JAX transformations like jit
and grad
.
JAX provides many element-wise universal functions found in NumPy:
x = jnp.linspace(0, jnp.pi * 2, 5)
print("\nx:", x)
y_sin = jnp.sin(x)
print("sin(x):", y_sin)
y_exp = jnp.exp(x / (jnp.pi * 2)) # Scale x to be 0 to 1 before exp
print("exp(x scaled):", y_exp)
Let's visualize sin(x)
:
{"data": [{"x": x.tolist(), "y": y_sin.tolist(), "type": "scatter", "mode": "lines+markers", "marker": {"color": "#339af0"}, "line": {"color": "#339af0"}}], "layout": {"title": "Sine Function using JAX Arrays", "xaxis": {"title": "x (radians)"}, "yaxis": {"title": "sin(x)"}, "margin": {"l": 40, "r": 20, "t": 40, "b": 40}, "height": 300}}
Values generated by
jnp.linspace
and transformed byjnp.sin
.
Functions that aggregate array values like sum
, mean
, max
, min
are available. You can perform reductions over the entire array or along specific axes.
matrix = jnp.arange(12).reshape((3, 4))
print("\nMatrix:\n", matrix)
total_sum = jnp.sum(matrix)
print("Total sum:", total_sum)
sum_along_rows = jnp.sum(matrix, axis=0) # Sum elements in each column
print("Sum along rows (axis=0):", sum_along_rows)
mean_along_cols = jnp.mean(matrix, axis=1) # Mean elements in each row
print("Mean along columns (axis=1):", mean_along_cols)
max_val = jnp.max(matrix)
print("Maximum value:", max_val)
You can change the shape of an array without changing its data, which again produces a new array.
original = jnp.arange(6)
print("\nOriginal 1D array:", original)
reshaped_2x3 = original.reshape((2, 3))
# Equivalent: reshaped_2x3 = jnp.reshape(original, (2, 3))
print("Reshaped to (2, 3):\n", reshaped_2x3)
reshaped_3x2 = jnp.reshape(original, (3, 2))
print("Reshaped to (3, 2):\n", reshaped_3x2)
This hands-on practice covers the most common array operations. You should feel comfortable creating, manipulating, and inspecting JAX arrays using the jax.numpy
interface. The similarities to NumPy make this transition relatively smooth, but keep the core differences, particularly immutability and the explicit PRNG handling, in mind. These features are fundamental to how JAX achieves high performance and enables powerful function transformations, which we will explore in the next chapters. Experiment with these operations yourself to solidify your understanding.
© 2025 ApX Machine Learning