Having established what JAX is and its functional programming approach built upon transformations, let's look at the fundamental building block you'll work with: JAX arrays. If you're comfortable with NumPy, you'll find the jax.numpy
API, conventionally imported as jnp
, very familiar. It mirrors much of the standard NumPy interface, providing a smooth transition.
import jax
import jax.numpy as jnp
# Standard convention
print(f"JAX version: {jax.__version__}")
Just like NumPy, you can create JAX arrays from existing Python lists or tuples, or use specialized functions.
From Python Lists/Tuples:
# Create an array from a list
python_list = [1.0, 2.0, 3.0]
jax_array_from_list = jnp.array(python_list)
print(f"Array from list: {jax_array_from_list}")
print(f"Type: {type(jax_array_from_list)}")
# Create a 2D array from nested lists
python_nested_list = [[1, 2], [3, 4]]
jax_2d_array = jnp.array(python_nested_list)
print(f"2D Array:\n{jax_2d_array}")
Using jax.numpy
Functions:
jax.numpy
provides equivalents for common NumPy array creation routines:
# Array of zeros
zeros_array = jnp.zeros((2, 3)) # Shape (2 rows, 3 columns)
print(f"Zeros array:\n{zeros_array}")
# Array of ones
ones_array = jnp.ones((3,), dtype=jnp.int32) # Shape (3,) with integer type
print(f"Ones array: {ones_array}")
print(f"Ones array dtype: {ones_array.dtype}")
# Array with a range of values
range_array = jnp.arange(5) # Similar to Python's range(5) -> [0, 1, 2, 3, 4]
print(f"Arange array: {range_array}")
# Linearly spaced array
linspace_array = jnp.linspace(0, 1, 5) # 5 points from 0 to 1 (inclusive)
print(f"Linspace array: {linspace_array}")
Generating random numbers in JAX differs significantly from NumPy because JAX functions must be pure. Pure functions always return the same output for the same input and have no side effects. NumPy's random functions maintain a global state, violating purity.
JAX handles randomness using explicit pseudorandom number generator (PRNG) keys. You create an initial key and then generate new keys from existing ones whenever you need more random numbers. This makes random number generation reproducible and compatible with JAX transformations.
from jax import random
# Create an initial PRNG key. Seed is typically an integer.
key = random.PRNGKey(0)
print(f"Initial key: {key}")
# Generate random numbers (e.g., from a normal distribution)
# This operation consumes the key, but doesn't modify it directly.
normal_random_numbers = random.normal(key, shape=(2, 2))
print(f"Normal random numbers:\n{normal_random_numbers}")
# To generate more random numbers, 'split' the key
key, subkey = random.split(key) # Creates a new key and a subkey
uniform_random_numbers = random.uniform(subkey, shape=(3,))
print(f"\nSplit key: {key}") # The original key is updated
print(f"Subkey: {subkey}")
print(f"Uniform random numbers: {uniform_random_numbers}")
# Calling random.normal again with the *same* original key yields the *same* result
# This demonstrates the purity and reproducibility
same_normal_numbers = random.normal(random.PRNGKey(0), shape=(2, 2))
print(f"\nSame normal numbers:\n{same_normal_numbers}")
assert jnp.allclose(normal_random_numbers, same_normal_numbers)
Managing these keys explicitly is essential for writing correct and reproducible JAX code, especially when using transformations like jit
or vmap
.
JAX arrays share familiar properties with NumPy arrays:
x = jnp.arange(12).reshape((3, 4))
print(f"Array x:\n{x}")
print(f"Shape: {x.shape}") # Tuple indicating dimensions (rows, columns)
print(f"Data type: {x.dtype}") # Type of elements (e.g., float32, int32)
print(f"Number of dims: {x.ndim}") # Number of axes (2 for a matrix)
print(f"Total elements: {x.size}") # Total number of elements (3 * 4 = 12)
JAX often defaults to 32-bit precision (float32
, int32
) for performance, especially on accelerators. You can enable 64-bit precision if needed, though it might impact speed.
This is perhaps the most important distinction from standard NumPy. JAX arrays are immutable. Once created, their values cannot be changed in place.
numpy_array = np.array([1, 2, 3])
numpy_array[0] = 100 # Works fine in NumPy
print(f"NumPy array after modification: {numpy_array}")
jax_array = jnp.array([1, 2, 3])
try:
# This will raise an error because JAX arrays are immutable
jax_array[0] = 100
except TypeError as e:
print(f"\nError trying to modify JAX array in place: {e}")
# Instead, use functional 'indexed update' syntax
# This creates a *new* array with the updated value.
updated_jax_array = jax_array.at[0].set(100)
print(f"Original JAX array (unchanged): {jax_array}")
print(f"Updated JAX array (new object): {updated_jax_array}")
# You can perform more complex updates too:
incremented_array = jax_array.at[1].add(10) # Add 10 to element at index 1
print(f"Incremented array: {incremented_array}")
This immutability is crucial for JAX's function transformations (jit
, grad
, vmap
, pmap
). It ensures that functions remain pure, without hidden side effects caused by modifying inputs directly. While it requires a slightly different way of thinking about updates, the array.at[index].set(value)
pattern becomes natural with practice and fits well within the functional paradigm.
Arithmetic operations and universal functions (ufuncs) work much like in NumPy:
a = jnp.array([[1, 2], [3, 4]])
b = jnp.array([[5, 6], [7, 8]])
# Element-wise operations
print(f"a + b:\n{a + b}")
print(f"a * b:\n{a * b}") # Element-wise multiplication
# Matrix multiplication
print(f"Matrix product (jnp.dot):\n{jnp.dot(a, b)}")
print(f"Matrix product (@):\n{a @ b}")
# Universal functions
print(f"Sine of a:\n{jnp.sin(a)}")
print(f"Exponential of a:\n{jnp.exp(a)}")
Standard NumPy-style indexing and slicing are used for reading data:
data = jnp.arange(10)
print(f"\nOriginal data: {data}")
# Get a single element
print(f"Element at index 3: {data[3]}")
# Get a slice
print(f"Slice from index 2 to 5: {data[2:5]}")
# Multi-dimensional indexing
matrix = jnp.arange(9).reshape((3, 3))
print(f"Matrix:\n{matrix}")
print(f"Element at row 1, col 2: {matrix[1, 2]}")
print(f"First two rows:\n{matrix[:2, :]}")
print(f"First column:\n{matrix[:, 0]}")
Remember, due to immutability, if you need to modify parts of an array, you must use the .at[...].set(...)
(or .add
, .multiply
, etc.) syntax to create a new, updated array.
JAX automatically handles placing arrays and executing computations on the available hardware (CPU, GPU, or TPU). You generally don't need to manage this manually for basic operations. You can inspect where an array resides:
x = jnp.ones(3)
try:
# device() method gives the device buffer resides on
print(f"\nArray x is on device: {x.device()}")
except AttributeError:
# Older JAX versions might not have .device() readily available this way
# or the object might be a tracer abstract value before execution.
# More robust checking often involves checking jax.devices()
print("\nDevice information might require specific context or jax.devices().")
print(f"Available devices: {jax.devices()}")
Understanding device placement becomes more relevant when using pmap
for multi-device parallelism, as covered in a later chapter.
With this understanding of JAX arrays, their similarities to NumPy arrays, and the critical concept of immutability, you are ready to explore JAX's powerful function transformations, starting with jax.jit
for accelerating your code.
© 2025 ApX Machine Learning