During training, Batch Normalization works its magic by standardizing the inputs to a layer using the mean and variance calculated from the current mini-batch. This approach helps stabilize training dynamics, as discussed previously. However, a question naturally arises: what happens when we want to use the trained model to make predictions on new data (inference or testing)?
At inference time, several issues make using mini-batch statistics impractical or undesirable:
The standard solution is to use statistics that represent the entire training dataset (or a reasonable estimate thereof) instead of just the current mini-batch. These are often referred to as population statistics. Specifically, during inference, Batch Normalization normalizes the input using a fixed estimate of the mean and variance derived from the training data.
How are these population statistics obtained? They are typically estimated during the training process using an exponential moving average. As the model trains over many mini-batches, the framework keeps track of running estimates for the mean (μpop) and variance (σpop2) for each feature dimension being normalized.
For each mini-batch B during training, the framework calculates the batch mean μB and variance σB2. These are then used to update the running estimates, often using a momentum term (let's call it momentum
, typically close to 0.1):
These updates happen at each training step, gradually refining the estimates μpop and σpop2 to reflect the overall statistics of the activations seen during training.
So, when the model is set to evaluation mode (e.g., using model.eval()
in PyTorch), the Batch Normalization layer switches its behavior. Instead of calculating μB and σB2 from the input, it uses the pre-computed running estimates μpop and σpop2.
The normalization calculation during inference becomes:
x^i=σpop2+ϵxi−μpopHere, xi is an input feature, μpop and σpop2 are the estimated population mean and variance respectively, and ϵ is the small constant added for numerical stability.
It's important to remember that the learned scaling parameter γ and shifting parameter β are still applied after this normalization, just as they were during training:
yi=γx^i+βThese γ and β parameters are part of the model's learned weights and remain fixed after training.
Fortunately, deep learning frameworks like PyTorch and TensorFlow handle this switch between training and inference behavior automatically. When you define a Batch Normalization layer, it internally maintains these running statistics.
import torch
import torch.nn as nn
# Example: A simple block with Conv -> BN -> ReLU
conv_block = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
# Batch Norm layer for 64 channels
nn.BatchNorm2d(num_features=64), # Tracks running_mean and running_var
nn.ReLU()
)
# --- During Training ---
# Set the model to training mode
conv_block.train()
# Generate some dummy input data (Batch Size=4, Channels=3, Height=32, Width=32)
input_data_train = torch.randn(4, 3, 32, 32)
# Forward pass during training: BN uses mini-batch stats & updates running stats
output_train = conv_block(input_data_train)
# --- During Inference ---
# Set the model to evaluation mode
conv_block.eval()
# Generate some dummy test data (Batch Size=1)
input_data_test = torch.randn(1, 3, 32, 32)
# Forward pass during inference: BN uses the stored running_mean and running_var
output_test = conv_block(input_data_test)
# Print the running mean tracked by the BN layer
# Note: These values are populated during the training pass(es)
print("Running Mean Shape:", conv_block[1].running_mean.shape)
print("Running Variance Shape:", conv_block[1].running_var.shape)
In the PyTorch example above, calling conv_block.train()
sets the modules within the Sequential
block (including BatchNorm2d
) to training mode. In this mode, BatchNorm2d
calculates statistics from the input batch and updates its internal running_mean
and running_var
. Calling conv_block.eval()
switches the modules to evaluation mode. Now, BatchNorm2d
no longer uses the current input batch's statistics but instead uses the running_mean
and running_var
that were estimated during training. This ensures consistent and deterministic output during inference.
Most Batch Normalization implementations have a parameter like track_running_stats
(usually True
by default). When set to True
, the layer estimates population statistics during training and uses them during evaluation. If set to False
, it would always use batch statistics, which is generally not desired for standard inference unless you have specific reasons.
In summary, Batch Normalization adapts its behavior between training and testing. It uses dynamic mini-batch statistics during training to stabilize learning and estimates population statistics via moving averages. At test time, it uses these fixed population statistics, along with the learned scale and shift parameters, to ensure deterministic and consistent normalization of inputs.
© 2025 ApX Machine Learning