With your neural network architecture defined using Flux.jl, a suitable loss function selected to quantify prediction errors, and an optimizer chosen to direct the learning process, we arrive at the core of training: iteratively adjusting the model's parameters. This training regimen is typically organized into multiple passes over the entire dataset, known as epochs. Within each epoch, the data is usually processed in smaller segments called mini-batches to make computation more manageable and often to improve the learning dynamics.
Each step, or iteration, within this training loop fundamentally consists of a sequence of operations:
Flux.jl provides the tools to implement this training loop efficiently. Let's examine how these operations translate into Flux.jl code.
First, you need an optimizer and its associated state. Flux.jl uses an explicit state management system for optimizers, which is initialized using Flux.setup
. This state tracks things like momentum or adaptive learning rates for more advanced optimizers.
using Flux # Also brings in Optimisers.jl functionalities
# Assume 'model' is your defined Flux model (e.g., a Chain)
# model = Chain(Dense(10, 5, relu), Dense(5, 1))
# Choose an optimizer rule
opt_rule = Adam(0.001) # Adam optimizer with a learning rate of 0.001
# Setup the optimizer state for the model
opt_state = Flux.setup(opt_rule, model)
The opt_state
now holds the necessary information for the Adam
optimizer to work with your specific model
.
To calculate the loss and the gradients simultaneously, Flux.jl offers Flux.withgradient
. This function takes the model (or its parameters) and a function that computes the loss. It returns both the loss value and the gradients.
# Assume x_batch, y_batch are available
# Assume loss_fn(m, x, y) is defined, e.g.:
# loss_fn(m, x, y) = Flux.mse(m(x), y)
# Inside your training iteration:
loss_value, grads = Flux.withgradient(model) do m
# The model 'm' passed here is the one for which gradients are computed
y_hat = m(x_batch) # Forward pass
loss_fn(m, x_batch, y_batch) # Compute loss
end
Here, loss_value
is the scalar result of loss_fn
, and grads
is a collection of gradients. Specifically, grads[1]
will contain the gradients corresponding to the parameters of model
.
With the gradients obtained, the final step in an iteration is to update the model's parameters using the optimizer:
Flux.update!(opt_state, model, grads[1])
The Flux.update!
function modifies the parameters of model
in-place, using the logic of the optimizer rule (e.g., Adam) stored in opt_state
and the computed grads[1]
.
Neural network training requires data to be in a format suitable for processing, typically as multi-dimensional arrays or tensors. As mentioned, training is usually performed on mini-batches. Flux.jl provides Flux.Data.DataLoader
as a convenient utility for batching and shuffling your dataset.
using Flux.Data: DataLoader
# Assuming train_X (features) and train_Y (labels) are your full dataset arrays
# e.g., train_X is a 10x1000 matrix (10 features, 1000 samples)
# train_Y is a 1x1000 matrix (1 output, 1000 samples)
batch_size = 32
train_loader = DataLoader((train_X, train_Y), batchsize=batch_size, shuffle=true)
# In your training loop, you would then iterate over train_loader:
# for (x_batch_from_loader, y_batch_from_loader) in train_loader
# # ... perform training steps ...
# end
Using DataLoader
with shuffle=true
ensures that the model sees data in a different order each epoch, which can help prevent overfitting and improve generalization.
Let's put all these pieces together into a functional training loop. This example will define a simple model, loss function, optimizer, and then iterate through epochs and batches to train the model.
using Flux
# DataLoader is often used, ensure it's accessible (e.g. using Flux.Data: DataLoader)
# 1. Define the model
model = Chain(
Dense(10 => 5, relu), # 10 input features, 5 output neurons, ReLU activation
Dense(5 => 1) # 5 input features, 1 output neuron (e.g., for regression)
)
# 2. Define the loss function
# This version of loss_fn takes the model, input, and target
# It's suitable for use with Flux.withgradient(model) do m ... end
loss_fn(m, x, y) = Flux.mse(m(x), y) # Mean Squared Error
# 3. Define the optimizer and setup its state
opt_rule = Adam(0.001) # Adam optimizer with a learning rate of 0.001
opt_state = Flux.setup(opt_rule, model)
# 4. Prepare data
# For a practical example, replace with your actual data loading
dummy_X = rand(Float32, 10, 100) # 10 features, 100 samples
dummy_Y = rand(Float32, 1, 100) # 1 output, 100 samples
# Ensure DataLoader is available, e.g., via `using Flux.Data: DataLoader` if not automatically.
# If Flux re-exports it directly:
train_loader = Flux.DataLoader((dummy_X, dummy_Y), batchsize=32, shuffle=true)
# 5. The training loop
num_epochs = 10
println("Starting training for $num_epochs epochs...")
for epoch in 1:num_epochs
epoch_cumulative_loss = 0.0
batches_processed = 0
for (x_batch, y_batch) in train_loader
# Calculate loss and gradients for the current batch
# The loss_fn takes the model 'm_in_grad' as its first argument here
current_loss, grads = Flux.withgradient(model) do m_in_grad
loss_fn(m_in_grad, x_batch, y_batch)
end
# Update model parameters using the calculated gradients
Flux.update!(opt_state, model, grads[1])
epoch_cumulative_loss += current_loss
batches_processed += 1
end
avg_epoch_loss = epoch_cumulative_loss / batches_processed
println("Epoch: $epoch, Average Batch Loss: $avg_epoch_loss")
end
println("Training complete.")
In this comprehensive loop:
num_epochs
.train_loader
provides mini-batches of data.Flux.withgradient
is used to compute both the current_loss
for the batch and the grads
(gradients). The model itself (model
) is passed as the first argument to Flux.withgradient
, and the anonymous function do m_in_grad ... end
receives this model (here named m_in_grad
) to use for the forward pass and loss calculation within the gradient computation context.Flux.update!
applies the optimizer's logic to adjust the model
's parameters in place using grads[1]
.The following diagram illustrates the flow of a single training iteration:
An overview of a single training iteration, showing the flow of data, computation of loss and gradients, and parameter updates.
This detailed process of iterating through data, calculating losses, deriving gradients, and updating parameters is fundamental to training most neural networks. With Flux.jl, these steps are expressed quite directly, allowing for both clarity and flexibility in building and training your deep learning models.
Was this section helpful?
© 2025 ApX Machine Learning