As outlined in the chapter introduction, JAX achieves significant performance gains by compiling your Python functions. The primary tool for this is the jax.jit
transformation. Think of jit
(Just-In-Time compilation) as a way to take a standard Python function operating on JAX arrays and convert it into a highly optimized, fused sequence of operations specific to your hardware (CPU, GPU, or TPU).
jax.jit
There are two main ways to apply the jit
transformation:
@jax.jit
on the line before the function definition.jit
directly, passing your function as an argument. This returns a new, compiled version of your function.Let's look at a simple example. Suppose we have a function that performs a few numerical operations:
import jax
import jax.numpy as jnp
import time
# A function with some numerical computations
def complex_computation(x, weight, bias):
y = jnp.dot(x, weight) + bias
z = jnp.tanh(y)
return jnp.mean(z)
# Create some random data
key = jax.random.PRNGKey(0)
x_data = jax.random.normal(key, (1000, 500))
weight_data = jax.random.normal(key, (500, 200))
bias_data = jax.random.normal(key, (200,))
# --- Method 1: Using jit as a decorator ---
@jax.jit
def compiled_computation_decorator(x, weight, bias):
y = jnp.dot(x, weight) + bias
z = jnp.tanh(y)
return jnp.mean(z)
# --- Method 2: Using jit as a function ---
compiled_computation_functional = jax.jit(complex_computation)
# --- Timing the execution ---
# Time the original Python function
# Run once to avoid any initial overhead unrelated to JAX
_ = complex_computation(x_data, weight_data, bias_data).block_until_ready()
start_time = time.time()
result_original = complex_computation(x_data, weight_data, bias_data).block_until_ready()
end_time = time.time()
print(f"Original function time: {end_time - start_time:.6f} seconds")
# Time the JIT-compiled function (decorator version)
# First call includes compilation time
start_time_compile = time.time()
result_compiled_decorator = compiled_computation_decorator(x_data, weight_data, bias_data).block_until_ready()
end_time_compile = time.time()
print(f"Compiled function (decorator) first call (incl. compile): {end_time_compile - start_time_compile:.6f} seconds")
# Second call uses the cached compiled code
start_time_cached = time.time()
result_compiled_decorator_cached = compiled_computation_decorator(x_data, weight_data, bias_data).block_until_ready()
end_time_cached = time.time()
print(f"Compiled function (decorator) second call (cached): {end_time_cached - start_time_cached:.6f} seconds")
# Verify results are the same (within floating point tolerances)
print(f"Results match: {jnp.allclose(result_original, result_compiled_decorator)}")
# Timing the functional version (should be similar after first compile)
_ = compiled_computation_functional(x_data, weight_data, bias_data).block_until_ready() # Compile
start_time_func = time.time()
result_compiled_functional = compiled_computation_functional(x_data, weight_data, bias_data).block_until_ready()
end_time_func = time.time()
print(f"Compiled function (functional) subsequent call: {end_time_func - start_time_func:.6f} seconds")
Important Note on Timing: Notice the use of .block_until_ready()
after each function call we want to time. JAX uses asynchronous dispatch by default, meaning operations are queued but might not complete immediately. block_until_ready()
ensures the computation finishes before we record the end time, giving us accurate measurements.
You'll observe a pattern:
jit
-compiled function (compiled_computation_decorator
or compiled_computation_functional
) is often slower than the original. This is because JAX needs to perform a crucial step called tracing (which we'll discuss in the next section) and then compile the traced operations using XLA (Accelerated Linear Algebra compiler).The difference between the decorator (@jax.jit
) and functional (jax.jit(fn)
) approaches is primarily stylistic. The decorator is often preferred for its readability when defining functions. The functional form is useful when you want to compile a function obtained from elsewhere (e.g., a library function, though many JAX library functions are already JIT-compiled internally where appropriate).
In essence, jax.jit
provides a straightforward way to request compilation for your performance-sensitive numerical functions. By understanding when and how to apply it, you can unlock substantial speedups for your JAX code. The next section will explore the tracing mechanism that makes this possible.
© 2025 ApX Machine Learning