If you've worked with NumPy, you'll find JAX's array library, jax.numpy
, remarkably familiar. This is by design. JAX aims to provide a high-performance numerical computing environment that feels comfortable for Python users already acquainted with NumPy, the cornerstone of scientific computing in Python.
Conventionally, jax.numpy
is imported as jnp
:
import numpy as np
import jax.numpy as jnp
import jax
# Check available devices (CPU is always present)
print(f"JAX Devices: {jax.devices()}")
# NumPy array creation
np_array = np.array([1.0, 2.0, 3.0])
print(f"NumPy Array: {np_array}, Type: {type(np_array)}")
# JAX array creation
jnp_array = jnp.array([1.0, 2.0, 3.0])
print(f"JAX Array: {jnp_array}, Type: {type(jnp_array)}")
You'll notice that many function names and behaviors are identical. Creating arrays, performing element-wise operations, calculating sums, means, or standard deviations often involves the exact same function calls, just using jnp
instead of np
.
# Basic operations look similar
a_np = np.arange(6).reshape((2, 3))
b_np = np.array([[1, 1, 1], [2, 2, 2]])
c_np = a_np + b_np * 2
a_jnp = jnp.arange(6).reshape((2, 3))
b_jnp = jnp.array([[1, 1, 1], [2, 2, 2]])
c_jnp = a_jnp + b_jnp * 2 # Same syntax as NumPy
print(f"NumPy Result:\n{c_np}")
print(f"JAX Result:\n{c_jnp}")
# Check if results are close (useful for comparing float results)
print(f"Results are close: {np.allclose(c_np, c_jnp)}")
This similarity significantly lowers the barrier to entry. Much of your existing NumPy knowledge is directly transferable. However, beneath this familiar surface lie fundamental differences critical to understanding JAX's purpose and capabilities.
While the API strives for compatibility, JAX operates differently from NumPy in several important ways:
Immutability: This is perhaps the most significant difference in day-to-day coding. NumPy arrays are mutable, meaning you can change their values in place. JAX arrays, on the other hand, are immutable. You cannot modify a JAX array after it's created; operations that seem to modify an array actually return a new array with the updated values.
np_array = np.array([1, 2, 3])
np_array[0] = 100 # This works fine in NumPy
print(f"Modified NumPy array: {np_array}")
jnp_array = jnp.array([1, 2, 3])
try:
# This will raise a TypeError in JAX
jnp_array[0] = 100
except TypeError as e:
print(f"\nError modifying JAX array in place: {e}")
# The JAX way: create a new array with the updated value
# Uses indexed update syntax: .at[index].set(value)
updated_jnp_array = jnp_array.at[0].set(100)
print(f"Original JAX array (unchanged): {jnp_array}")
print(f"Updated JAX array (new object): {updated_jnp_array}")
Immutability is a core tenet of functional programming and is essential for JAX's function transformations (like jit
and grad
) to work correctly and reliably, especially across different hardware accelerators. It prevents side effects, making code easier to reason about, parallelize, and differentiate.
Hardware Acceleration: Standard NumPy operations run exclusively on the CPU. JAX is designed from the ground up to run on different types of hardware accelerators, such as Graphics Processing Units (GPUs) and Tensor Processing Units (TPUs), in addition to CPUs. JAX typically handles device placement automatically, sending computations to the fastest available accelerator. You often don't need to change your code significantly to benefit from GPU/TPU speedups. We discuss device management later in this chapter.
Execution Model (Lazy Evaluation & Compilation): NumPy operations execute eagerly. When you type c = a + b
, the addition happens immediately. JAX operations, especially when combined with transformations like jax.jit
(Just-In-Time compilation), often use lazy evaluation. JAX might build up an internal computation graph and only execute it when the result is actually needed (e.g., printed or saved). jax.jit
compiles your Python function into optimized XLA (Accelerated Linear Algebra) code specifically targeted at the available hardware (CPU/GPU/TPU). This compilation step, which happens automatically behind the scenes, is a primary source of JAX's performance advantage over standard NumPy for computationally intensive tasks. We will study jax.jit
in detail in Chapter 2.
Function Transformations: The most profound difference isn't in the jax.numpy
API itself, but in what JAX enables around it. JAX provides composable function transformations:
jax.jit
: For Just-In-Time (JIT) compilation to accelerate code.jax.grad
: For automatic differentiation (computing gradients).jax.vmap
: For automatic vectorization (mapping functions over array axes).jax.pmap
: For parallelization across multiple devices (SPMD programming).
Standard NumPy has no equivalent for these transformations. They are the core reason JAX is so effective for machine learning research and high-performance computing. We will explore each of these transformations in subsequent chapters.Type Promotion and Precision: While generally compatible, there can be subtle differences in default data types (e.g., JAX defaults to 32-bit floats on startup unless configured otherwise, potentially differing from NumPy's default 64-bit floats) and how types are promoted during operations involving mixed types. It's good practice to be explicit about data types using dtype
arguments when precision is important.
# Check default float types (may vary based on JAX config)
np_float = np.array([1.0]).dtype
jnp_float = jnp.array([1.0]).dtype
print(f"\nDefault NumPy float type: {np_float}")
print(f"Default JAX float type: {jnp_float}")
# Explicitly set dtype
jnp_float64 = jnp.array([1.0, 2.0], dtype=jnp.float64)
print(f"JAX array with explicit float64: {jnp_float64.dtype}")
Here's a quick comparison:
Feature | NumPy (numpy ) |
JAX (jax.numpy ) |
---|---|---|
API Similarity | - | High, mimics NumPy API |
Mutability | Mutable (in-place modification) | Immutable |
Hardware | CPU | CPU, GPU, TPU |
Execution Model | Eager | Lazy (often via JIT), Compiled via XLA |
Transformations | None | jit , grad , vmap , pmap |
Primary Goal | General numerical computing | High-performance, Differentiable computation |
State Handling | Allows side effects | Prefers pure functions (explicit state passing) |
You don't necessarily have to replace all your NumPy code with jax.numpy
. They can coexist.
jax.numpy
for the core numerical computations within your algorithms, particularly those you intend to accelerate with jit
, differentiate with grad
, vectorize with vmap
, or parallelize with pmap
.Think of jax.numpy
as the NumPy-like interface to JAX's powerful compilation and transformation engine. By understanding both the similarities and the fundamental differences concerning immutability, execution, and hardware capabilities, you can effectively leverage JAX for demanding computational tasks.
© 2025 ApX Machine Learning