Okay, we've seen how an RNN processes a sequence step by step, using its hidden state to carry information forward. But how does the network learn? Just like feedforward networks, RNNs learn by adjusting their weights based on the error they make. The standard algorithm for this is backpropagation. However, the recurrent connections and shared weights in RNNs require a modification: Backpropagation Through Time (BPTT).
Imagine you feed a sequence of length T into your RNN. As discussed earlier, the network performs the same calculations at each step, using the same set of weights (Whh,Wxh,Why,bh,by). To train the network, we first need to compute a loss, which measures how far the network's predictions (y1,y2,...,yT) are from the true target values for that sequence. This loss is typically calculated over the entire sequence, often as a sum or average of the losses at each time step.
The core idea of BPTT is to apply the chain rule of calculus, just like standard backpropagation, but to do so across the temporal sequence of operations. To visualize this, think of "unrolling" the RNN for the specific input sequence. This creates a large, feedforward-like network where each time step corresponds to a layer. Crucially, the weights are shared across these "layers".
Let's consider the gradient calculation. The error at the final time step T depends directly on the hidden state hT and the weights Why and by. The error at time step T−1 depends on hT−1 and potentially influences the error at step T via the recurrent connection (hT depends on hT−1). BPTT works by calculating the gradient of the total loss with respect to the outputs and hidden states at each time step, starting from the last step T and moving backward to the first step t=1.
The gradient of the loss L with respect to the hidden state at time t, denoted ∂ht∂L, depends on two things:
This second point is where the "through time" part comes in. The gradient signal flows backward from ht+1 to ht via the recurrent weight matrix Whh. Mathematically, this involves terms like:
∂ht∂ht+1=∂ht∂f(Whhht+Wxhxt+1+bh)
This propagation continues backward step by step.
The backward pass of BPTT. Gradients (dashed red lines) from the loss at each time step flow backward through the network outputs (yt) and hidden states (ht). Importantly, gradients also flow backward through the recurrent connections (via Whh), influencing the gradient calculations at earlier time steps. The gradients with respect to the shared weights (conceptual purple lines) are accumulated across all time steps.
A significant aspect of BPTT arises from the shared weights. Since the same weight matrices (Whh,Wxh,Why) and bias vectors (bh,by) are used at every single time step, the gradient calculated for a specific weight needs to account for its influence across the entire sequence. Therefore, the final gradient for a shared parameter, say Whh, is the sum of the gradients computed with respect to its usage at each time step t=1,...,T.
∂Whh∂L=∑t=1T∂ht∂L∂Whh∂ht(calculated during the backward pass)
Similarly for Wxh, Why, bh, and by. Once these total gradients are computed, a standard gradient descent update (or one of its variants like Adam or RMSprop) is performed to adjust the network's parameters.
W←W−η∂W∂L
where W represents any of the shared parameters and η is the learning rate.
While BPTT allows us to train RNNs, this process of propagating gradients over potentially long sequences isn't without its difficulties. As the gradient signal travels backward through many time steps, it can either shrink exponentially towards zero (vanishing gradients) or grow exponentially large (exploding gradients). We will examine these training challenges and their implications in Chapter 4. For now, the fundamental takeaway is that BPTT extends backpropagation to handle the temporal dependencies and shared parameters inherent in recurrent architectures. The next section looks more closely at the practical concept of unrolling the network to facilitate this process.
© 2025 ApX Machine Learning