Now that we've covered the concepts of loss functions, gradient descent, and the backpropagation algorithm, let's solidify this understanding by manually calculating the gradients for a very small neural network. This exercise helps visualize how the chain rule operates to determine how each parameter contributes to the overall error.
Consider a simple network with one input, one hidden neuron (using the Sigmoid activation function), and one output neuron (also using Sigmoid). Our goal is to calculate how changes in the weights (w1, w2) and biases (b1, b2) affect the final loss for a single training example.
A simple feedforward network with one input, one hidden neuron, and one output neuron.
Network Setup
Let's define the components and initial values:
- Input: x=0.5
- Target Output: y=0.8
- Weights: w1=0.2, w2=0.9
- Biases: b1=0.1, b2=−0.3
- Activation Function: Sigmoid, σ(z)=1+e−z1. Its derivative is σ′(z)=σ(z)(1−σ(z)).
- Loss Function: Mean Squared Error (MSE), L=21(y−o)2. Its derivative with respect to the output o is ∂o∂L=o−y.
1. Forward Propagation
First, we compute the network's output (o) for the given input (x) and parameters.
- Hidden Layer Pre-activation (z1):
z1=w1x+b1=(0.2×0.5)+0.1=0.1+0.1=0.2
- Hidden Layer Activation (h):
h=σ(z1)=σ(0.2)=1+e−0.21≈1+0.81871≈0.5498
- Output Layer Pre-activation (z2):
z2=w2h+b2=(0.9×0.5498)+(−0.3)≈0.4948−0.3=0.1948
- Output Layer Activation (o):
o=σ(z2)=σ(0.1948)=1+e−0.19481≈1+0.82301≈0.5486
So, the network's prediction is o≈0.5486.
2. Loss Calculation
Now, calculate the error using the MSE loss function:
L=21(y−o)2=21(0.8−0.5486)2=21(0.2514)2≈21(0.0632)≈0.0316
The loss for this example is approximately 0.0316.
3. Backward Propagation (Gradient Calculation)
Our goal is to find the gradients: ∂w2∂L, ∂b2∂L, ∂w1∂L, and ∂b1∂L. We use the chain rule, working backward from the loss.
-
Derivative of Loss w.r.t. Network Output (o):
∂o∂L=o−y≈0.5486−0.8=−0.2514
-
Gradients for Output Layer (w2,b2):
We need the derivative of the output activation o w.r.t. its pre-activation z2.
∂z2∂o=σ′(z2)=o(1−o)≈0.5486×(1−0.5486)≈0.5486×0.4514≈0.2476
Now, apply the chain rule to find the gradient of the loss w.r.t. z2:
∂z2∂L=∂o∂L∂z2∂o≈(−0.2514)×(0.2476)≈−0.0622
The gradients for w2 and b2 depend on how z2 changes with respect to them:
∂w2∂z2=h≈0.5498
∂b2∂z2=1
Using the chain rule again:
∂w2∂L=∂z2∂L∂w2∂z2≈(−0.0622)×(0.5498)≈−0.0342
∂b2∂L=∂z2∂L∂b2∂z2≈(−0.0622)×1=−0.0622
-
Gradients for Hidden Layer (w1,b1):
We need to propagate the gradient further back. First, find the gradient of the loss w.r.t. the hidden activation h:
∂h∂L=∂z2∂L∂h∂z2
We need ∂h∂z2:
∂h∂z2=w2=0.9
So,
∂h∂L≈(−0.0622)×0.9=−0.0560
Next, we need the derivative of the hidden activation h w.r.t. its pre-activation z1:
∂z1∂h=σ′(z1)=h(1−h)≈0.5498×(1−0.5498)≈0.5498×0.4502≈0.2475
Now, apply the chain rule to find the gradient of the loss w.r.t. z1:
∂z1∂L=∂h∂L∂z1∂h≈(−0.0560)×(0.2475)≈−0.0139
Finally, the gradients for w1 and b1 depend on how z1 changes with respect to them:
∂w1∂z1=x=0.5
∂b1∂z1=1
Using the chain rule one last time:
∂w1∂L=∂z1∂L∂w1∂z1≈(−0.0139)×0.5=−0.0070
∂b1∂L=∂z1∂L∂b1∂z1≈(−0.0139)×1=−0.0139
Summary of Gradients
We have manually calculated the gradients of the loss function with respect to each parameter:
- ∂w2∂L≈−0.0342
- ∂b2∂L≈−0.0622
- ∂w1∂L≈−0.0070
- ∂b1∂L≈−0.0139
These gradients tell us the direction and magnitude of change needed for each parameter to reduce the loss. For instance, a negative gradient like ∂w2∂L≈−0.0342 suggests that increasing w2 slightly would decrease the loss (because the update rule involves subtracting the gradient).
Next Step
In a real training scenario, these gradients would be used with a chosen learning rate (η) to update the parameters using gradient descent:
wnew=wold−η∂wold∂L
bnew=bold−η∂bold∂L
This manual calculation, while tedious for larger networks, clearly demonstrates the mechanics of backpropagation and how error signals flow backward through the network to inform parameter updates. Frameworks like TensorFlow and PyTorch automate this process, but understanding the underlying calculations is significant for effective model building and debugging.