As we discussed, training a neural network involves adjusting its weights to minimize a loss function, L. Gradient descent requires us to compute the gradient of this loss concerning every single weight and bias in the network, often denoted as ∇L. For a network with potentially millions of parameters spread across many layers, how do we calculate the influence of a weight buried deep within the network on the final loss? The loss isn't a direct function of that specific weight; it's a function of the network's output, which depends on intermediate activations, which in turn depend on earlier activations and weights, forming a long chain of dependencies.
This is where a fundamental concept from calculus comes into play: the chain rule. The chain rule provides a way to calculate the derivative of composite functions, functions nested within one another. This is exactly the situation we have in a neural network.
Let's refresh the basic idea. Suppose we have a simple composition of functions. If a variable y depends on a variable u, written as y=f(u), and u itself depends on another variable x, written as u=g(x), then y indirectly depends on x through u: y=f(g(x)).
The chain rule tells us how to find the rate of change of y with respect to x, denoted as dy/dx. It states that this derivative is the product of the derivatives of the "outer" function with respect to its input and the "inner" function with respect to its input:
dxdy=dudy×dxduThink of it as propagating sensitivities. How much does y change if x changes slightly? It depends on how much u changes when x changes (du/dx), multiplied by how much y changes when u changes (dy/du).
We can extend this to longer chains. If z=f(y), y=g(x), and x=h(w), then z depends on w through this chain. The derivative of z with respect to w is found by multiplying the derivatives along the path:
dwdz=dydz×dxdy×dwdxNow, let's connect this back to our neural network. Consider a very simple network with one input x, one hidden neuron h, and one output neuron y. Let the calculations be:
Forward pass computation in a simple neural network. Arrows indicate dependencies. Calculating the gradient of the Loss L with respect to an early weight like w1 requires propagating the derivative backward through these dependencies using the chain rule.
Suppose we want to find the gradient of the loss L with respect to the weight w1. L depends on y, y depends on z2, z2 depends on h, h depends on z1, and z1 depends on w1. Using the chain rule, we can write the derivative ∂L/∂w1 as:
∂w1∂L=∂y∂L×∂z2∂y×∂h∂z2×∂z1∂h×∂w1∂z1Let's break down each term:
Putting it all together for this specific example:
∂w1∂L=∂L/∂y2(y−ytrue)×∂y/∂z2σ′(z2)×∂z2/∂hw2×∂h/∂z1σ′(z1)×∂z1/∂w1xNotice how the calculation involves terms computed during the forward pass (like x,h,y,z1,z2) and the derivatives of the activation functions evaluated at the pre-activation values. The chain rule gives us a systematic way to multiply these local derivatives together to find the overall sensitivity of the loss to a specific weight, no matter how deep in the network it is.
In multi-layer networks with many neurons per layer, the dependencies become more complex (a neuron's output affects multiple neurons in the next layer), involving sums of derivatives. However, the core principle remains the same: the chain rule allows us to compute gradients by propagating derivative information backward through the network, layer by layer. This systematic application of the chain rule is precisely what the backpropagation algorithm achieves, which we will detail in the following sections. Understanding the chain rule is the first significant step towards understanding how neural networks learn.
© 2025 ApX Machine Learning