Masterclass
While LSTMs effectively address the vanishing gradient problem and capture long-range dependencies, they introduce a fair amount of computational complexity with their three gates (input, forget, output) and separate cell state. In 2014, Cho et al. proposed the Gated Recurrent Unit (GRU) as a variation that achieves similar performance on many tasks but with a simpler architecture. GRUs combine the forget and input gates into a single "update gate" and merge the cell state and hidden state.
Let's look at the structure of a GRU cell. Like other RNNs, it takes the current input xt and the previous hidden state ht−1 to produce the next hidden state ht. The magic happens through two gates: the Reset Gate (rt) and the Update Gate (zt).
The reset gate determines how to combine the new input with the previous hidden state. Specifically, it controls how much of the previous hidden state (ht−1) should be "forgotten" when calculating a candidate hidden state. It's computed as follows:
rt=σ(Wrxt+Urht−1+br)Here, Wr, Ur, and br are learnable weight matrices and a bias vector for the reset gate. The sigmoid function σ squashes the output between 0 and 1. A value close to 0 means the previous hidden state is mostly ignored, while a value close to 1 means it's mostly kept.
The update gate decides how much of the previous hidden state (ht−1) should be carried forward to the new hidden state (ht) versus how much of the new candidate hidden state should be used. This gate essentially combines the roles of the LSTM's forget and input gates. It's calculated similarly to the reset gate:
zt=σ(Wzxt+Uzht−1+bz)Again, Wz, Uz, and bz are learnable parameters, and σ is the sigmoid function. A value of zt close to 1 means the previous state ht−1 is largely kept, while a value close to 0 means the new candidate state is predominantly used.
Before computing the final hidden state, the GRU calculates a candidate hidden state (h~t). This calculation is influenced by the reset gate, which determines how much the previous hidden state ht−1 contributes:
h~t=tanh(Whxt+Uh(rt⊙ht−1)+bh)Here, ⊙ denotes element-wise multiplication (the Hadamard product). If the reset gate rt has values close to 0, the contribution from ht−1 is effectively erased, allowing the candidate state to be based primarily on the current input xt. Wh, Uh, and bh are another set of learnable weights and bias. The tanh function helps regulate the values within the network, typically squashing them between -1 and 1.
Finally, the update gate zt mediates between the previous hidden state ht−1 and the candidate hidden state h~t to produce the final hidden state ht for the current time step:
ht=(1−zt)⊙ht−1+zt⊙h~tThis equation acts like a weighted average. If zt is close to 1, the candidate state h~t contributes more, effectively updating the hidden state with new information. If zt is close to 0, the previous hidden state ht−1 is retained more, allowing information to pass through unchanged across multiple time steps. This mechanism is how GRUs can maintain long-range dependencies.
A simplified view of information flow within a GRU cell. xt is the input, ht−1 is the previous hidden state. The reset gate (rt) influences the candidate state (h~t), and the update gate (zt) combines the candidate state with the previous state to produce the final hidden state ht.
GRUs are often seen as a more streamlined alternative to LSTMs.
In practice, the choice between LSTM and GRU often depends on the specific dataset and task. Neither consistently outperforms the other across all scenarios, although GRUs have gained popularity due to their relative simplicity and comparable performance.
Here's a PyTorch snippet showing the core calculations for a single GRU step (assuming inputs x_t
, h_tm1
and pre-defined weight/bias tensors):
import torch
import torch.nn.functional as F
# Example tensors (batch_size, input_size/hidden_size)
# Replace with actual dimensions and initialized weights/biases
batch_size = 1
input_size = 10
hidden_size = 20
x_t = torch.randn(batch_size, input_size)
h_tm1 = torch.randn(batch_size, hidden_size) # h_{t-1}
# --- Assume weight matrices (W_*, U_*) and biases (b_*) are defined ---
# Example initialization (replace with actual learned parameters)
W_r = torch.randn(input_size, hidden_size)
U_r = torch.randn(hidden_size, hidden_size)
b_r = torch.randn(hidden_size)
W_z = torch.randn(input_size, hidden_size)
U_z = torch.randn(hidden_size, hidden_size)
b_z = torch.randn(hidden_size)
W_h = torch.randn(input_size, hidden_size)
U_h = torch.randn(hidden_size, hidden_size)
b_h = torch.randn(hidden_size)
# ---------------------------------------------------------------------
# Reset Gate calculation
r_t = torch.sigmoid(x_t @ W_r + h_tm1 @ U_r + b_r)
# Update Gate calculation
z_t = torch.sigmoid(x_t @ W_z + h_tm1 @ U_z + b_z)
# Candidate Hidden State calculation
h_tilde_t = torch.tanh(x_t @ W_h + (r_t * h_tm1) @ U_h + b_h)
# Final Hidden State calculation
h_t = (1 - z_t) * h_tm1 + z_t * h_tilde_t
print("Previous hidden state shape:", h_tm1.shape)
print("Current hidden state shape:", h_t.shape)
This simplified structure, while effective, still relies on sequential processing. The computation for time step t depends on the result from time step t−1. This inherent sequential dependency limits parallelization during training and remains a bottleneck for processing very long sequences, setting the stage for the non-recurrent attention mechanisms used in Transformers.
© 2025 ApX Machine Learning