As we saw previously, Recurrent Neural Networks (RNNs) are designed to handle sequences by maintaining a hidden state that carries information from one time step to the next. This 'memory' allows them to theoretically connect information across different points in a sequence. But how do they learn which past information is relevant? Like other neural networks, RNNs learn by adjusting their weights based on the error calculated from their outputs, typically using a variant of backpropagation.
Training an RNN involves calculating how the network's error changes with respect to its weights. Because the output at a given time step t depends on computations from all previous time steps (t−1,t−2,…,1), we need to propagate the error signal backward through the entire sequence. This process is called Backpropagation Through Time (BPTT).
Conceptually, BPTT works by 'unrolling' the RNN through its time steps. Imagine creating a separate layer for each time step in the sequence, where the weights are shared across all these 'layers'. Then, standard backpropagation is applied to this unrolled network. The total gradient for a shared weight (like the recurrent weight matrix Whh) is the sum of its gradients calculated at each time step.
While elegant, BPTT faces a significant practical challenge when dealing with long sequences. The core issue arises from the repeated application of the chain rule required to propagate gradients back through many time steps.
To calculate the gradient of the loss with respect to a hidden state hk far back in time (say, k≪t), we need to multiply many Jacobian matrices (matrices of partial derivatives) together:
∂hk∂L=∂ht∂L∂ht−1∂ht∂ht−2∂ht−1⋯∂hk∂hk+1Each term ∂hi−1∂hi involves the recurrent weight matrix Whh and the derivative of the activation function used in the RNN cell.
Now, consider what happens during this repeated multiplication. If the relevant values in these Jacobian matrices are consistently smaller than 1, their product will shrink exponentially as we propagate further back in time. Multiplying 0.9 by itself 10 times gives about 0.35, but multiplying it 50 times gives roughly 0.005. The gradient signal becomes vanishingly small.
This is the vanishing gradient problem: the contribution of information from early time steps to the gradient becomes almost zero for later time steps.
Simplified illustration of gradient flow during Backpropagation Through Time (BPTT). The error signal (gradient, shown by dashed red arrows) starts near the Loss but weakens significantly (indicated by lighter color and thinner lines) as it propagates backward through many time steps, making it hard to update weights based on early inputs.
The practical impact of vanishing gradients is significant: the RNN struggles to learn long-range dependencies. If a critical piece of information occurred early in the sequence, but the error signal related to it has vanished by the time it propagates back from a much later output, the network cannot effectively adjust its weights to capture that dependency.
Consider sentiment analysis of a long product review: "I initially ordered this product with high hopes, based on the excellent descriptions and photos. It arrived promptly, and the packaging was superb. However, after using it for a week, I found that the main component completely failed. Very disappointing."
To correctly classify the overall sentiment as negative, the model needs to connect the final "Very disappointing" statement (and the failure description) back to the beginning of the review, potentially ignoring the initially positive comments. If the gradient signal from the end of the review vanishes before it reaches the parts of the network processing the beginning, the model might overweight the early positive signals or fail to learn the importance of the component failure described mid-review. It essentially develops a short memory, unable to link events across long time intervals.
While less common but still possible, the opposite problem can also occur: exploding gradients. If the values in the Jacobian matrices repeatedly multiplied during BPTT are consistently greater than 1, the gradient signal can grow exponentially, becoming astronomically large. This leads to unstable training, where weight updates are huge and erratic, often resulting in numerical overflow (NaN values).
Fortunately, exploding gradients are usually easier to detect and handle than vanishing gradients. A common technique called gradient clipping involves setting a maximum threshold for the norm (magnitude) of the gradients. If the gradient's norm exceeds this threshold during backpropagation, it is scaled down before the weight update is applied, preventing excessively large steps.
The vanishing gradient problem, however, remains a more fundamental obstacle to learning long-range dependencies in simple RNNs. It significantly limits their effectiveness on tasks requiring understanding context over extended sequences. This difficulty was a primary motivation for developing more sophisticated recurrent architectures, specifically Long Short-Term Memory (LSTM) networks and Gated Recurrent Units (GRUs), which we will explore next. These architectures incorporate mechanisms explicitly designed to regulate information flow and mitigate the vanishing gradient issue.
© 2025 ApX Machine Learning