Masterclass
Simple Recurrent Neural Networks (RNNs), as discussed previously, struggle with learning dependencies over long sequences. The backpropagation process involves repeated multiplications through time, causing gradients to either shrink towards zero (vanish) or grow uncontrollably (explode). The vanishing gradient problem is particularly troublesome because it prevents the network from learning correlations between events that are far apart in the sequence.
Long Short-Term Memory (LSTM) networks were specifically designed to combat these issues. Introduced by Hochreiter and Schmidhuber in 1997, LSTMs incorporate a more complex internal structure than simple RNNs, featuring a dedicated cell state and gates that regulate the flow of information. This architecture allows LSTMs to maintain information over extended time intervals, effectively capturing long-range dependencies.
Think of the standard RNN hidden state ht as a working memory that gets overwritten at each time step. LSTMs introduce an additional component, the cell state Ct, which acts like a conveyor belt or a memory highway. Information can flow along this highway relatively unchanged, making it easier to preserve context over long durations.
The key innovation of LSTMs lies in their ability to selectively add or remove information from the cell state using structures called gates. Gates are composed of a sigmoid neural network layer and a pointwise multiplication operation. The sigmoid layer outputs numbers between 0 and 1, describing how much of each component should be let through. A value of 0 means "let nothing through," while a value of 1 means "let everything through."
An LSTM cell typically has three main gates:
Let's examine how these components interact at a single time step t, given the input xt, the previous hidden state ht−1, and the previous cell state Ct−1.
The first step is to decide what information we're going to discard from the previous cell state Ct−1. This decision is made by the forget gate layer. It looks at ht−1 and xt, and outputs a number between 0 and 1 for each number in the cell state Ct−1.
ft=σ(Wf[ht−1,xt]+bf)Here, σ is the sigmoid activation function, Wf represents the weights, and bf the biases for the forget gate. The notation [ht−1,xt] indicates that the previous hidden state and the current input are concatenated.
Next, we need to decide what new information we're going to store in the cell state. This involves two parts. First, an input gate layer (another sigmoid layer) decides which values we'll update.
it=σ(Wi[ht−1,xt]+bi)Second, a tanh
layer creates a vector of new candidate values, C~t, that could be added to the state.
Now, we update the old cell state Ct−1 into the new cell state Ct. We multiply the old state by ft, forgetting the things we decided to forget earlier. Then we add it∗C~t. This is the new candidate information, scaled by how much we decided to update each state value.
Ct=ft∗Ct−1+it∗C~tThe use of addition here is significant. Unlike the repeated multiplications in simple RNNs, this additive interaction makes it much easier for gradients to flow backward through time without vanishing as quickly.
Finally, we need to decide what we're going to output. This output will be based on our cell state, but will be a filtered version. First, we run a sigmoid layer which decides what parts of the cell state we’re going to output.
ot=σ(Wo[ht−1,xt]+bo)Then, we put the cell state through tanh
(to push the values to be between -1 and 1) and multiply it by the output of the sigmoid gate, so that we only output the parts we decided to. This final output is the hidden state ht.
The hidden state ht is then passed to the next time step, along with the new cell state Ct.
Simplified view of information flow within an LSTM cell at time step t. The cell state Ct acts as a conveyor belt, modified by the forget and input gates. The output gate filters the cell state to produce the hidden state ht.
Deep learning frameworks like PyTorch provide efficient implementations of LSTM layers. Using torch.nn.LSTM
abstracts away the gate calculations.
import torch
import torch.nn as nn
# Example parameters
input_size = 10 # Dimension of input features x_t
hidden_size = 20 # Dimension of hidden state h_t and cell state C_t
num_layers = 1 # Number of stacked LSTM layers
batch_size = 5
seq_len = 7 # Length of the input sequence
# Create an LSTM layer
lstm_layer = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
# Dummy input data (batch_size, seq_len, input_size)
input_seq = torch.randn(batch_size, seq_len, input_size)
# Initial hidden and cell states (num_layers, batch_size, hidden_size)
# If not provided, they default to zeros.
h0 = torch.randn(num_layers, batch_size, hidden_size)
c0 = torch.randn(num_layers, batch_size, hidden_size)
# Forward pass
# output contains the hidden state h_t for each time step in the sequence
# hn contains the final hidden state for the last time step
# cn contains the final cell state for the last time step
output, (hn, cn) = lstm_layer(input_seq, (h0, c0))
print("Input shape:", input_seq.shape)
print("Output shape (all hidden states):", output.shape)
print("Final hidden state shape (hn):", hn.shape)
print("Final cell state shape (cn):", cn.shape)
# --- Output ---
# Input shape: torch.Size([5, 7, 10])
# Output shape (all hidden states): torch.Size([5, 7, 20])
# Final hidden state shape (hn): torch.Size([1, 5, 20])
# Final cell state shape (cn): torch.Size([1, 5, 20])
Notice that the hidden state and cell state are returned as a tuple (hn, cn)
. This reflects the two distinct internal states maintained by the LSTM throughout the sequence processing. The output
tensor provides the hidden state ht at every time step, which is often useful in sequence-to-sequence models.
By using gates to control information flow and an additive update mechanism for the cell state, LSTMs effectively mitigate the vanishing gradient problem and allow for the learning of dependencies across much longer time spans compared to simple RNNs. This capability made them a dominant architecture for many NLP tasks before the rise of attention-based models. While Gated Recurrent Units (GRUs), which we explore next, offer a slightly simpler gated mechanism, the fundamental principles introduced by LSTMs remain influential.
© 2025 ApX Machine Learning