During training, Batch Normalization works its magic by standardizing the inputs to a layer using the mean and variance calculated from the current mini-batch. This approach helps stabilize training dynamics. However, a question naturally arises: what happens when a trained model is used to make predictions on new data (inference or testing)?
At inference time, several issues make using mini-batch statistics impractical or undesirable:
The standard solution is to use statistics that represent the entire training dataset (or a reasonable estimate thereof) instead of just the current mini-batch. These are often referred to as population statistics. Specifically, during inference, Batch Normalization normalizes the input using a fixed estimate of the mean and variance derived from the training data.
How are these population statistics obtained? They are typically estimated during the training process using an exponential moving average. As the model trains over many mini-batches, the framework keeps track of running estimates for the mean () and variance () for each feature dimension being normalized.
For each mini-batch during training, the framework calculates the batch mean and variance . These are then used to update the running estimates, often using a momentum term (let's call it momentum, typically close to 0.1):
These updates happen at each training step, gradually refining the estimates and to reflect the overall statistics of the activations seen during training.
So, when the model is set to evaluation mode (e.g., using model.eval() in PyTorch), the Batch Normalization layer switches its behavior. Instead of calculating and from the input, it uses the pre-computed running estimates and .
The normalization calculation during inference becomes:
Here, is an input feature, and are the estimated population mean and variance respectively, and is the small constant added for numerical stability.
It's important to remember that the learned scaling parameter and shifting parameter are still applied after this normalization, just as they were during training:
These and parameters are part of the model's learned weights and remain fixed after training.
Fortunately, deep learning frameworks like PyTorch and TensorFlow handle this switch between training and inference behavior automatically. When you define a Batch Normalization layer, it internally maintains these running statistics.
import torch
import torch.nn as nn
# Example: A simple block with Conv -> BN -> ReLU
conv_block = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
# Batch Norm layer for 64 channels
nn.BatchNorm2d(num_features=64), # Tracks running_mean and running_var
nn.ReLU()
)
# --- During Training ---
# Set the model to training mode
conv_block.train()
# Generate some dummy input data (Batch Size=4, Channels=3, Height=32, Width=32)
input_data_train = torch.randn(4, 3, 32, 32)
# Forward pass during training: BN uses mini-batch stats & updates running stats
output_train = conv_block(input_data_train)
# --- During Inference ---
# Set the model to evaluation mode
conv_block.eval()
# Generate some dummy test data (Batch Size=1)
input_data_test = torch.randn(1, 3, 32, 32)
# Forward pass during inference: BN uses the stored running_mean and running_var
output_test = conv_block(input_data_test)
# Print the running mean tracked by the BN layer
# Note: These values are populated during the training pass(es)
print("Running Mean Shape:", conv_block[1].running_mean.shape)
print("Running Variance Shape:", conv_block[1].running_var.shape)
In the PyTorch example above, calling conv_block.train() sets the modules within the Sequential block (including BatchNorm2d) to training mode. In this mode, BatchNorm2d calculates statistics from the input batch and updates its internal running_mean and running_var. Calling conv_block.eval() switches the modules to evaluation mode. Now, BatchNorm2d no longer uses the current input batch's statistics but instead uses the running_mean and running_var that were estimated during training. This ensures consistent and deterministic output during inference.
Most Batch Normalization implementations have a parameter like track_running_stats (usually True by default). When set to True, the layer estimates population statistics during training and uses them during evaluation. If set to False, it would always use batch statistics, which is generally not desired for standard inference unless you have specific reasons.
In summary, Batch Normalization adapts its behavior between training and testing. It uses dynamic mini-batch statistics during training to stabilize learning and estimates population statistics via moving averages. At test time, it uses these fixed population statistics, along with the learned scale and shift parameters, to ensure deterministic and consistent normalization of inputs.
Was this section helpful?
running_mean, running_var, and track_running_stats for training and evaluation modes.© 2026 ApX Machine LearningEngineered with