As we saw with the SimpleRNN
layer, the core idea of recurrent networks is to maintain a hidden state that carries information from previous time steps to influence the processing of the current step. This sounds great for capturing context in sequences. However, training these networks, especially the simpler ones, runs into a significant practical challenge known primarily as the vanishing gradient problem.
Recall that neural networks learn by adjusting their weights based on the gradient of the loss function. This gradient indicates how much a small change in each weight would affect the final loss. During training, we use backpropagation to calculate these gradients. In RNNs, because the output at a given time step depends on computations from previous time steps, we need to propagate gradients back not just through the layers at the current time step, but also through time itself. This process is often called Backpropagation Through Time (BPTT).
Imagine unfolding an RNN over a sequence. The computation at time step t depends on the input at t and the hidden state from t−1. The hidden state at t−1 depends on the input at t−1 and the hidden state from t−2, and so on. When calculating the gradient of the loss with respect to a weight that influences an early hidden state (say, at time step t−k), the chain rule requires us to multiply many gradient terms together, one for each time step we propagate back through.
Specifically, the gradient calculation involves repeated multiplication by the derivative of the hidden state at one time step with respect to the hidden state at the previous time step. This derivative often involves the recurrent weight matrix and the derivative of the activation function used in the recurrent unit.
The problem arises when these terms are consistently smaller than 1. Consider a simple case where the relevant part of the gradient calculation involves repeatedly multiplying by a value, say 0.8.
As you propagate further back in time, the gradient magnitude shrinks exponentially, approaching zero.
Illustration of how the magnitude of gradients can decrease exponentially as they are propagated back through time steps in a simple RNN.
Why might these terms be less than 1?
The consequence of vanishing gradients is profound: the network struggles to learn dependencies between elements that are far apart in the sequence. If the gradient related to an early time step becomes vanishingly small, the weights associated with that step won't be updated effectively. The network essentially fails to "remember" or learn from information seen many steps ago. For tasks like text generation or sentiment analysis, where understanding the context might depend on words or phrases seen much earlier, this is a major limitation.
The SimpleRNN
architecture is particularly prone to this issue because it lacks mechanisms to regulate the flow of information and gradients over long durations. The hidden state is simply overwritten at each step, making it hard to preserve crucial information from the distant past.
While less common in practice but mathematically possible, the opposite problem can also occur: exploding gradients. If the relevant gradient terms are consistently greater than 1, their product can grow exponentially, leading to huge gradients, unstable training, and numerical overflows (often resulting in NaN
values during training). Techniques like gradient clipping (capping the gradient magnitude at a threshold) are often used to mitigate exploding gradients.
The vanishing gradient problem, however, is the more persistent challenge for simple RNNs trying to capture long-range dependencies. This limitation was a primary driver for the development of more sophisticated recurrent architectures like Long Short-Term Memory (LSTM) and Gated Recurrent Units (GRU), which we will explore next. These architectures incorporate explicit gating mechanisms designed to better control the flow of information and gradients through time, allowing them to learn much longer-term patterns in sequential data.
© 2025 ApX Machine Learning