In the previous sections, we discussed the SimpleRNN
layer and identified its limitations, particularly the vanishing gradient problem, which makes it difficult for the network to learn dependencies across long sequences. To overcome this, more sophisticated recurrent architectures were developed. One of the most successful and widely used is the Long Short-Term Memory (LSTM) network.
LSTMs introduce a mechanism to explicitly manage the flow of information over time, allowing them to selectively remember or forget information. This is achieved through a system of gates controlling a dedicated cell state (ct), which acts like a conveyor belt running through the entire sequence, carrying information with minimal manipulation.
An LSTM cell processes the input at the current timestep (xt) along with the hidden state from the previous timestep (ht−1). Unlike SimpleRNN
, it uses three primary gates and updates both a hidden state (ht) and the cell state (ct).
tanh
layer (c~t) creates a vector of new candidate values to be added to the state.
it=σ(Wi⋅[ht−1,xt]+bi)
c~t=tanh(WC⋅[ht−1,xt]+bC)tanh
(to push values between -1 and 1) and multiplied by the output of the sigmoid gate. This filtered version becomes the new hidden state ht.
ot=σ(Wo⋅[ht−1,xt]+bo)
ht=ot∗tanh(ct)The hidden state ht is the output of the LSTM unit for the current timestep. The combination of the cell state and the gates allows LSTMs to maintain relevant information over much longer sequences compared to SimpleRNN
s, mitigating the vanishing gradient problem.
A conceptual diagram of an LSTM cell showing the flow of information through the forget, input, and output gates, interacting with the cell state and hidden state.
Using LSTMs in Keras is straightforward, thanks to the keras.layers.LSTM
layer. It functions similarly to SimpleRNN
but incorporates the more complex internal logic described above.
import keras
from keras import layers
# Define an LSTM layer with 64 units
# Assuming input shape is (batch_size, timesteps, features)
# For example, (32, 10, 8) means 32 sequences, 10 timesteps each, 8 features per timestep
lstm_layer = layers.LSTM(units=64)
# You can add it to a Sequential model:
model = keras.Sequential([
# Input shape required for the first layer
layers.Input(shape=(None, 8)), # (timesteps, features) - None allows variable sequence length
layers.LSTM(units=64, return_sequences=True), # Returns the full sequence output
layers.LSTM(units=32), # Returns only the last output
layers.Dense(units=10) # Example final classification layer
])
model.summary()
Key parameters for keras.layers.LSTM
:
units
: This is the dimensionality of the output space, which also corresponds to the dimensionality of the hidden state ht and the cell state ct. This is a required argument.activation
: The activation function applied to the candidate cell state (c~t) and the final hidden state output calculation (ht). The default is 'tanh'
.recurrent_activation
: The activation function used for the three gates (forget, input, output). The default is 'sigmoid'
.return_sequences
: A boolean value.
False
(default), the layer only returns the hidden state for the last timestep in the input sequence (hT). This is suitable when the LSTM layer is the final recurrent layer before a Dense
layer for tasks like sequence classification.True
, the layer returns the hidden state for every timestep (h1,h2,...,hT). This is necessary when stacking LSTM layers (so the next LSTM layer receives a sequence as input) or for sequence-to-sequence tasks where an output is needed at each step.input_shape
: Like other Keras layers, you need to specify the shape of the input for the first layer in a model. For recurrent layers, this is typically (timesteps, features)
. You can use None
for the timesteps
dimension if your sequences have variable lengths.By default, the LSTM
layer uses optimized CuDNN kernels when running on a compatible GPU, providing significant speedups during training.
Compared to SimpleRNN
, the LSTM
layer involves more computations per timestep due to its internal gating mechanisms. However, this complexity is precisely what allows it to effectively learn long-range dependencies, making it a much more powerful tool for many sequence modeling tasks. In the practice section later in this chapter, you'll implement an LSTM model for text classification.
© 2025 ApX Machine Learning