jax.jitJAX achieves significant performance gains by compiling your Python functions. The primary tool for this is the jax.jit transformation. Think of jit (Just-In-Time compilation) as a way to take a standard Python function operating on JAX arrays and convert it into a highly optimized, fused sequence of operations specific to your hardware (CPU, GPU, or TPU).
jax.jitThere are two main ways to apply the jit transformation:
@jax.jit on the line before the function definition.jit directly, passing your function as an argument. This returns a new, compiled version of your function.Let's look at a simple example. Suppose we have a function that performs a few numerical operations:
import jax
import jax.numpy as jnp
import time
# A function with some numerical computations
def complex_computation(x, weight, bias):
y = jnp.dot(x, weight) + bias
z = jnp.tanh(y)
return jnp.mean(z)
# Create some random data
key = jax.random.PRNGKey(0)
x_data = jax.random.normal(key, (1000, 500))
weight_data = jax.random.normal(key, (500, 200))
bias_data = jax.random.normal(key, (200,))
# --- Method 1: Using jit as a decorator ---
@jax.jit
def compiled_computation_decorator(x, weight, bias):
y = jnp.dot(x, weight) + bias
z = jnp.tanh(y)
return jnp.mean(z)
# --- Method 2: Using jit as a function ---
compiled_computation_functional = jax.jit(complex_computation)
# --- Timing the execution ---
# Time the original Python function
# Run once to avoid any initial overhead unrelated to JAX
_ = complex_computation(x_data, weight_data, bias_data).block_until_ready()
start_time = time.time()
result_original = complex_computation(x_data, weight_data, bias_data).block_until_ready()
end_time = time.time()
print(f"Original function time: {end_time - start_time:.6f} seconds")
# Time the JIT-compiled function (decorator version)
# First call includes compilation time
start_time_compile = time.time()
result_compiled_decorator = compiled_computation_decorator(x_data, weight_data, bias_data).block_until_ready()
end_time_compile = time.time()
print(f"Compiled function (decorator) first call (incl. compile): {end_time_compile - start_time_compile:.6f} seconds")
# Second call uses the cached compiled code
start_time_cached = time.time()
result_compiled_decorator_cached = compiled_computation_decorator(x_data, weight_data, bias_data).block_until_ready()
end_time_cached = time.time()
print(f"Compiled function (decorator) second call (cached): {end_time_cached - start_time_cached:.6f} seconds")
# Verify results are the same (within floating point tolerances)
print(f"Results match: {jnp.allclose(result_original, result_compiled_decorator)}")
# Timing the functional version (should be similar after first compile)
_ = compiled_computation_functional(x_data, weight_data, bias_data).block_until_ready() # Compile
start_time_func = time.time()
result_compiled_functional = compiled_computation_functional(x_data, weight_data, bias_data).block_until_ready()
end_time_func = time.time()
print(f"Compiled function (functional) subsequent call: {end_time_func - start_time_func:.6f} seconds")
Important Note on Timing: Notice the use of .block_until_ready() after each function call we want to time. JAX uses asynchronous dispatch by default, meaning operations are queued but might not complete immediately. block_until_ready() ensures the computation finishes before we record the end time, giving us accurate measurements.
You'll observe a pattern:
jit-compiled function (compiled_computation_decorator or compiled_computation_functional) is often slower than the original. This is because JAX needs to perform an important step called tracing (which we'll discuss in the next section) and then compile the traced operations using XLA (Accelerated Linear Algebra compiler).The difference between the decorator (@jax.jit) and functional (jax.jit(fn)) approaches is primarily stylistic. The decorator is often preferred for its readability when defining functions. The functional form is useful when you want to compile a function obtained from elsewhere (e.g., a library function, though many JAX library functions are already JIT-compiled internally where appropriate).
In essence, jax.jit provides a straightforward way to request compilation for your performance-sensitive numerical functions. By understanding when and how to apply it, you can get substantial speedups for your JAX code. The next section will explain the tracing mechanism that makes this possible.
Was this section helpful?
jax.jit, including tracing and asynchronous execution.jax.jit for performance optimization, tracing, and interaction with hardware accelerators.© 2026 ApX Machine LearningEngineered with