As mentioned in the chapter introduction, gradients are fundamental in many computational fields, particularly machine learning. Training models often involves optimizing a function (like a loss function) by iteratively adjusting parameters. Gradients tell us how to make these adjustments effectively. But what exactly is a gradient?
Informally, for a function that takes multiple numerical inputs and produces a single numerical output (a scalar function), the gradient is a vector that points in the direction of the steepest increase of the function at a specific input point. The magnitude (length) of this gradient vector indicates how steep that increase is.
Consider a simple function with one input variable, f(x). You might remember from calculus that its derivative, often written as f′(x) or dxdf, represents the slope of the tangent line to the function's graph at point x. It tells you the instantaneous rate of change of the function's output with respect to its input.
The function f(x)=x2 and the tangent line at x=2. The slope of the tangent line (4) is the derivative (gradient) at that point.
Now, let's extend this to a function with multiple input variables, say f(x1,x2,...,xn). The gradient generalizes the derivative concept. Instead of a single slope, we get a vector of partial derivatives. The partial derivative ∂xi∂f measures how the function f changes if we slightly change the input xi while keeping all other inputs (xj where j=i) constant.
The gradient of f at a point (x1,...,xn) is the vector containing all its partial derivatives:
∇f(x1,...,xn)=[∂x1∂f,∂x2∂f,...,∂xn∂f]The symbol ∇ (nabla) is commonly used to denote the gradient operator.
For example, consider the function f(x,y)=x2+sin(y). The partial derivative with respect to x is ∂x∂f=2x. The partial derivative with respect to y is ∂y∂f=cos(y). The gradient vector is ∇f(x,y)=[2x,cos(y)].
At the point (x=1,y=0), the gradient is ∇f(1,0)=[2(1),cos(0)]=[2,1]. This vector [2,1] indicates the direction from (1,0) in which the function f(x,y) increases most rapidly.
A primary use case for gradients is function minimization. Imagine you have a "cost" or "loss" function that measures how poorly your machine learning model is performing. You want to find the model parameters that make this cost as small as possible.
The gradient ∇f points in the direction of the steepest ascent. Therefore, the negative gradient, −∇f, points in the direction of the steepest descent.
This is the core idea behind the gradient descent algorithm:
Each step moves closer to a local minimum of the function.
new_parameters=old_parameters−learning_rate×∇f(old_parameters)Here, the learning_rate
is a small positive scalar that controls the step size.
Calculating gradients manually using the rules of calculus, as we did for f(x,y)=x2+sin(y), is feasible for simple functions. However, the functions encountered in machine learning (like deep neural networks) can involve millions of parameters and complex compositions of operations. Deriving the gradients manually becomes impractical and highly prone to errors.
We could try numerical differentiation, which approximates the gradient by evaluating the function at slightly perturbed points (e.g., using the definition hf(x+h)−f(x) for small h). However, this method suffers from approximation errors (due to the choice of h) and can be computationally expensive, as it requires multiple function evaluations for each dimension of the gradient.
Symbolic differentiation, as performed by computer algebra systems, manipulates mathematical expressions to find the exact derivative expression. While exact, it can lead to very complex and potentially inefficient expressions ("expression swell"), especially for large computations.
This is where Automatic Differentiation (AD) comes in. AD is a set of techniques that computes the exact numerical value of a function's gradient efficiently by systematically applying the chain rule of calculus at the level of elementary operations (addition, multiplication, sin, cos, etc.) within the function's computation. It avoids the approximation errors of numerical differentiation and the potential expression swell of symbolic differentiation.
JAX's grad
transformation is built upon highly optimized implementations of AD, specifically reverse-mode AD (also known as backpropagation in the context of neural networks). This mode is particularly efficient for the common case in machine learning where we have a function with many inputs mapping to a single scalar output (like a loss function).
In the following sections, we will explore how to use jax.grad
to leverage the power of automatic differentiation for your own Python functions.
© 2025 ApX Machine Learning