Masterclass
The feed-forward network (FFN) sub-layer within each Transformer block plays a significant role in transforming the representations learned by the attention mechanism. A standard FFN consists of two linear transformations with a non-linear activation function in between:
FFN(x)=max(0,xW1+b1)W2+b2
Here, x is the input from the attention sub-layer, W1, b1, W2, and b2 are learnable parameters, and the activation function shown is ReLU. This non-linearity is essential; without it, the two linear layers would collapse into a single linear transformation, limiting the model's expressive power.
As models scale, the choice of this activation function becomes more than just a minor detail. It can impact gradient flow, training stability, computational cost, and ultimately, the model's final performance. Let's examine the common choices in large Transformers: ReLU, GeLU, and SwiGLU.
The Rectified Linear Unit, or ReLU, defined as ReLU(x)=max(0,x), was a foundational activation for deep learning. Its main advantages are simplicity and computational efficiency. It avoids the vanishing gradient problems often seen with sigmoid or tanh functions in deep networks.
import torch
import torch.nn as nn
# Example ReLU usage in a simplified FFN
d_model = 512
d_ff = 2048 # Typical inner dimension is 4*d_model
relu_ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
# Example input
x = torch.randn(16, 128, d_model) # Batch, sequence length, dimension
output = relu_ffn(x)
print("Output shape:", output.shape)
# Output: Output shape: torch.Size([16, 128, 512])
However, ReLU is not without drawbacks. The primary issue is the "dying ReLU" problem: neurons can become inactive if their input consistently falls below zero, causing their weights to stop updating because the gradient is zero in that region. While techniques like careful initialization and lower learning rates mitigate this, it remains a consideration, especially in very deep networks. Furthermore, its non-smooth nature at x=0 can sometimes hinder optimization compared to smoother alternatives.
The Gaussian Error Linear Unit (GeLU) was introduced as a smoother alternative to ReLU and gained prominence with models like BERT and the GPT series. It weights inputs by their value, but this weighting is stochastic, incorporating the standard Gaussian cumulative distribution function (Φ(x)).
GeLU(x)=x⋅Φ(x)
Since computing the exact Gaussian CDF can be slow, a common approximation is used:
GeLU(x)≈0.5x(1+tanh[2/π(x+0.044715x3)])
The intuition is that GeLU provides a smoother curve than ReLU, potentially allowing for easier optimization and better gradient flow. Empirically, it often outperformed ReLU in Transformer models.
import torch
import torch.nn as nn
# Example GeLU usage in a simplified FFN
d_model = 512
d_ff = 2048
gelu_ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(), # PyTorch uses the approximation by default
nn.Linear(d_ff, d_model)
)
# Example input
x = torch.randn(16, 128, d_model)
output = gelu_ffn(x)
print("Output shape:", output.shape)
# Output: Output shape: torch.Size([16, 128, 512])
GeLU is slightly more computationally intensive than ReLU but is well-supported by hardware accelerators. Its success in many foundational LLMs made it a standard choice for several years.
More recently, variations involving gating mechanisms within the FFN layer have shown strong performance. One popular variant is SwiGLU, introduced in the PaLM paper and used in models like Llama.
The core idea combines the Swish activation function (Swish(x)=x⋅σ(x), where σ is the sigmoid function) with a gating mechanism. Instead of a single linear layer expanding the dimension, SwiGLU typically uses two linear layers whose outputs are multiplied element-wise. One output passes through the Swish function, acting as a gate for the other.
SwiGLU(x,W,V,b,c)=(xW+b)⊗Swish(xV+c)
Here, x is the input, W, V, b, and c are learnable parameters, and ⊗ denotes element-wise multiplication. The Swish function is defined as:
Swish(x)=x⋅σ(βx) Often, β is set to 1 or made a learnable parameter.
import torch
import torch.nn as nn
import torch.nn.functional as F
class SwiGLUFFN(nn.Module):
def __init__(self, dim, hidden_dim, bias=True):
super().__init__()
# Usually hidden_dim is scaled, e.g., by 2/3 * 4 * dim,
# because SwiGLU splits the intermediate representation.
# Here, we simplify and assume hidden_dim is the target
# dimension *before* splitting for the gate.
# We need two linear layers for the gate mechanism
self.w1 = nn.Linear(dim, hidden_dim, bias=bias)
self.w2 = nn.Linear(dim, hidden_dim, bias=bias)
# The final linear layer
self.w3 = nn.Linear(hidden_dim, dim, bias=bias)
def forward(self, x):
# Apply the two linear layers
hidden1 = self.w1(x)
hidden2 = self.w2(x)
# Apply Swish activation to the first output and multiply element-wise
gated_hidden = F.silu(hidden1) * hidden2 # F.silu is PyTorch's Swish
# Apply the final linear layer
output = self.w3(gated_hidden)
return output
# Example SwiGLU usage
d_model = 512
# The effective hidden dimension in SwiGLU needs careful consideration.
# A common practice is to use a hidden dimension like (2/3 * 4 * d_model)
# so that the parameter count is similar to a standard FFN with 4 * d_model.
# For simplicity, let's use a smaller hidden_dim here.
d_ff_swiglu = 1024 # Example hidden dim for gating
swiglu_ffn = SwiGLUFFN(d_model, d_ff_swiglu)
# Example input
x = torch.randn(16, 128, d_model)
output = swiglu_ffn(x)
print("Output shape:", output.shape)
# Output: Output shape: torch.Size([16, 128, 512])
A subtle but important point about SwiGLU (and similar gated activations like GeGLU) is its impact on parameter count. To maintain a comparable number of parameters to a standard ReLU/GeLU FFN with an intermediate dimension of dff, the hidden_dim
used in the SwiGLU implementation (like d_ff_swiglu
above) is often set to around 32dff. This is because SwiGLU uses two linear projections (W and V) up to the intermediate dimension, effectively splitting the capacity usually handled by one larger matrix (W1) in the standard FFN. Despite this, SwiGLU has often been found to yield better perplexity scores and downstream performance compared to GeLU or ReLU in large models, suggesting the gating mechanism offers benefits beyond just parameter efficiency.
Comparison of ReLU, GeLU (approximation), and Swish activation functions. Note the increasing smoothness from ReLU to GeLU to Swish.
Choosing the right activation function involves trade-offs:
When scaling Transformers, moving from ReLU to GeLU or SwiGLU is a common architectural change aimed at improving performance. The performance gains from SwiGLU, despite the FFN implementation complexity, have led to its adoption in several recent large-scale models. As with many architectural choices, the optimal selection may depend on the specific model size, dataset, and computational budget, often requiring empirical validation.
© 2025 ApX Machine Learning