Training a Recurrent Neural Network presents a unique situation compared to standard feedforward networks. The core difficulty lies in the recurrent connection: the output of a time step feeds back into the network as input for the next time step. This creates cycles in the computational graph, which standard backpropagation isn't directly equipped to handle. How, then, do we calculate gradients when the network's state at time t depends on its state at time t−1, which depends on t−2, and so on?
The solution is a technique called Backpropagation Through Time (BPTT). To understand and implement BPTT, we first need to visualize the flow of computation in a different way. We do this by conceptually unrolling the network through time.
Imagine taking the compact representation of an RNN cell with its loop and unfolding it across the sequence length. For a sequence of length T, we create T copies of the cell, one for each time step. The recurrent connection from a cell at time t−1 to the cell at time t is now represented as a directed connection between copy t−1 and copy t.
An RNN unrolled for three time steps. The input xt and previous hidden state ht−1 are used to compute the current hidden state ht. The same weight matrices (Wxh, Whh, Why) are applied at each time step.
This unrolled network looks very much like a deep feedforward network. Each time step can be thought of as a layer receiving input from the previous layer (the previous time step's hidden state) and the external input for that specific time step. This transformation is purely conceptual; we don't actually create multiple copies of the weights in memory. It's a way to visualize the dependencies required for gradient calculation.
With the network unrolled, we can now apply the principles of backpropagation. BPTT involves performing a forward pass through the unrolled network, computing the output and loss (typically at the end of the sequence, or at each step depending on the task), and then performing a backward pass to compute gradients.
During the backward pass:
For example, the gradient of the loss L with respect to the weight matrix Whh is computed by summing its influence across all time steps where it was used:
∂Whh∂L=t=1∑T∂Whh∂Ltwhere Lt represents the portion of the loss dependent on the computation at time step t. This summation reflects the fact that the same Whh affects the hidden state calculation at every step.
Unrolling the network clarifies the computational flow for training but also highlights some practical aspects:
In essence, unrolling is the conceptual bridge that allows us to apply gradient-based optimization techniques, derived from standard backpropagation, to the recurrent structure of RNNs. It transforms the temporal loop into a sequence of operations, making the calculation of gradients with respect to the shared parameters feasible through BPTT.
© 2025 ApX Machine Learning