Automatic differentiation works by applying the chain rule recursively through the primitive operations that make up a function. This process relies on each primitive having a well-defined way to compute its contribution to the overall gradient, typically via its Jacobian-vector product (JVP) or vector-Jacobian product (VJP) rule. However, not all operations are mathematically differentiable, or they might operate on types (like integers) where differentiation isn't standardly defined. Furthermore, sometimes you might choose to prevent gradient flow through certain parts of your computation for modeling or performance reasons. JAX provides mechanisms to handle these situations.
Several types of operations can pose challenges for automatic differentiation:
Mathematical Discontinuities: Functions with jumps or sharp corners lack a unique derivative at the point of discontinuity. Examples include:
jnp.round
, jnp.floor
, jnp.ceil
)jnp.sign
) at zero>
, <
, >=
, <=
, ==
, !=
) which produce boolean results. While boolean results themselves aren't differentiated, using them in control flow (lax.cond
, lax.while_loop
) or arithmetic (where they might be cast to 0/1) can create discontinuities in the function being differentiated.jnp.argmax
or jnp.argmin
: These return indices (integers), and the operation itself is discontinuous. Small changes in input values can cause the resulting index to jump.Integer Operations: Differentiation is typically defined over continuous fields (real or complex numbers). Operations that inherently involve integers, like integer casting or indexing into an array based on a computed integer value, don't have standard derivatives with respect to those integers.
Undefined Gradient Rules: Some JAX primitives might simply not have JVP or VJP rules implemented, especially less common or experimental ones. Attempting to differentiate through them will usually result in an error.
When JAX encounters an operation during differentiation for which it doesn't have a defined VJP or JVP rule, it will typically raise a TypeError
.
For some operations that are mathematically non-differentiable but common (like jnp.round
or integer casting), JAX often defines a gradient rule that returns zero. This is a pragmatic choice: it prevents errors but effectively treats the operation's output as a constant with respect to its input during the differentiation process. This might be the desired behavior, but it's important to be aware of it. For example, if y = jnp.round(x)
and you compute jax.grad(lambda x: jnp.round(x))(x_val)
, you will likely get 0.0
.
jax.lax.stop_gradient
The primary tool JAX provides for controlling gradient flow in these situations is jax.lax.stop_gradient
. This function has a simple behavior:
jax.lax.stop_gradient(x)
simply returns x
.Essentially, jax.lax.stop_gradient
tells JAX: "Compute the forward pass using this value, but treat this value as a constant during any differentiation pass."
Let's see how this works in practice. Consider a function where we want to use a value in the computation but prevent gradients from flowing back through the calculation of that value.
import jax
import jax.numpy as jnp
# A function where gradients flow normally
def f_normal(x):
y = jnp.sin(x)
z = jnp.cos(x)
return y * z # Product rule applies: d/dx (sin(x)cos(x)) = cos^2(x) - sin^2(x)
# A function where gradient through cos(x) is stopped
def f_stopped(x):
y = jnp.sin(x)
# Treat cos(x) as a constant for differentiation purposes
z = jax.lax.stop_gradient(jnp.cos(x))
# Gradient behaves like d/dx (sin(x) * K) = cos(x) * K, where K is the value of z
return y * z
grad_f_normal = jax.grad(f_normal)
grad_f_stopped = jax.grad(f_stopped)
x_val = jnp.pi / 4.0 # 45 degrees
print(f"x = {x_val:.3f}")
print(f"f_normal(x) = {f_normal(x_val):.3f}")
print(f"f_stopped(x) = {f_stopped(x_val):.3f}") # Forward pass is identical
print("\nGradients:")
# Normal gradient: cos(2*x) = cos(pi/2) = 0
print(f"Gradient (normal) = {grad_f_normal(x_val):.3f}")
# Stopped gradient: cos(x) * stop_gradient(cos(x)) = cos(x_val) * cos(x_val)
# Evaluated at x=pi/4: cos(pi/4) * cos(pi/4) = (1/sqrt(2)) * (1/sqrt(2)) = 0.5
print(f"Gradient (stopped) = {grad_f_stopped(x_val):.3f}")
In f_stopped
, the stop_gradient
call ensures that during the backward pass, no gradient signal flows into the computation of jnp.cos(x)
. The value z
computed by jnp.cos(x)
is used in the forward pass, but from the perspective of differentiation, it's treated as if it were a pre-computed constant. Therefore, the gradient calculation effectively becomes d/dx(sin(x)×constant)=cos(x)×constant, where the constant is the value of z
.
Using jax.lax.stop_gradient
is appropriate in several scenarios:
jnp.round
), using stop_gradient
makes the choice explicit, although JAX might implicitly do this for some functions by defining a zero gradient.stop_gradient
can achieve this.stop_gradient
can prune the gradient path.stop_gradient
can prevent unnecessary backward computations through that path (though XLA's dead code elimination often handles zero gradients efficiently).stop_gradient
can isolate these parts, although addressing the root cause of the instability is usually preferable.While stop_gradient
provides a zero gradient, this isn't always the desired behavior. If you need a non-zero "gradient signal" for optimization despite mathematical non-differentiability (common in areas like quantizing neural networks or training networks with discrete steps), you might consider:
jax.nn.sigmoid(k * x)
with a large k
to approximate a step function. The gradient will be well-defined.jax.custom_vjp
or jax.custom_jvp
(covered earlier in this chapter) to define a custom gradient behavior. A common technique here is the straight-through estimator (STE), where the forward pass uses the non-differentiable function (e.g., rounding), but the backward pass uses the gradient of a surrogate function (often the identity function, effectively passing the incoming gradient straight through).import jax
import jax.numpy as jnp
@jax.custom_vjp
def round_straight_through(x):
"""Forward pass uses jnp.round, backward pass is identity."""
return jnp.round(x)
# Define the forward and backward pass functions for custom_vjp
def round_straight_through_fwd(x):
# Forward pass returns the primal output and residuals for backward pass
return round_straight_through(x), None
def round_straight_through_bwd(residuals, g):
# Backward pass receives outgoing gradient 'g' and returns
# gradient w.r.t. inputs. Here, we pass 'g' straight through.
return (g,)
# Register the forward and backward functions with the primitive
round_straight_through.defvjp(round_straight_through_fwd, round_straight_through_bwd)
def ste_example(x):
rounded_x = round_straight_through(x)
return rounded_x * x # d/dx = rounded_x * 1 + d/dx(rounded_x) * x = rounded_x + 1*x
grad_ste = jax.grad(ste_example)
x_val = 2.7
print(f"\nStraight-Through Estimator Example at x={x_val}")
print(f"ste_example({x_val}) = {ste_example(x_val):.3f}") # Uses round -> 3.0 * 2.7 = 8.1
# Gradient: round(x) + x = 3.0 + 2.7 = 5.7
print(f"Gradient grad_ste({x_val}) = {grad_ste(x_val):.3f}")
This example uses jax.custom_vjp
to implement an STE for jnp.round
. The forward pass computes jnp.round(x)
, but the VJP rule is defined to simply pass the incoming gradient g
back as the gradient with respect to x
, effectively using the identity function's gradient for the backward pass. This allows gradient-based optimization to "pass through" the rounding operation.
Handling non-differentiable functions requires understanding both the mathematical limitations and the tools JAX provides. jax.lax.stop_gradient
offers a direct way to prevent gradient flow, while techniques like smooth approximations or custom differentiation rules provide more control when a zero gradient is not sufficient. Choosing the right approach depends on the specific problem and the desired behavior during optimization.
© 2025 ApX Machine Learning