Masterclass
Before the advent of Transformers, Recurrent Neural Networks (RNNs) were the standard architecture for handling sequential data, such as text or time series. Unlike feed-forward networks, which process inputs independently, RNNs possess a form of memory, allowing information from previous steps in the sequence to influence the processing of the current step. This makes them naturally suited for tasks where context and order matter.
The central concept in an RNN is the hidden state, often denoted as ht for time step t. This hidden state acts as a compressed summary of the information seen in the sequence up to that point. At each time step t, the RNN takes two inputs: the current input element xt from the sequence and the hidden state from the previous time step ht−1. It then computes a new hidden state ht and, optionally, an output yt.
Think of reading a sentence: "The cat sat on the ___". To predict the next word, you need to remember "The cat sat on the". An RNN mimics this by updating its hidden state as it processes each word, carrying forward relevant context.
This process involves a loop: the same set of operations and weights are applied at every time step, using the previous hidden state as input. This shared weight structure makes RNNs parameter-efficient, as they don't need separate parameters for each position in the sequence.
Let's look at the computations inside a simple RNN cell at time step t:
Calculate the new hidden state ht: This is typically done by combining the current input xt and the previous hidden state ht−1 using weight matrices and an activation function (often hyperbolic tangent, tanh).
ht=tanh(Wxhxt+Whhht−1+bh)Here:
Calculate the output yt (optional): Depending on the task, an output might be generated at each time step based on the current hidden state.
yt=Whyht+byHere:
The key is that the weight matrices (Wxh,Whh,Why) and biases (bh,by) are the same across all time steps. The network learns a single transition function that is applied repeatedly.
While we often draw an RNN cell with a loop, it's useful to visualize it "unrolled" across the sequence length. This shows how the computation flows from one time step to the next.
An RNN unrolled through three time steps. The same RNN cell (representing shared weights Wxh,Whh,Why) processes input xt and the previous hidden state ht−1 to produce the current hidden state ht and output yt.
PyTorch provides convenient modules for RNNs. Here's a basic example of defining and using a single-layer RNN:
import torch
import torch.nn as nn
# Define parameters
input_size = 10 # Dimension of input vector x_t
hidden_size = 20 # Dimension of hidden state h_t
sequence_length = 5
batch_size = 3
# Create an RNN layer
# batch_first=True means input/output tensors have batch dim first
# (batch, seq, feature)
rnn_layer = nn.RNN(input_size, hidden_size, batch_first=True)
# Create some dummy input data
# Shape: (batch_size, sequence_length, input_size)
input_sequence = torch.randn(batch_size, sequence_length, input_size)
# Initialize hidden state (optional, defaults to zeros)
# Shape: (num_layers * num_directions, batch_size, hidden_size)
# -> (1, 3, 20) for this case
initial_hidden_state = torch.zeros(1, batch_size, hidden_size)
# Pass the input sequence and initial hidden state through the RNN
# output contains the hidden state for *each* time step
# final_hidden_state contains only the *last* hidden state
output, final_hidden_state = rnn_layer(input_sequence, initial_hidden_state)
print("Input shape:", input_sequence.shape)
# Output shape: (batch_size, sequence_length, hidden_size)
print("Output shape:", output.shape)
# Final hidden state shape: (num_layers * num_directions, batch_size,
# hidden_size)
print("Final hidden state shape:", final_hidden_state.shape)
# Example: Accessing hidden state at the last time step from the output
last_time_step_output = output[:, -1, :]
print("Last time step hidden state from output shape:",
last_time_step_output.shape)
# Verify it matches the final_hidden_state (squeeze the first dimension)
print(
"Are final hidden state and last output step equal?",
torch.allclose(
final_hidden_state.squeeze(0),
last_time_step_output
)
)
This simple structure allows RNNs to model sequential dependencies. However, as we will see in the next section, basic RNNs struggle with learning relationships between elements that are far apart in the sequence. This limitation paved the way for more complex architectures like LSTMs and GRUs.
© 2025 ApX Machine Learning