Automatic differentiation is a core component of JAX, enabling gradient-based optimization for machine learning models. While jax.grad
provides a convenient interface for obtaining gradients, understanding and controlling the underlying mechanisms offers greater flexibility and efficiency for complex tasks.
This chapter moves beyond the basics of jax.grad
. You will learn about the two fundamental modes of automatic differentiation: forward-mode (Jacobian-vector products, JVPs) and reverse-mode (vector-Jacobian products, VJPs). We will cover how to compute these directly using jax.jvp
and jax.vjp
, which form the building blocks for jax.grad
and other transformations.
Key topics include:
jax.custom_vjp
and jax.custom_jvp
, useful for numerical stability, performance optimization, or handling non-JAX code.lax.scan
, lax.cond
, and lax.while_loop
.jax.lax.stop_gradient
.By the end of this chapter, you will have a more detailed understanding of JAX's autodiff system and the tools to apply it effectively in advanced scenarios.
4.1 Review of Forward- and Reverse-Mode Autodiff
4.2 Jacobian-Vector Products (JVPs) with jax.jvp
4.3 Vector-Jacobian Products (VJPs) with jax.vjp
4.4 Higher-Order Derivatives
4.5 Computing Full Jacobians and Hessians
4.6 Custom Differentiation Rules with jax.custom_vjp
4.7 Custom Differentiation Rules with jax.custom_jvp
4.8 Differentiation through Control Flow Primitives
4.9 Handling Non-Differentiable Functions
4.10 Practice: Implementing a Custom Gradient
© 2025 ApX Machine Learning