Having established how Batch Normalization (BN) works during the forward pass, normalizing inputs using mini-batch statistics and then scaling and shifting them, we now turn to the backward pass. To train the network using gradient descent, we need to calculate how the loss L changes with respect to the BN layer's inputs xi and its learnable parameters γ and β. This involves applying the chain rule through the BN transformation.
Let's recall the forward pass for a single activation xi within a mini-batch B={x1,...,xm}:
Normalize the input:
x^i=σB2+ϵxi−μB
(where ϵ is a small constant for numerical stability)
Scale and shift:
yi=γx^i+β
During backpropagation, we receive the gradient of the loss with respect to the BN layer's output, ∂yi∂L, from the subsequent layer. Our goal is to compute ∂xi∂L, ∂γ∂L, and ∂β∂L.
Gradients for Learnable Parameters (γ and β)
These are the most straightforward gradients to compute using the chain rule:
Gradient with respect to β: The parameter β directly adds to the output yi.
The gradient for γ is the sum of the incoming gradients, each weighted by the corresponding normalized input x^i.
Gradient with respect to the Input (xi)
Calculating the gradient with respect to the input xi is more involved because xi influences the output yi in multiple ways:
Directly through the numerator (xi−μB) in x^i.
Indirectly through the mini-batch mean μB, which depends on all xj in the batch.
Indirectly through the mini-batch variance σB2, which also depends on all xj (including xi) and μB.
We need to apply the chain rule carefully, considering all these paths. Let σB,ϵ=σB2+ϵ. The gradient computation proceeds backward through the operations:
Gradient w.r.t. normalized input x^i:
∂x^i∂L=∂yi∂L∂x^i∂yi=∂yi∂Lγ
Gradients w.r.t. μB and σB2: These require summing contributions from all x^j in the mini-batch, as both statistics affect all normalized inputs.
Combining everything and simplifying (the full derivation is quite detailed, often found in appendices of papers or textbooks), the result can be expressed more compactly. A common form is:
The main takeaway is that the gradient ∂xi∂L depends not only on the gradient ∂yi∂L corresponding to that specific activation, but also on the gradients and values of all other activations in the mini-batch (j=1...m) due to the shared mean and variance calculations.
Visualizing the Gradient Flow
The dependencies during the backward pass can be visualized. Consider the computation graph for a single output yi and how the loss gradient flows back to input xi, incorporating the influence of the shared μB and σB2.
This diagram illustrates the dependencies in the Batch Normalization calculations and the flow of gradients during backpropagation. Notice how the input xi receives gradient contributions directly from x^i and indirectly through the mini-batch statistics μB and σB2.
Implementation in Frameworks
Fortunately, you rarely need to implement this backward pass manually. Deep learning frameworks like PyTorch and TensorFlow use automatic differentiation (autograd) to compute these gradients automatically when you define a model with BN layers. For example, in PyTorch:
import torch
import torch.nn as nn
# Example setup
batch_size = 4
num_features = 10
input_tensor = torch.randn(batch_size, num_features, requires_grad=True)
# Define a Batch Norm layer (affine=True means learnable gamma and beta)
bn_layer = nn.BatchNorm1d(num_features=num_features, affine=True)
# Forward pass
output = bn_layer(input_tensor)
# Assume some dummy loss for demonstration
loss = output.mean()
# Backward pass
loss.backward()
# Gradients are now computed and stored
# Gradient w.r.t input: input_tensor.grad
# Gradient w.r.t gamma (weight): bn_layer.weight.grad
# Gradient w.r.t beta (bias): bn_layer.bias.grad
print("Shape of input gradient:", input_tensor.grad.shape)
print("Shape of gamma gradient:", bn_layer.weight.grad.shape)
print("Shape of beta gradient:", bn_layer.bias.grad.shape)
# >>> Shape of input gradient: torch.Size([4, 10])
# >>> Shape of gamma gradient: torch.Size([10])
# >>> Shape of beta gradient: torch.Size([10])
While the framework handles the mechanics, understanding the underlying calculations, especially the dependency on the entire mini-batch for the input gradient ∂xi∂L, is valuable for interpreting model behavior and potential issues during training. This understanding helps appreciate why BN affects training dynamics and generalization performance, which we will discuss next.
Deep Learning, Ian Goodfellow, Yoshua Bengio, and Aaron Courville, 2016 (MIT Press) - A comprehensive and authoritative textbook covering the theoretical foundations and practical aspects of deep learning, including a discussion of Batch Normalization in the context of optimization.
Dive into Deep Learning, Aston Zhang, Zachary C. Lipton, Mu Li, Alex Smola, 2023 (Cambridge University Press) - An interactive, open-source textbook that provides detailed explanations and step-by-step mathematical derivations for various deep learning components, including a dedicated section on Batch Normalization's forward and backward passes.