Okay, let's put the theory into practice. In this section, we'll apply jax.jit
to a few functions and observe its effect on execution speed. We'll also explore how to measure performance accurately and see firsthand some of the concepts discussed earlier, like the cost of the initial compilation.
First, make sure you have JAX installed and import the necessary libraries. We'll primarily use jax
, jax.numpy
(conventionally imported as jnp
), and Python's timeit
module for basic timing.
import jax
import jax.numpy as jnp
import timeit
import numpy as np # We'll use standard NumPy for comparison setup
# Check available devices (optional, but good practice)
print(f"JAX devices: {jax.devices()}")
# Helper function for more robust timing
def time_func(func, *args, num_runs=10, warmup_runs=2):
"""Times a function execution, handling warmup and JAX async."""
# Warmup runs
for _ in range(warmup_runs):
result = func(*args)
if isinstance(result, jax.Array):
result.block_until_ready() # Ensure JAX computation completes
# Timed runs
times = []
for _ in range(num_runs):
start_time = timeit.default_timer()
result = func(*args)
if isinstance(result, jax.Array):
result.block_until_ready() # IMPORTANT: Wait for JAX op to finish
end_time = timeit.default_timer()
times.append(end_time - start_time)
return np.mean(times), np.std(times)
# Create some sample data on the default device
key = jax.random.PRNGKey(0)
size = 1000
x_jnp = jax.random.normal(key, (size, size))
y_jnp = jax.random.normal(key, (size, size))
# Ensure data is on the device before timing
x_jnp.block_until_ready()
y_jnp.block_until_ready()
Notice the use of result.block_until_ready()
. JAX operations are dispatched asynchronously. This means that when you call a JAX function, Python might return control to your script before the actual computation on the GPU/TPU (or even CPU) has finished. For accurate timing, we need to explicitly wait for the result to be computed.
Let's define a function that performs several jax.numpy
operations:
def compute_heavy(a, b):
"""A sample function with several JAX operations."""
c = jnp.dot(a, b)
d = jnp.sin(c)
e = jnp.log(jnp.abs(d) + 1e-6) # Add epsilon for numerical stability
f = jnp.sum(e)
return f
# Time the original function
mean_time_orig, std_time_orig = time_func(compute_heavy, x_jnp, y_jnp)
print(f"Original function time: {mean_time_orig:.6f} +/- {std_time_orig:.6f} seconds")
# Now, let's JIT-compile the function
compute_heavy_jit = jax.jit(compute_heavy)
# Time the JIT-compiled function
# Note: The *very first* run includes compilation time!
print("Timing JIT version (first run will include compilation)...")
mean_time_jit, std_time_jit = time_func(compute_heavy_jit, x_jnp, y_jnp)
print(f"JIT function time: {mean_time_jit:.6f} +/- {std_time_jit:.6f} seconds")
# Calculate speedup
speedup = mean_time_orig / mean_time_jit
print(f"\nApproximate speedup: {speedup:.2f}x")
When you run this, you should observe a significant difference in execution time between the original and the JIT-compiled versions after the initial compilation cost. The JIT compiler fuses the sequence of jnp
operations into a single, optimized kernel that runs much faster on the target accelerator.
The exact speedup depends on your hardware (CPU vs GPU/TPU), the size of the arrays, and the complexity of the function. For purely numerical code like this, the gains from jit
are often substantial.
As discussed previously, jit
works by tracing the function with abstract values. This can sometimes lead to unexpected behavior or errors if the control flow depends on the specific values being traced.
Consider this function:
def conditional_computation(x, threshold):
"""Performs different computations based on a condition."""
if jnp.sum(x) > threshold:
return jnp.dot(x, x.T) * 2
else:
return jnp.dot(x, x.T) / 2
# Try JITing it
conditional_computation_jit = jax.jit(conditional_computation)
# Create some small data
x_small = jnp.array([1.0, 2.0, 3.0])
# Run with a value that satisfies the condition
print("Running JIT with condition TRUE:")
result_true = conditional_computation_jit(x_small, 5.0)
result_true.block_until_ready()
print(f"Result (True): {result_true}")
# Run with a value that does NOT satisfy the condition
# This might trigger a re-compilation or work depending on JAX version/details
print("\nRunning JIT with condition FALSE:")
try:
result_false = conditional_computation_jit(x_small, 10.0)
result_false.block_until_ready()
print(f"Result (False): {result_false}")
except Exception as e:
print(f"Caught an error (as expected sometimes): {e}")
# Example using jax.lax for staged-out control flow
import jax.lax
@jax.jit
def conditional_computation_lax(x, threshold):
"""Uses lax.cond for JIT-compatible conditional logic."""
return jax.lax.cond(
jnp.sum(x) > threshold, # Condition
lambda op: op * 2, # True branch function
lambda op: op / 2, # False branch function
jnp.dot(x, x.T) # Operand passed to the chosen branch
)
print("\nRunning JIT with lax.cond:")
result_true_lax = conditional_computation_lax(x_small, 5.0)
result_true_lax.block_until_ready()
print(f"Result lax (True): {result_true_lax}")
result_false_lax = conditional_computation_lax(x_small, 10.0)
result_false_lax.block_until_ready()
print(f"Result lax (False): {result_false_lax}")
When JAX traces conditional_computation
, it encounters the Python if
statement. Because the condition jnp.sum(x) > threshold
depends on the value of x
(which is abstract during tracing), JAX might have difficulty creating a single compiled artifact that works for all possible outcomes of the if
. Depending on the specifics, it might:
The standard JAX way to handle this is using structured control flow primitives like jax.lax.cond
(for conditionals) or jax.lax.scan
, jax.lax.fori_loop
(for loops). These functions are designed to be traceable by jit
. The conditional_computation_lax
example demonstrates lax.cond
, which stages out the conditional logic into the compiled XLA graph, avoiding the Python-level if
statement issue during runtime.
Sometimes, an argument to your function determines the structure of the computation rather than just participating as data. For instance, consider a function that applies a matrix multiplication repeatedly:
def apply_n_times(x, n):
"""Applies matrix multiplication n times."""
y = x
for _ in range(n): # Python loop!
y = jnp.dot(y, x)
return y
# Try JITing without specifying static arguments
apply_n_times_jit = jax.jit(apply_n_times)
x_matrix = jax.random.normal(key, (50, 50))
x_matrix.block_until_ready()
print("Timing apply_n_times_jit (n=2):")
time_func(apply_n_times_jit, x_matrix, 2) # First run, compiles for n=2
print("\nTiming apply_n_times_jit (n=3):")
# This will likely trigger a RE-COMPILATION because 'n' changed!
time_func(apply_n_times_jit, x_matrix, 3)
# Now, use static_argnums to tell JIT that 'n' affects the computation structure
apply_n_times_jit_static = jax.jit(apply_n_times, static_argnums=(1,)) # Index 1 corresponds to 'n'
print("\nTiming apply_n_times_jit_static (n=2):")
time_func(apply_n_times_jit_static, x_matrix, 2) # First run, compiles a version for n=2
print("\nTiming apply_n_times_jit_static (n=3):")
time_func(apply_n_times_jit_static, x_matrix, 3) # First run *with n=3*, compiles a separate version for n=3
print("\nTiming apply_n_times_jit_static (n=2) AGAIN:")
# This should be fast now, using the cached compilation for n=2
time_func(apply_n_times_jit_static, x_matrix, 2)
In the first attempt (apply_n_times_jit
), the Python for
loop's number of iterations depends directly on the value of n
. When jit
traces this, the loop gets unrolled for the specific value of n
encountered during tracing (e.g., n=2
). When you call the function again with n=3
, the trace is different (the loop needs to unroll 3 times), forcing a re-compilation.
By using jax.jit(..., static_argnums=(1,))
, we tell JAX that the argument at index 1 (n
) is static. This means JAX won't try to trace through its value. Instead, it will treat n
as a constant for a given compilation. If you call the function with a different static value (like changing n
from 2 to 3), JAX will recognize this and compile a new, specialized version of the function for that specific value of n
. Subsequent calls with the same static value (n=2
again) will reuse the cached, already compiled version. This avoids runtime re-compilation overhead while still allowing the compiled code to be optimized for specific structural variations controlled by the static argument.
Let's visualize the timing difference from Example 1.
# Data for plotting (using results from Example 1)
labels = ['Original Function', 'JIT Compiled']
mean_times = [mean_time_orig, mean_time_jit]
std_devs = [std_time_orig, std_time_jit]
{"data": [{"type": "bar", "x": ["Original Function", "JIT Compiled"], "y": [mean_time_orig, mean_time_jit], "marker": {"color": ["#4263eb", "#f76707"]}, "error_y": {"type": "data", "array": [std_time_orig, std_time_jit], "visible": true, "color": "#495057"}}], "layout": {"title": "Function Execution Time Comparison", "yaxis": {"title": "Average Execution Time (seconds)", "type": "log", "tickformat": ".1e"}, "xaxis": {"title": "Function Version"}, "bargap": 0.3, "template": "plotly_white", "width": 600, "height": 400}}
Comparison of average execution time (log scale) for the original Python function and its JIT-compiled version. Error bars represent standard deviation over multiple runs. Lower is better. Note the significant reduction in time after JIT compilation.
This hands-on session demonstrated the practical application of jax.jit
. You've seen how it can significantly accelerate numerical code, how Python control flow interacts with tracing, and how to use static_argnums
to manage compilation for functions whose structure depends on certain arguments. Remember to always use .block_until_ready()
when timing JAX code and be mindful of the initial compilation cost. As you build more complex JAX programs, jit
will be an indispensable tool in your performance optimization toolkit.
© 2025 ApX Machine Learning