JAX achieves its performance and capabilities primarily through function transformations like jit
(Just-In-Time compilation), grad
(automatic differentiation), vmap
(vectorization), and pmap
(parallelization). These transformations analyze and rewrite your Python code to run efficiently on accelerators. For this analysis and transformation process to work reliably and predictably, JAX operates best on functions that adhere to the principles of functional programming, specifically the concept of functional purity.
A pure function is a function that exhibits two main characteristics:
Let's look at simple Python examples:
# Pure function: Always returns the same output for the same inputs, no side effects.
def pure_add(a, b):
return a + b
# Impure function: Has a side effect (printing to console).
def impure_add_and_print(a, b):
result = a + b
print(f"Calculated {a} + {b} = {result}") # Side effect!
return result
# Impure function: Modifies external state (a global variable).
call_count = 0
def impure_increment_counter(x):
global call_count
call_count += 1 # Side effect! Modifies global state.
return x + 1
# Impure function: Modifies an input argument in-place.
def impure_append_to_list(data_list, value):
data_list.append(value) # Side effect! Modifies the input list.
return data_list
In the examples above, pure_add
is pure. Calling pure_add(2, 3)
will always return 5
. In contrast, impure_add_and_print
performs printing, impure_increment_counter
changes the global call_count
, and impure_append_to_list
modifies the list passed to it. These actions are side effects.
JAX's function transformations rely heavily on the assumption that the functions they operate on are pure. Here’s why:
Tracing Mechanism: Transformations like jax.jit
work by tracing your Python function. JAX executes the function once with abstract placeholder values (tracers) that record the sequence of operations performed. This recorded sequence (often represented as an intermediate language called Jaxpr) is then compiled (e.g., into XLA for execution on GPU/TPU). If the function has side effects, these might occur during the initial trace but not during subsequent runs of the compiled code, leading to unexpected behavior. For example, a print
statement inside a jit
-ted function will typically only execute once during tracing, not every time the compiled function is called.
Caching and Optimization: JAX caches the compiled versions of functions. When you call a jit
-ted function with arguments of the same shape and type, JAX reuses the already compiled code. This caching assumes the function's output depends only on its inputs (determinism). If a function's behavior depends on external state (like a global variable), the cached version might become stale or produce incorrect results if that external state changes. Purity ensures the cache remains valid.
Automatic Differentiation: jax.grad
works by analyzing the mathematical operations within a function to compute gradients. It needs a clear data flow path from inputs to outputs. Side effects introduce operations that are often non-differentiable (how do you calculate the gradient of a print
statement?) or obscure the relationship between inputs and outputs, making automatic differentiation unreliable or impossible. Modifying values in place can break the chain rule application that autodiff relies on.
Vectorization and Parallelization: Transformations like jax.vmap
and jax.pmap
replicate function execution across data dimensions or hardware devices. If the function being mapped has side effects, especially modifying shared state, it can lead to race conditions and non-deterministic outcomes. Which parallel execution gets to modify the shared state first? Pure functions guarantee that each execution is independent and produces consistent results, making parallelization safe and predictable.
Many standard programming patterns, particularly in object-oriented programming or when dealing with iterative algorithms like training machine learning models, inherently involve changing state over time. Consider updating model weights during optimization, managing the state of an optimizer (like momentum values), or even simple counters. These often involve modifying object attributes or data structures in place, patterns that directly conflict with the requirement of functional purity.
# Typical stateful pattern (impure)
class Counter:
def __init__(self):
self.n = 0
def count(self):
self.n += 1 # In-place modification (side effect)
return self.n
def reset(self):
self.n = 0 # In-place modification (side effect)
# Using the impure counter
my_counter = Counter()
print(my_counter.count()) # Output: 1
print(my_counter.count()) # Output: 2
# Applying JAX transformations to methods like count() can lead to issues.
Because JAX transformations work best with pure functions, we need alternative patterns to handle state updates that avoid side effects. The core idea, which we will explore next, is to make state handling explicit: pass the current state into the function as an argument and have the function return the new, updated state as part of its output. This approach keeps the functions pure while still allowing us to model stateful computations.
© 2025 ApX Machine Learning