When you train a neural network, minimizing the loss function is the primary goal of the optimizer. However, the loss value itself, especially for complex functions like cross-entropy, doesn't always provide a human-interpretable measure of how well your model is performing on its intended task. This is where performance metrics come in. Metrics like accuracy, precision, recall, or F1-score give you a clearer picture of your model's capabilities.
If you're coming from TensorFlow Keras, you're accustomed to specifying metrics easily within the model.compile()
method:
# Keras example
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy', tf.keras.metrics.Precision()])
Keras then takes care of calculating and reporting these metrics during training (model.fit()
) and evaluation (model.evaluate()
). These Keras metrics, like tf.keras.metrics.Accuracy()
, are stateful objects. They automatically accumulate values (e.g., total correct predictions and total samples) across batches and compute the final metric value.
PyTorch, in line with its more explicit philosophy, requires you to be more involved in the metric calculation process. There isn't a direct, built-in equivalent to Keras's automatic metric tracking system integrated into the training loop structure. Instead, you'll typically perform these calculations yourself.
For simple metrics, direct tensor operations within your training or evaluation loop are often sufficient. For example, to calculate accuracy for a batch in a classification task:
# PyTorch manual accuracy calculation for a batch
# outputs: raw logits from the model
# labels: ground truth labels
_, predicted_classes = torch.max(outputs.data, 1)
correct_predictions = (predicted_classes == labels).sum().item()
batch_accuracy = correct_predictions / labels.size(0)
To get epoch-level accuracy, you'd need to sum correct_predictions
and the total number of samples (labels.size(0)
) across all batches in that epoch, then perform the division:
# Accumulating for epoch-level accuracy
# Initialize before the epoch loop
epoch_total_correct = 0
epoch_total_samples = 0
# Inside the batch loop (after getting correct_predictions and labels.size(0))
# epoch_total_correct += correct_predictions
# epoch_total_samples += labels.size(0)
# After the epoch loop
# epoch_accuracy = epoch_total_correct / epoch_total_samples
This manual approach gives you full control but can become repetitive and error-prone, especially for more complex metrics like AUC or F1-score that require careful state management across batches.
To bridge this gap and offer a more convenient, Keras-like experience for metrics, the PyTorch ecosystem provides TorchMetrics. This library, developed by the PyTorch Lightning team, can be used independently in any PyTorch project. It offers a wide array of common machine learning metrics that are:
You'll first need to install it:
pip install torchmetrics
Here's how you might use TorchMetrics
for accuracy in a multi-class classification problem:
import torch
import torchmetrics
# Define number of classes and device
NUM_CLASSES = 10 # Example for CIFAR-10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize the metric, ensuring it's on the correct device
accuracy_metric = torchmetrics.Accuracy(task="multiclass", num_classes=NUM_CLASSES).to(device)
# --- Inside your training or evaluation loop ---
# for images, labels in data_loader:
# images, labels = images.to(device), labels.to(device)
# outputs = model(images) # Get model predictions (logits)
#
# # Update the metric state with predictions and targets for the current batch
# # TorchMetrics typically expects raw logits for classification tasks
# # or probabilities depending on the metric and its arguments.
# # For Accuracy with 'multiclass' task, it applies argmax internally.
# accuracy_metric.update(outputs, labels)
# --- End of batch loop ---
# After processing all batches in an epoch:
# Compute the metric over all accumulated batches
epoch_accuracy = accuracy_metric.compute()
print(f"Epoch Accuracy: {epoch_accuracy.item():.4f}")
# Reset the metric's internal state for the next epoch or evaluation phase
accuracy_metric.reset()
Using TorchMetrics
involves three main steps:
torchmetrics.Accuracy
, torchmetrics.Precision
, torchmetrics.F1Score
). You typically specify parameters like task
(e.g., "binary", "multiclass", "multilabel") and num_classes
if applicable. Send the metric object to the same device as your data and model.update(predictions, targets)
method. This accumulates the necessary statistics from the current batch.compute()
method to get the final metric value.reset()
to clear the internal state of the metric, making it ready for the next epoch or a new evaluation run.TorchMetrics
offers a comprehensive suite of metrics for classification, regression, image analysis, and more, significantly simplifying metric tracking in PyTorch.
Another alternative, especially if a specific metric isn't available in TorchMetrics
or for quick, ad-hoc evaluations, is to use sklearn.metrics
. Since scikit-learn functions operate on NumPy arrays, you'll need to:
.detach()
..cpu()
..numpy()
.from sklearn.metrics import precision_score
# Assuming all_preds_list and all_labels_list contain
# all predictions and labels for an epoch, collected as CPU tensors
# e.g., all_preds_cpu = torch.cat(all_preds_list).cpu().numpy()
# all_labels_cpu = torch.cat(all_labels_list).cpu().numpy()
# precision = precision_score(all_labels_cpu, all_preds_cpu, average='macro')
# print(f"Scikit-learn Precision (macro): {precision:.4f}")
While viable, this approach is generally less efficient for per-batch updates within a GPU-accelerated training loop compared to TorchMetrics
. It's often more practical for end-of-epoch or final model evaluation.
Let's put this together and see how TorchMetrics
fits into a standard PyTorch evaluation loop:
# Assume: model, val_dataloader, criterion (loss function), device are defined
# And a metric from TorchMetrics is initialized:
# val_accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=NUM_CLASSES).to(device)
# val_f1 = torchmetrics.F1Score(task="multiclass", num_classes=NUM_CLASSES, average="macro").to(device)
model.eval() # Set the model to evaluation mode
running_val_loss = 0.0
# Reset metrics at the start of evaluation
val_accuracy.reset()
# val_f1.reset() # if using multiple metrics
with torch.no_grad(): # Disable gradient computations for evaluation
for inputs, labels in val_dataloader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
running_val_loss += loss.item() * inputs.size(0)
# Update metrics
val_accuracy.update(outputs, labels)
# val_f1.update(outputs, labels)
# Compute final metrics after iterating through all validation data
epoch_val_loss = running_val_loss / len(val_dataloader.dataset)
final_val_accuracy = val_accuracy.compute()
# final_val_f1 = val_f1.compute()
print(f"Validation Loss: {epoch_val_loss:.4f}")
print(f"Validation Accuracy: {final_val_accuracy.item():.4f}")
# print(f"Validation F1 Score: {final_val_f1.item():.4f}")
In this example, model.eval()
is important as it sets layers like dropout and batch normalization to evaluation mode. torch.no_grad()
disables gradient calculation, which speeds up computation and reduces memory usage during inference. Metrics are updated per batch and computed once at the end.
The following table summarizes the main differences in how metrics are handled:
Comparison of metric handling in Keras and PyTorch. PyTorch offers fine-grained control, with
TorchMetrics
providing convenient, stateful metric objects similar to Keras.
TorchMetrics
provides many of these.TorchMetrics
often requires specifying the task
(e.g., "binary", "multiclass", "regression") and other parameters like num_classes
or average
(for F1/Precision/Recall).torch.utils.tensorboard.SummaryWriter
) or Weights & Biases.By understanding how to calculate and track performance metrics in PyTorch, either manually or with helper libraries like TorchMetrics
, you gain deeper insight into your model's learning process and its true performance, moving beyond just the training loss. This explicit control, while requiring a bit more setup than Keras's compile
/fit
pattern, aligns with PyTorch's philosophy of giving developers full transparency and command over their model training pipelines.
© 2025 ApX Machine Learning