Having established how Batch Normalization (BN) works during the forward pass, normalizing inputs using mini-batch statistics and then scaling and shifting them, we now turn to the backward pass. To train the network using gradient descent, we need to calculate how the loss L changes with respect to the BN layer's inputs xi and its learnable parameters γ and β. This involves applying the chain rule through the BN transformation.
Let's recall the forward pass for a single activation xi within a mini-batch B={x1,...,xm}:
During backpropagation, we receive the gradient of the loss with respect to the BN layer's output, ∂yi∂L, from the subsequent layer. Our goal is to compute ∂xi∂L, ∂γ∂L, and ∂β∂L.
These are the most straightforward gradients to compute using the chain rule:
Gradient with respect to β: The parameter β directly adds to the output yi.
∂β∂L=i=1∑m∂yi∂L∂β∂yi=i=1∑m∂yi∂L(1)=i=1∑m∂yi∂LThe gradient for β is simply the sum of the incoming gradients from the outputs yi.
Gradient with respect to γ: The parameter γ scales the normalized input x^i.
∂γ∂L=i=1∑m∂yi∂L∂γ∂yi=i=1∑m∂yi∂L(x^i)=i=1∑m∂yi∂Lx^iThe gradient for γ is the sum of the incoming gradients, each weighted by the corresponding normalized input x^i.
Calculating the gradient with respect to the input xi is more involved because xi influences the output yi in multiple ways:
We need to apply the chain rule carefully, considering all these paths. Let σB,ϵ=σB2+ϵ. The gradient computation proceeds backward through the operations:
Gradient w.r.t. normalized input x^i:
∂x^i∂L=∂yi∂L∂x^i∂yi=∂yi∂LγGradients w.r.t. μB and σB2: These require summing contributions from all x^j in the mini-batch, as both statistics affect all normalized inputs.
∂σB2∂L=i=1∑m∂x^i∂L∂σB2∂x^i=i=1∑m∂x^i∂L(xi−μB)(−21(σB2+ϵ)−3/2) ∂μB∂L=i=1∑m∂x^i∂L∂μB∂x^i=(i=1∑m∂x^i∂LσB,ϵ−1)+∂σB2∂L∂μB∂σB2where ∂μB∂σB2=m1∑j=1m2(xj−μB)(−1)=m−2∑j=1m(xj−μB)=0. So, the second term vanishes, simplifying the gradient for μB:
∂μB∂L=i=1∑m∂x^i∂LσB,ϵ−1Gradient w.r.t. input xi: Now we combine the paths. An input xi influences the loss through x^i, μB, and σB2.
∂xi∂L=∂x^i∂L∂xi∂x^i+∂σB2∂L∂xi∂σB2+∂μB∂L∂xi∂μBWe need the partial derivatives of the statistics w.r.t. a single input xi:
Substituting these gives the final expression for ∂xi∂L:
∂xi∂L=∂x^i∂LσB,ϵ1+∂σB2∂Lm2(xi−μB)+∂μB∂Lm1Combining everything and simplifying (the full derivation is quite detailed, often found in appendices of papers or textbooks), the result can be expressed more compactly. A common form is:
∂xi∂L=mσB,ϵ1(m∂x^i∂L−j=1∑m∂x^j∂L−x^ij=1∑m∂x^j∂Lx^j)Note that ∂x^j∂L=∂yj∂Lγ.
The key takeaway is that the gradient ∂xi∂L depends not only on the gradient ∂yi∂L corresponding to that specific activation, but also on the gradients and values of all other activations in the mini-batch (j=1...m) due to the shared mean and variance calculations.
The dependencies during the backward pass can be visualized. Consider the computation graph for a single output yi and how the loss gradient flows back to input xi, incorporating the influence of the shared μB and σB2.
This diagram illustrates the dependencies in the Batch Normalization calculations and the flow of gradients during backpropagation. Notice how the input xi receives gradient contributions directly from x^i and indirectly through the mini-batch statistics μB and σB2.
Fortunately, you rarely need to implement this backward pass manually. Deep learning frameworks like PyTorch and TensorFlow use automatic differentiation (autograd) to compute these gradients automatically when you define a model with BN layers. For example, in PyTorch:
import torch
import torch.nn as nn
# Example setup
batch_size = 4
num_features = 10
input_tensor = torch.randn(batch_size, num_features, requires_grad=True)
# Define a Batch Norm layer (affine=True means learnable gamma and beta)
bn_layer = nn.BatchNorm1d(num_features=num_features, affine=True)
# Forward pass
output = bn_layer(input_tensor)
# Assume some dummy loss for demonstration
loss = output.mean()
# Backward pass
loss.backward()
# Gradients are now computed and stored
# Gradient w.r.t input: input_tensor.grad
# Gradient w.r.t gamma (weight): bn_layer.weight.grad
# Gradient w.r.t beta (bias): bn_layer.bias.grad
print("Shape of input gradient:", input_tensor.grad.shape)
print("Shape of gamma gradient:", bn_layer.weight.grad.shape)
print("Shape of beta gradient:", bn_layer.bias.grad.shape)
# >>> Shape of input gradient: torch.Size([4, 10])
# >>> Shape of gamma gradient: torch.Size([10])
# >>> Shape of beta gradient: torch.Size([10])
While the framework handles the mechanics, understanding the underlying calculations, especially the dependency on the entire mini-batch for the input gradient ∂xi∂L, is valuable for interpreting model behavior and potential issues during training. This understanding helps appreciate why BN affects training dynamics and generalization performance, which we will discuss next.
© 2025 ApX Machine Learning