While TensorFlow Keras provides a convenient model.fit()
method enriched by a system of Callbacks for managing various aspects of the training process, PyTorch encourages a more hands-on approach. As you've seen, the training loop in PyTorch is something you write explicitly. This explicitness extends to how you implement training control mechanisms like early stopping, model checkpointing, and dynamic learning rate adjustments. Instead of predefined Callback objects, you'll integrate this logic directly into your training script. This gives you a granular level of control and a clear view of what happens at each step.
If you've worked extensively with Keras, you're likely familiar with its Callback system. Callbacks are objects passed to the fit()
method that can perform actions at various stages of training (e.g., at the beginning or end of an epoch, before or after a batch). Common examples include:
ModelCheckpoint
: Saves the model or weights at some frequency.EarlyStopping
: Stops training when a monitored metric has stopped improving.ReduceLROnPlateau
: Reduces the learning rate when a metric has stopped improving.TensorBoard
: Logs events for visualization with TensorBoard.These Callbacks abstract away the underlying logic, making it easy to add common functionalities. In PyTorch, you'll achieve similar results by writing the logic yourself.
Let's explore how to implement these common training control patterns within a standard PyTorch training loop.
Early stopping helps prevent overfitting by halting training if the model's performance on a validation set ceases to improve for a specified number of consecutive epochs (often called "patience").
To implement early stopping, you'll need to:
Here's how you might integrate this into your training loop:
# Assuming these are defined elsewhere:
# model, train_loader, val_loader, optimizer, criterion
# num_epochs, patience
best_val_loss = float('inf')
epochs_no_improve = 0
for epoch in range(num_epochs):
model.train()
# --- Training phase ---
for batch_data, batch_labels in train_loader:
optimizer.zero_grad()
outputs = model(batch_data)
loss = criterion(outputs, batch_labels)
loss.backward()
optimizer.step()
# --- Validation phase ---
model.eval()
current_val_loss = 0.0
with torch.no_grad():
for val_data, val_labels in val_loader:
val_outputs = model(val_data)
loss = criterion(val_outputs, val_labels)
current_val_loss += loss.item()
current_val_loss /= len(val_loader)
print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {current_val_loss:.4f}")
# Early stopping logic
if current_val_loss < best_val_loss:
best_val_loss = current_val_loss
epochs_no_improve = 0
# Optionally, save the best model here (see next section)
torch.save(model.state_dict(), 'best_model_checkpoint.pth')
print(f"Validation loss improved. Saved best model.")
else:
epochs_no_improve += 1
print(f"Validation loss did not improve for {epochs_no_improve} epoch(s).")
if epochs_no_improve >= patience:
print(f"Early stopping triggered after {epoch+1} epochs.")
break
In this snippet, patience
would be an integer you define (e.g., 5 or 10). If the validation loss doesn't improve for patience
epochs, the training loop terminates.
Model checkpointing involves saving your model's state (usually its learned parameters) during training. This is useful for several reasons:
As shown in the early stopping example above, a common strategy is to save the model whenever the validation metric improves:
# Inside the validation phase, after calculating current_val_loss
if current_val_loss < best_val_loss:
best_val_loss = current_val_loss
epochs_no_improve = 0
torch.save(model.state_dict(), 'best_model_checkpoint.pth') # Saving the model
print(f"Epoch {epoch+1}: Validation loss improved to {best_val_loss:.4f}, saving model...")
else:
# ... (epochs_no_improve logic)
You can also save checkpoints at fixed epoch intervals, regardless of validation performance, if you want to keep a history of models or for very long training runs.
Adjusting the learning rate during training is a common technique to improve convergence and final model performance. PyTorch provides the torch.optim.lr_scheduler
module, which offers various ways to alter the learning rate over time. This is analogous to Keras's LearningRateScheduler
or ReduceLROnPlateau
callbacks.
To use a learning rate scheduler:
scheduler.step()
at the appropriate point in your training loop (usually after each epoch, or sometimes after each batch, depending on the scheduler). For schedulers like ReduceLROnPlateau
, you pass the metric to monitor (e.g., validation loss) to the step()
method.Let's see an example using ReduceLROnPlateau
, which reduces the learning rate when a monitored metric has stopped improving, similar to its Keras counterpart:
from torch.optim.lr_scheduler import ReduceLROnPlateau
# Assuming optimizer is already defined
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Initialize the scheduler
# 'min' mode means LR will be reduced when the quantity monitored has stopped decreasing
# factor is the factor by which the learning rate will be reduced. new_lr = lr * factor
# patience is the number of epochs with no improvement after which learning rate will be reduced.
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)
# ... (inside your training loop, after the validation phase) ...
# After calculating current_val_loss for the epoch:
scheduler.step(current_val_loss)
When scheduler.step(current_val_loss)
is called, the scheduler checks if the current_val_loss
has improved according to its patience
setting. If not, it reduces the learning rate for the optimizer
by the specified factor
. The verbose=True
argument will print a message when the learning rate is adjusted.
Other common schedulers include:
StepLR
: Decays the learning rate by a factor gamma
every step_size
epochs.ExponentialLR
: Decays the learning rate by a factor gamma
every epoch.CosineAnnealingLR
: Adjusts the learning rate using a cosine annealing schedule.The key is to call scheduler.step()
at the correct frequency (typically once per epoch, immediately after optimizer.step()
for batch-wise schedulers, or after validation for epoch-wise schedulers like ReduceLROnPlateau
).
While Keras has a TensorBoard
callback that automatically logs many metrics, in PyTorch, you'll typically use torch.utils.tensorboard.SummaryWriter
to explicitly log values. You can log training loss, validation loss, accuracies, learning rates, or any other custom metric at any point in your loop.
from torch.utils.tensorboard import SummaryWriter
# Initialize writer
writer = SummaryWriter('runs/my_experiment_name')
# ... (inside your training loop) ...
# Logging training loss (per batch or averaged per epoch)
# Assume train_loss is calculated per batch
# writer.add_scalar('Loss/train_batch', train_loss.item(), epoch * len(train_loader) + batch_idx)
# Logging validation loss (per epoch)
# Assume current_val_loss is calculated per epoch
writer.add_scalar('Loss/validation_epoch', current_val_loss, epoch)
# Logging learning rate (per epoch)
writer.add_scalar('LearningRate', optimizer.param_groups[0]['lr'], epoch)
# ... (after training) ...
writer.close()
You then run TensorBoard pointing to the runs
directory to visualize these metrics. This gives you precise control over what gets logged and when.
The way training control is handled differs significantly. In Keras, callbacks are somewhat like plugins to the fit
method's execution cycle. In PyTorch, these control mechanisms are integral parts of your custom loop.
Comparison of training loop structures. Keras Callbacks act at defined points within
model.fit()
. PyTorch integrates control logic directly into the custom-written loop.
on_epoch_end
, on_batch_begin
). In PyTorch, you decide precisely where your logic executes within the loop structure you've built. This can be after a batch, before validation, after optimizer step. The control is entirely yours.best_val_loss
or epochs_no_improve
) directly as variables in your training script or within a class if you structure your training loop that way. Keras Callbacks often encapsulate their own state.While PyTorch doesn't have a formal Callback system like Keras, you can still create reusable components. If you find yourself repeatedly writing the same logic for, say, early stopping across different projects, you can encapsulate it into Python functions or even simple classes.
For example, an early stopping function might look like:
def check_early_stopping(current_val_loss, best_val_loss, epochs_no_improve, patience):
stop_training = False
if current_val_loss < best_val_loss:
best_val_loss = current_val_loss
epochs_no_improve = 0
# Potentially return a flag to save model
else:
epochs_no_improve += 1
if epochs_no_improve >= patience:
stop_training = True
print(f"Early stopping triggered. Patience: {patience}.")
return best_val_loss, epochs_no_improve, stop_training
# In your loop:
# best_val_loss, epochs_no_improve, should_stop = check_early_stopping(...)
# if should_stop:
# break
This doesn't replicate the full event-driven nature of Keras Callbacks but promotes code reuse for common patterns within your explicit PyTorch loops. More advanced patterns, such as those involving hooks to modify gradients or activations directly, will be touched upon in Chapter 6 when discussing torch.nn.Module.register_forward_hook
and similar functionalities. For most training control, direct loop integration is the standard PyTorch practice.
In summary, managing training control in PyTorch involves writing explicit Python logic within your training and validation loops. While this may initially seem more involved than using Keras Callbacks, it provides a high degree of transparency and customization, allowing you to tailor the training process precisely to your needs. As you become more familiar with PyTorch, you'll likely appreciate this direct control over the mechanics of model training.
© 2025 ApX Machine Learning