As you transition from TensorFlow Keras to PyTorch, you'll find that specifying a loss function, a way to measure how far your model's predictions are from the actual target values, is handled more explicitly. In Keras, you typically define the loss as part of the model.compile()
step, often using a string identifier like 'binary_crossentropy'
or an instance from tf.keras.losses
. PyTorch requires you to instantiate the loss function yourself and then call it directly within your custom training loop. This approach, while requiring a few more lines of code, gives you a clearer view and finer control over the process.
torch.nn
PyTorch provides a rich set of pre-defined loss functions within the torch.nn
module. These are implemented as classes that inherit from torch.nn.Module
. To use one, you first create an instance of the desired loss class, and then, during your training loop, you call this instance with your model's predictions and the ground truth targets.
For example, to use Mean Squared Error (MSE) loss, you would do something like this:
import torch
import torch.nn as nn
# Instantiate the loss function
mse_loss_fn = nn.MSELoss()
# Example model output and target
predictions = torch.randn(10, 1, requires_grad=True) # 10 samples, 1 output feature
targets = torch.randn(10, 1)
# Calculate the loss
loss = mse_loss_fn(predictions, targets)
print(loss) # Output: tensor(..., grad_fn=<MseLossBackward0>)
# The loss tensor can then be used for backpropagation
# loss.backward()
Many loss functions are also available in a functional form in torch.nn.functional
(often imported as F
). For instance, F.mse_loss(predictions, targets)
would achieve the same as the example above. While functional forms can be convenient for stateless operations, using the nn.Module
versions is generally preferred for the primary loss function of your model, as they can maintain state (like reduction type) and integrate more seamlessly into the nn.Module
paradigm.
When moving from TensorFlow, there are a few important distinctions in how PyTorch loss functions operate, particularly regarding input expectations and target formats.
This is one of the most common areas of confusion for developers switching from TensorFlow.
Multi-class Classification (nn.CrossEntropyLoss
):
tf.keras.losses.CategoricalCrossentropy
often expects probabilities as input (i.e., the output of a softmax layer), unless from_logits=True
is specified.nn.CrossEntropyLoss
expects raw, unnormalized scores (logits) directly from your model's final linear layer. It combines a LogSoftmax
layer and a Negative Log-Likelihood Loss (NLLLoss
) in one efficient step. This means you should not apply a Softmax
activation to your model's output before passing it to nn.CrossEntropyLoss
.Binary Classification (nn.BCEWithLogitsLoss
vs. nn.BCELoss
):
tf.keras.losses.BinaryCrossentropy
can take either probabilities (output of a sigmoid, from_logits=False
) or logits (from_logits=True
).nn.BCELoss
: Expects probabilities as input (i.e., model output passed through torch.sigmoid
).nn.BCEWithLogitsLoss
: Expects raw logits as input. This version is more numerically stable and is generally recommended over using nn.BCELoss
preceded by a separate torch.sigmoid
layer. It combines a sigmoid layer and the BCELoss
in one operation.nn.CrossEntropyLoss
(Multi-class):
tf.keras.losses.CategoricalCrossentropy
typically expects targets to be one-hot encoded. tf.keras.losses.SparseCategoricalCrossentropy
expects integer class labels.nn.CrossEntropyLoss
expects target tensors to contain integer class indices. For a C-class problem, the target for each sample should be an integer in the range [0,C−1]. For example, for an N-sample batch, the predictions might have shape (N, C)
(logits) and targets shape (N,)
(class indices).nn.BCEWithLogitsLoss
/ nn.BCELoss
(Binary/Multi-label):
(N, 1)
or (N,)
. For multi-label classification, where each sample can belong to multiple classes, predictions and targets would have shape (N, C)
, where C is the number of classes, and targets are binary indicators (0 or 1).Both TensorFlow and PyTorch loss functions allow you to specify how the losses calculated for each item in a batch should be aggregated into a single scalar value.
reduction
argument in the loss constructor, using tf.keras.losses.Reduction.SUM
, tf.keras.losses.Reduction.NONE
, or implicitly 'auto'
which usually means averaging over the batch.nn.Module
-based loss functions have a reduction
parameter in their constructor, which can be set to:
'mean'
(default): The sum of the per-sample losses is divided by the number of elements.'sum'
: The per-sample losses are summed.'none'
: No reduction is applied; the loss for each element is returned.# Example of reduction
targets = torch.tensor([0., 1., 0.])
predictions_logits = torch.tensor([-1.0, 2.0, -0.5]) # Logits
loss_fn_mean = nn.BCEWithLogitsLoss(reduction='mean')
print(loss_fn_mean(predictions_logits, targets)) # Output: tensor(0.3711)
loss_fn_sum = nn.BCEWithLogitsLoss(reduction='sum')
print(loss_fn_sum(predictions_logits, targets)) # Output: tensor(1.1133) # 0.3711 * 3
loss_fn_none = nn.BCEWithLogitsLoss(reduction='none')
print(loss_fn_none(predictions_logits, targets)) # Output: tensor([0.3133, 0.1269, 0.4741])
Here’s a quick comparison of some frequently used loss functions and their equivalents:
Use Case | TensorFlow Keras (tf.keras.losses ) |
PyTorch (torch.nn ) |
PyTorch Input Expectation | PyTorch Target Expectation (Typical Batch) |
---|---|---|---|---|
Regression (Mean Square) | MeanSquaredError() |
MSELoss() |
Any real values | Same shape as input |
Regression (Mean Absolute) | MeanAbsoluteError() |
L1Loss() |
Any real values | Same shape as input |
Binary Classification | BinaryCrossentropy(from_logits=True) |
BCEWithLogitsLoss() (Recommended) |
Logits | (N, 1) or (N,) , floats 0.0 or 1.0 |
Binary Classification | BinaryCrossentropy(from_logits=False) |
BCELoss() (Requires sigmoid on input) |
Probabilities | (N, 1) or (N,) , floats 0.0 or 1.0 |
Multi-class Classification | CategoricalCrossentropy(from_logits=True) or SparseCategoricalCrossentropy() |
CrossEntropyLoss() (Recommended) |
Logits (N, C) |
Class indices (N,) , long integers |
Multi-label Classification | BinaryCrossentropy(from_logits=True) |
BCEWithLogitsLoss() |
Logits (N, C) |
Binary matrix (N, C) , floats 0.0 or 1.0 |
The relationship between the final activation function (like Sigmoid or Softmax) and the loss function is important. PyTorch's preferred loss functions for classification (nn.BCEWithLogitsLoss
and nn.CrossEntropyLoss
) often incorporate the activation step for numerical stability and efficiency.
Common pathways for model outputs, activations (if separate), and loss functions in TensorFlow/Keras versus PyTorch for binary and multi-class classification. Note PyTorch's recommended direct use of logits with specific loss functions.
nn.CrossEntropyLoss
. Ensure your model's final layer outputs raw logits (no Softmax) and your targets are integer class labels.nn.BCEWithLogitsLoss
. Your model's final layer should output a single logit per sample. Targets should be floats (0.0 or 1.0).nn.BCEWithLogitsLoss
. Your model's final layer should output C logits per sample (one for each class). Targets should be a binary matrix of shape (N, C) with 0.0s and 1.0s.nn.MSELoss
(for L2 loss) and nn.L1Loss
(for L1 loss, Mean Absolute Error) are common choices.If you need a loss function not available in torch.nn
, you can easily implement your own.
For simple custom losses that don't require learnable parameters or maintain state, a Python function that takes predictions and targets and returns a scalar loss tensor is sufficient.
def my_custom_rmse_loss(predictions, targets):
return torch.sqrt(torch.mean((predictions - targets)**2))
# Usage:
# loss = my_custom_rmse_loss(model_outputs, ground_truth)
If your custom loss function needs to store parameters or inherit nn.Module
behavior (e.g., to be discoverable by model.children()
), you can subclass nn.Module
:
class MyParameterizedLoss(nn.Module):
def __init__(self, alpha=0.5):
super().__init__()
self.alpha = alpha
def forward(self, predictions, targets):
# Example: A weighted sum of L1 and L2 loss
l1_loss = torch.abs(predictions - targets).mean()
mse_loss = ((predictions - targets)**2).mean()
return self.alpha * l1_loss + (1 - self.alpha) * mse_loss
# Usage:
# criterion = MyParameterizedLoss(alpha=0.3)
# loss = criterion(model_outputs, ground_truth)
Here's how a loss function typically fits into a PyTorch training step:
import torch
import torch.nn as nn
import torch.optim as optim
# Assume model, train_loader are defined
# model = YourNeuralNetwork()
# optimizer = optim.Adam(model.parameters(), lr=0.001)
# For multi-class classification
criterion = nn.CrossEntropyLoss()
# --- Inside your training loop ---
# for inputs, labels in train_loader:
# optimizer.zero_grad() # Zero the gradients
#
# outputs = model(inputs) # Forward pass (outputs are logits)
#
# loss = criterion(outputs, labels) # Calculate loss
#
# loss.backward() # Backward pass (compute gradients)
# optimizer.step() # Update weights
# --- End of training loop snippet ---
In this snippet, criterion(outputs, labels)
directly computes the loss. outputs
are the raw logits from the model, and labels
are the integer class indices. This loss
tensor then drives the backpropagation process via loss.backward()
.
Understanding these nuances, especially around logits and the combined operations within certain PyTorch loss functions like CrossEntropyLoss
and BCEWithLogitsLoss
, will make your transition from TensorFlow Keras much smoother and help you avoid common mistakes. By explicitly defining and using loss functions, you gain a more granular control that is characteristic of PyTorch's design.
© 2025 ApX Machine Learning