Automatic differentiation is a core component of JAX, enabling gradient-based optimization for machine learning models. While jax.grad provides a convenient interface for obtaining gradients, understanding and controlling the underlying mechanisms offers greater flexibility and efficiency for complex tasks.
This chapter moves beyond the basics of jax.grad. You will learn about the two fundamental modes of automatic differentiation: forward-mode (Jacobian-vector products, JVPs) and reverse-mode (vector-Jacobian products, VJPs). We will cover how to compute these directly using jax.jvp and jax.vjp, which form the building blocks for jax.grad and other transformations.
Key topics include:
jax.custom_vjp and jax.custom_jvp, useful for numerical stability, performance optimization, or handling non-JAX code.lax.scan, lax.cond, and lax.while_loop.jax.lax.stop_gradient.By the end of this chapter, you will have a more detailed understanding of JAX's autodiff system and the tools to apply it effectively in advanced scenarios.
4.1 Review of Forward- and Reverse-Mode Autodiff
4.2 Jacobian-Vector Products (JVPs) with jax.jvp
4.3 Vector-Jacobian Products (VJPs) with jax.vjp
4.4 Higher-Order Derivatives
4.5 Computing Full Jacobians and Hessians
4.6 Custom Differentiation Rules with jax.custom_vjp
4.7 Custom Differentiation Rules with jax.custom_jvp
4.8 Differentiation through Control Flow Primitives
4.9 Handling Non-Differentiable Functions
4.10 Practice: Implementing a Custom Gradient
© 2026 ApX Machine LearningEngineered with