Having established that Zygote.jl provides the machinery for automatic differentiation, we now focus on how to practically utilize the computed gradients within the Flux.jl training workflow. Gradients are fundamental to training neural networks; they quantify how much the loss function will change if a model parameter (like a weight or bias) is altered slightly. Optimizers then use this information to adjust the parameters in a direction that minimizes the loss.
Flux.gradient
In Flux.jl, gradients are typically computed using Zygote.gradient
. While Zygote is the underlying automatic differentiation engine, Flux often provides convenient wrappers or directly uses Zygote's functions. To calculate gradients, you need three main components:
Let's look at a typical usage pattern:
using Flux, Zygote
# Define a simple model
model = Dense(3, 2) # 3 input features, 2 output features
# Sample input data and target output
x_sample = rand(Float32, 3)
y_target = rand(Float32, 2)
# Define a loss function. It takes the model, input, and target.
# This is a common pattern, but the loss function itself only needs to see
# what's necessary to compute the scalar loss.
functionmse_loss(m, x, y_true)
y_pred = m(x)
return sum((y_pred .- y_true).^2) / length(y_true) # Mean Squared Error
end
# Collect the model's parameters using Flux.params
# This tells Zygote which variables to differentiate.
parameters = Flux.params(model)
# Calculate the gradients
# The first argument to Zygote.gradient is an anonymous function
# that calls our loss function.
grads = Zygote.gradient(() -> mse_loss(model, x_sample, y_target), parameters)
In this snippet, Zygote.gradient
is called with an anonymous function () -> mse_loss(model, x_sample, y_target)
. This function takes no arguments and, when called, executes our loss calculation. The second argument, parameters
, is a collection obtained from Flux.params(model)
, which tells Zygote the variables for which gradients should be computed.
The result, grads
, is a Zygote.Grads
object. This object behaves like a dictionary where the keys are the original parameter arrays (e.g., model.weight
, model.bias
) and the values are the corresponding gradient arrays.
The Zygote.Grads
object provides a clean way to access the gradient for any specific parameter you passed to Flux.params
. For instance, to get the gradient of the loss with respect to the weights of our Dense
layer model
:
# Assuming 'model' and 'grads' are from the previous example
gradient_weights = grads[model.weight]
gradient_bias = grads[model.bias]
println("Gradient for weights:\n", gradient_weights)
println("Gradient for bias:\n", gradient_bias)
If your model is a Chain
of multiple layers, Flux.params(model)
will collect parameters from all layers within the chain. The grads
object will then contain entries for each of these parameters. For example, if model = Chain(Dense(3, 4), Dense(4, 2))
, then Flux.params(model)
would include weights and biases for both Dense
layers, and grads
would provide access to their respective gradients.
# Example for a Chain
complex_model = Chain(
Dense(10, 5, relu), # Layer 1
Dense(5, 2) # Layer 2
)
x_complex = rand(Float32, 10)
y_complex_target = rand(Float32, 2)
# Parameters for the complex model
params_complex = Flux.params(complex_model)
# Gradients for the complex model
grads_complex = Zygote.gradient(() -> mse_loss(complex_model, x_complex, y_complex_target), params_complex)
# Accessing gradients for the first layer's weights
# The layers in a Chain are typically accessed by index
grad_layer1_weights = grads_complex[complex_model[1].weight]
grad_layer2_bias = grads_complex[complex_model[2].bias]
# println("Gradient for Layer 1 weights:\n", grad_layer1_weights)
# println("Gradient for Layer 2 bias:\n", grad_layer2_bias)
This direct mapping between parameters and their gradients is highly convenient for debugging or for implementing custom training logic.
In a standard training loop, these gradients are used by an optimizer to update the model's parameters. The general formula for a simple gradient descent update for a parameter W is:
Wnew=Wold−η⋅∇WL
where L is the loss, ∇WL is the gradient of the loss with respect to W, and η is the learning rate. Flux.jl optimizers (like ADAM
, SGD
, etc.) implement variations of this rule.
While Flux.train!
automates this process, understanding the steps is informative:
Zygote.gradient
).The following diagram illustrates this flow:
This diagram shows the cycle of operations in one step of training a neural network, from input data through model prediction, loss calculation, gradient computation, and finally parameter updates via an optimizer.
When you use Flux.train!(loss_function, params, data, optimizer)
, Flux performs these steps internally. For each batch of data in data
, it:
loss_function
(which should execute the forward pass and return the loss).params
using Zygote.gradient
.optimizer
(e.g., Flux.Optimise.update!(opt, p, g)
for each parameter p
and gradient g
).Manually inspecting gradients can be very useful, especially when debugging a network that isn't learning correctly or when trying to understand the learning dynamics.
Flux.params
.NaN
(Not a Number) due to numerical overflow. Large gradients cause drastic updates to parameters, often overshooting the optimal values.One common practice is to monitor the norm of the gradients. A large norm might indicate exploding gradients, while a very small norm could suggest vanishing gradients.
# After computing grads:
for p in parameters
g = grads[p]
if g !== nothing
# println("Norm of gradient for parameter: ", norm(g))
else
# println("No gradient for parameter (or parameter was not used in loss).")
end
end
Techniques like gradient clipping (scaling down gradients if their norm exceeds a threshold) can help manage exploding gradients, while architectural changes (e.g., using ReLU activations, residual connections) or careful initialization can mitigate vanishing gradients. These are more advanced topics but highlight why understanding gradients is important.
Zygote.pullback
While Zygote.gradient
is convenient for obtaining gradients of a scalar loss function with respect to a set of parameters, Zygote also offers a more fundamental function called Zygote.pullback
.
A "pullback" is a function that, given a gradient from "above" (i.e., the gradient of some later computation with respect to the output of your current function), computes the gradients with respect to the inputs of your current function.
Zygote.gradient(f, args...)
is essentially a shorthand for:
y, back = Zygote.pullback(f, args...)
.back(dy)
where dy
is the gradient of the final scalar output (which is implicitly 1.0 for a scalar loss function, meaning dLdL=1).You typically don't need to use Zygote.pullback
directly when training standard Flux models with a scalar loss. However, it becomes very useful for more advanced scenarios, such as:
For most deep learning tasks covered in this course, Zygote.gradient
(often used implicitly by Flux.train!
) will be sufficient.
Understanding how gradients are computed, structured, and utilized is a significant step toward mastering neural network training. In the upcoming hands-on practical, you'll put this knowledge to use as you build and train your first neural network in Flux.jl.
Was this section helpful?
© 2025 ApX Machine Learning