gradMany computational tasks, especially in machine learning model training, rely on calculating the gradient of functions. Automatic differentiation provides an efficient way to compute these derivatives. This chapter introduces jax.grad, the core JAX transformation for obtaining gradient functions from Python code operating on numerical inputs.
You will learn how to:
jax.grad to compute the gradient of a scalar-valued function .grad.grad.jax.value_and_grad.By the end of this chapter, you will be able to use jax.grad effectively to differentiate your numerical functions within the JAX framework.
3.1 Understanding Gradients
3.2 Introducing `jax.grad`
3.3 How Autodiff Works: Reverse Mode
3.4 Differentiating with Respect to Arguments
3.5 Higher-Order Derivatives (`grad` of `grad`)
3.6 Value and Gradient (`jax.value_and_grad`)
3.7 Differentiation and Control Flow
3.8 Limitations and Considerations
3.9 Hands-on Practical: Computing Gradients
© 2026 ApX Machine LearningEngineered with