Automatic differentiation, a capability often provided by libraries like Zygote.jl, is essential for deep learning. This enables the computation of gradients, which are practically utilized within the Flux.jl training workflow. Gradients are fundamental to training neural networks; they quantify how much the loss function will change if a model parameter (like a weight or bias) is altered slightly. Optimizers then use this information to adjust the parameters in a direction that minimizes the loss.
Flux.gradientIn Flux.jl, gradients are typically computed using Zygote.gradient. While Zygote is the underlying automatic differentiation engine, Flux often provides convenient wrappers or directly uses Zygote's functions. To calculate gradients, you need three main components:
Let's look at a typical usage pattern:
using Flux, Zygote
# Define a simple model
model = Dense(3, 2) # 3 input features, 2 output features
# Sample input data and target output
x_sample = rand(Float32, 3)
y_target = rand(Float32, 2)
# Define a loss function. It takes the model, input, and target.
# This is a common pattern, but the loss function itself only needs to see
# what's necessary to compute the scalar loss.
functionmse_loss(m, x, y_true)
y_pred = m(x)
return sum((y_pred .- y_true).^2) / length(y_true) # Mean Squared Error
end
# Collect the model's parameters using Flux.params
# This tells Zygote which variables to differentiate.
parameters = Flux.params(model)
# Calculate the gradients
# The first argument to Zygote.gradient is an anonymous function
# that calls our loss function.
grads = Zygote.gradient(() -> mse_loss(model, x_sample, y_target), parameters)
In this snippet, Zygote.gradient is called with an anonymous function () -> mse_loss(model, x_sample, y_target). This function takes no arguments and, when called, executes our loss calculation. The second argument, parameters, is a collection obtained from Flux.params(model), which tells Zygote the variables for which gradients should be computed.
The result, grads, is a Zygote.Grads object. This object behaves like a dictionary where the keys are the original parameter arrays (e.g., model.weight, model.bias) and the values are the corresponding gradient arrays.
The Zygote.Grads object provides a clean way to access the gradient for any specific parameter you passed to Flux.params. For instance, to get the gradient of the loss with respect to the weights of our Dense layer model:
# Assuming 'model' and 'grads' are from the previous example
gradient_weights = grads[model.weight]
gradient_bias = grads[model.bias]
println("Gradient for weights:\n", gradient_weights)
println("Gradient for bias:\n", gradient_bias)
If your model is a Chain of multiple layers, Flux.params(model) will collect parameters from all layers within the chain. The grads object will then contain entries for each of these parameters. For example, if model = Chain(Dense(3, 4), Dense(4, 2)), then Flux.params(model) would include weights and biases for both Dense layers, and grads would provide access to their respective gradients.
# Example for a Chain
complex_model = Chain(
Dense(10, 5, relu), # Layer 1
Dense(5, 2) # Layer 2
)
x_complex = rand(Float32, 10)
y_complex_target = rand(Float32, 2)
# Parameters for the complex model
params_complex = Flux.params(complex_model)
# Gradients for the complex model
grads_complex = Zygote.gradient(() -> mse_loss(complex_model, x_complex, y_complex_target), params_complex)
# Accessing gradients for the first layer's weights
# The layers in a Chain are typically accessed by index
grad_layer1_weights = grads_complex[complex_model[1].weight]
grad_layer2_bias = grads_complex[complex_model[2].bias]
# println("Gradient for Layer 1 weights:\n", grad_layer1_weights)
# println("Gradient for Layer 2 bias:\n", grad_layer2_bias)
This direct mapping between parameters and their gradients is highly convenient for debugging or for implementing custom training logic.
In a standard training loop, these gradients are used by an optimizer to update the model's parameters. The general formula for a simple gradient descent update for a parameter is:
where is the loss, is the gradient of the loss with respect to , and is the learning rate. Flux.jl optimizers (like ADAM, SGD, etc.) implement variations of this rule.
While Flux.train! automates this process, understanding the steps is informative:
Zygote.gradient).The following diagram illustrates this flow:
This diagram shows the cycle of operations in one step of training a neural network, from input data through model prediction, loss calculation, gradient computation, and finally parameter updates via an optimizer.
When you use Flux.train!(loss_function, params, data, optimizer), Flux performs these steps internally. For each batch of data in data, it:
loss_function (which should execute the forward pass and return the loss).params using Zygote.gradient.optimizer (e.g., Flux.Optimise.update!(opt, p, g) for each parameter p and gradient g).Manually inspecting gradients can be very useful, especially when debugging a network that isn't learning correctly or when trying to understand the learning dynamics.
Flux.params.NaN (Not a Number) due to numerical overflow. Large gradients cause drastic updates to parameters, often overshooting the optimal values.One common practice is to monitor the norm of the gradients. A large norm might indicate exploding gradients, while a very small norm could suggest vanishing gradients.
# After computing grads:
for p in parameters
g = grads[p]
if g !== nothing
# println("Norm of gradient for parameter: ", norm(g))
else
# println("No gradient for parameter (or parameter was not used in loss).")
end
end
Techniques like gradient clipping (scaling down gradients if their norm exceeds a threshold) can help manage exploding gradients, while architectural changes (e.g., using ReLU activations, residual connections) or careful initialization can mitigate vanishing gradients. These are more advanced topics but highlight why understanding gradients is important.
Zygote.pullbackWhile Zygote.gradient is convenient for obtaining gradients of a scalar loss function with respect to a set of parameters, Zygote also offers a more fundamental function called Zygote.pullback.
A "pullback" is a function that, given a gradient from "above" (i.e., the gradient of some later computation with respect to the output of your current function), computes the gradients with respect to the inputs of your current function.
Zygote.gradient(f, args...) is essentially a shorthand for:
y, back = Zygote.pullback(f, args...).back(dy) where dy is the gradient of the final scalar output (which is implicitly 1.0 for a scalar loss function, meaning ).You typically don't need to use Zygote.pullback directly when training standard Flux models with a scalar loss. However, it becomes very useful for more advanced scenarios, such as:
For most deep learning tasks covered in this course, Zygote.gradient (often used implicitly by Flux.train!) will be sufficient.
Understanding how gradients are computed, structured, and utilized is a significant step toward mastering neural network training. In the upcoming hands-on practical, you'll put this knowledge to use as you build and train your first neural network in Flux.jl.
Was this section helpful?
© 2026 ApX Machine LearningEngineered with