Many computational tasks, especially in machine learning model training, rely on calculating the gradient of functions. Automatic differentiation provides an efficient way to compute these derivatives. This chapter introduces jax.grad
, the core JAX transformation for obtaining gradient functions from Python code operating on numerical inputs.
You will learn how to:
jax.grad
to compute the gradient ∇f(x) of a scalar-valued function f.grad
.grad
.jax.value_and_grad
.By the end of this chapter, you will be able to use jax.grad
effectively to differentiate your numerical functions within the JAX framework.
3.1 Understanding Gradients
3.2 Introducing `jax.grad`
3.3 How Autodiff Works: Reverse Mode
3.4 Differentiating with Respect to Arguments
3.5 Higher-Order Derivatives (`grad` of `grad`)
3.6 Value and Gradient (`jax.value_and_grad`)
3.7 Differentiation and Control Flow
3.8 Limitations and Considerations
3.9 Hands-on Practical: Computing Gradients
© 2025 ApX Machine Learning