Alright, let's roll up our sleeves and put theory into practice. You've learned about the individual components of a PyTorch training regimen: loss functions, optimizers, and the mechanics of gradient computation. Now, we'll weave these together to construct a complete training and evaluation loop from scratch. This exercise is designed to solidify your understanding of how PyTorch handles model training, offering a clear contrast to the more abstracted model.fit()
and model.evaluate()
methods you're familiar with from TensorFlow Keras.
We'll tackle a common task: image classification. To keep things focused on the loop mechanics, we'll use the FashionMNIST dataset, a drop-in replacement for MNIST often used for benchmarking, and a relatively simple convolutional neural network (CNN).
Before we move into the code, ensure you have PyTorch and TorchVision installed. If you're working in a new environment, you can typically install them with:
pip install torch torchvision
We'll also need matplotlib
for a simple visualization at the end, so pip install matplotlib
if you don't have it.
First, let's import the necessary PyTorch modules and set up our device (GPU if available, otherwise CPU). This is a standard starting point for most PyTorch scripts.
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
Using torch.device
allows your code to be portable and run efficiently on different hardware setups.
We'll use torchvision
to load and preprocess the FashionMNIST dataset. torchvision.transforms
provides convenient tools for data augmentation and normalization.
# Transformations
transform = transforms.Compose([
transforms.ToTensor(), # Converts PIL image or NumPy ndarray to tensor and scales to [0,1]
transforms.Normalize((0.5,), (0.5,)) # Normalize with mean 0.5 and std 0.5
])
# Load FashionMNIST dataset
train_dataset = torchvision.datasets.FashionMNIST(
root='./data', train=True, download=True, transform=transform
)
test_dataset = torchvision.datasets.FashionMNIST(
root='./data', train=False, download=True, transform=transform
)
# Create DataLoaders
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
# For reference, class names
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot')
Here, transforms.ToTensor()
converts images into PyTorch tensors, and transforms.Normalize()
adjusts pixel values to a specific range, which often helps with training stability. DataLoader
handles batching, shuffling, and parallel data loading.
Let's define a simple CNN. If you've worked through Chapter 2, this structure will look familiar. We subclass nn.Module
and define our layers in __init__
and the forward pass in the forward
method.
class SimpleCNN(nn.Module):
def __init__(self, num_classes=10):
super(SimpleCNN, self).__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1, stride=1), # 28x28x1 -> 28x28x16
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2), # 28x28x16 -> 14x14x16
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1, stride=1), # 14x14x16 -> 14x14x32
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2) # 14x14x32 -> 7x7x32
)
self.fc_layers = nn.Sequential(
nn.Linear(32 * 7 * 7, 128), # Flatten 7*7*32 feature map
nn.ReLU(),
nn.Linear(128, num_classes)
)
def forward(self, x):
x = self.conv_layers(x)
x = x.view(x.size(0), -1) # Flatten the output for the fully connected layer
x = self.fc_layers(x)
return x
# Instantiate the model and move it to the device
model = SimpleCNN(num_classes=len(classes)).to(device)
print(model)
This model has two convolutional layers followed by max pooling, and then two fully connected layers. It's a standard, yet effective, architecture for image classification tasks like FashionMNIST.
Next, we select a loss function and an optimizer. For multi-class classification, CrossEntropyLoss
is a common choice. For the optimizer, Adam
is a popular and effective algorithm.
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
Notice how we pass model.parameters()
to the optimizer. This tells the optimizer which tensors (weights and biases of our model) it needs to update during training. The learning rate lr
is a hyperparameter you can tune.
This is where PyTorch's explicitness shines. We manually iterate through epochs and batches.
def train_one_epoch(epoch_index, tb_writer=None):
running_loss = 0.
last_loss = 0.
correct_predictions = 0
total_samples = 0
# Set model to training mode
model.train(True) # or model.train()
for i, data in enumerate(train_loader):
# Every data instance is an input + label pair
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device) # Move data to the configured device
# Zero your gradients for every batch!
optimizer.zero_grad()
# Make predictions for this batch
outputs = model(inputs)
# Compute the loss and its gradients
loss = criterion(outputs, labels)
loss.backward() # Computes dloss/dx for every parameter x which has requires_grad=True
# Adjust learning weights
optimizer.step() # Updates the value of x using the gradient x.grad
# Gather data and report
running_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total_samples += labels.size(0)
correct_predictions += (predicted == labels).sum().item()
if i % 100 == 99: # Log every 100 batches
last_loss = running_loss / 100 # loss per batch
current_accuracy = 100 * correct_predictions / (batch_size * (i + 1)) # approximate accuracy
print(f' Epoch {epoch_index + 1}, Batch {i + 1:5d} loss: {last_loss:.3f}, Accuracy: {current_accuracy:.2f}%')
if tb_writer:
tb_writer.add_scalar('Training Loss', last_loss, epoch_index * len(train_loader) + i)
tb_writer.add_scalar('Training Accuracy', current_accuracy, epoch_index * len(train_loader) + i)
running_loss = 0.
epoch_accuracy = 100 * correct_predictions / total_samples
return last_loss, epoch_accuracy # Return last batch loss and epoch accuracy
Let's break down the critical steps within the batch loop:
inputs, labels = inputs.to(device), labels.to(device)
: Data for the current batch is moved to the GPU (if available).optimizer.zero_grad()
: This is extremely important. PyTorch accumulates gradients by default. If you don't zero them out at the start of each batch, you'd be accumulating gradients from previous batches, leading to incorrect updates.outputs = model(inputs)
: The forward pass, where the input data flows through the network to produce predictions.loss = criterion(outputs, labels)
: The loss function compares the model's predictions (outputs
) with the true labels (labels
).loss.backward()
: This is where the magic of autograd
happens. It computes the gradients of the loss with respect to all model parameters that have requires_grad=True
.optimizer.step()
: The optimizer updates the model's parameters using the gradients computed in the backward()
call.The following diagram illustrates the flow of operations within a typical PyTorch training loop for one batch:
Flow of operations for processing one batch within the training loop.
After each epoch, or at the end of training, you'll want to evaluate your model on a separate dataset (e.g., validation or test set) to see how well it generalizes. The evaluation loop is similar to the training loop, but with key differences:
def evaluate_model(loader):
# Set model to evaluation mode
model.eval() # or model.train(False)
running_vloss = 0.0
correct_predictions = 0
total_samples = 0
# Disable gradient computation for evaluation
with torch.no_grad(): # Important!
for i, vdata in enumerate(loader):
vinputs, vlabels = vdata
vinputs, vlabels = vinputs.to(device), vlabels.to(device)
voutputs = model(vinputs)
vloss = criterion(voutputs, vlabels)
running_vloss += vloss.item()
_, predicted = torch.max(voutputs.data, 1)
total_samples += vlabels.size(0)
correct_predictions += (predicted == vlabels).sum().item()
avg_vloss = running_vloss / (i + 1) # Average loss over all validation batches
accuracy = 100 * correct_predictions / total_samples
print(f'Validation Loss: {avg_vloss:.3f}, Validation Accuracy: {accuracy:.2f}%')
return avg_vloss, accuracy
Key differences in the evaluation loop:
model.eval()
: This sets the model to evaluation mode. It's important because some layers, like Dropout and BatchNorm, behave differently during training and evaluation. model.eval()
ensures they use their evaluation-time behavior.with torch.no_grad():
: This context manager disables gradient calculations. During evaluation, we don't need to compute gradients, so this saves memory and computation. It ensures that no part of the evaluation process accidentally modifies the model's gradients.Now, let's combine these functions into a main script that trains the model for a few epochs and evaluates it.
num_epochs = 5 # For demonstration; usually more epochs are needed
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
print("Starting Training...")
for epoch in range(num_epochs):
print(f'EPOCH {epoch + 1}:')
# Training
model.train(True)
avg_loss, train_acc = train_one_epoch(epoch) # No TensorBoard writer for simplicity here
train_losses.append(avg_loss)
train_accuracies.append(train_acc)
# Validation
model.eval()
avg_vloss, val_acc = evaluate_model(test_loader)
val_losses.append(avg_vloss)
val_accuracies.append(val_acc)
print(f'EPOCH {epoch + 1} Summary: Train Loss: {avg_loss:.3f}, Train Acc: {train_acc:.2f}%, Val Loss: {avg_vloss:.3f}, Val Acc: {val_acc:.2f}%')
print("-" * 30)
print('Finished Training')
# Optional: Save the model
# torch.save(model.state_dict(), 'fashion_mnist_cnn.pth')
# print('Model saved to fashion_mnist_cnn.pth')
This script iterates for num_epochs
. In each epoch, it calls train_one_epoch
to train the model on the training data and then evaluate_model
to assess its performance on the test data. We also collect loss and accuracy values, which can be useful for plotting later.
A common practice is to plot the training and validation loss/accuracy over epochs to monitor for overfitting and to understand training dynamics.
import matplotlib.pyplot as plt
epochs_range = range(1, num_epochs + 1)
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_losses, label='Training Loss', color='#1c7ed6', marker='o')
plt.plot(epochs_range, val_losses, label='Validation Loss', color='#fd7e14', marker='x')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_accuracies, label='Training Accuracy', color='#1c7ed6', marker='o')
plt.plot(epochs_range, val_accuracies, label='Validation Accuracy', color='#fd7e14', marker='x')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.tight_layout()
plt.show()
If you run the full script, you should see plots similar to this (actual values will vary based on the run):
Example plot showing a typical trend of training and validation loss decreasing over epochs.
Coming from TensorFlow Keras, where model.compile()
and model.fit()
abstract away many of these details, the PyTorch approach might initially seem more verbose. However, this explicitness is one of PyTorch's strengths.
While Keras provides convenience for standard workflows, mastering the PyTorch training loop empowers you to tackle a wider range of research problems and implement more sophisticated training procedures. This hands-on exercise has laid the foundation. As you progress, you'll find yourself adding more features to this basic loop, such as learning rate schedulers, early stopping, and custom logging, all of which integrate naturally into this explicit structure.
© 2025 ApX Machine Learning