Having discussed the concepts behind attention mechanisms like Squeeze-and-Excitation (SE) blocks and non-local networks earlier in this chapter, we now move to the practical implementation. Adding attention modules allows a CNN to dynamically re-weight its feature channels or spatial locations, effectively focusing on more informative parts of the input. This practice section focuses on implementing the widely adopted Squeeze-and-Excitation (SE) block within a typical CNN framework using PyTorch.
The SE block performs feature recalibration, aiming to explicitly model interdependencies between channels. It consists of two main operations:
Squeeze: Global information is aggregated from each channel spatial map. This is typically done using Global Average Pooling (GAP), producing a channel descriptor vector z∈RC, where C is the number of channels. For a channel c, the squeezed value zc is calculated as:
zc=H×W1i=1∑Hj=1∑Wuc(i,j)where uc is the c-th feature map of input U, and H,W are its spatial dimensions.
Excitation: The aggregated information is used to learn channel-wise attention weights. This typically involves two fully connected (Linear) layers: a dimensionality-reduction layer with reduction ratio r and activation (e.g., ReLU), followed by a dimensionality-increasing layer back to C channels with a gating activation (e.g., Sigmoid). The resulting vector s∈RC contains weights between 0 and 1 for each channel:
s=σ(W2δ(W1z))Here, W1∈RrC×C and W2∈RC×rC are the weights of the linear layers, δ is the ReLU activation, and σ is the Sigmoid activation.
Scale (or Recalibration): The original input feature map U is scaled by the learned attention weights s. The output feature map X~ is obtained by element-wise multiplication:
x~c=sc⋅ucwhere x~c and uc are the c-th channels of the output X~ and input U respectively, and sc is the learned scalar weight for channel c.
Let's implement this as a reusable PyTorch module. We define a class SEBlock
inheriting from torch.nn.Module
.
import torch
import torch.nn as nn
class SEBlock(nn.Module):
"""
Squeeze-and-Excitation Block.
Adds channel-wise attention to a convolutional block.
"""
def __init__(self, channels, reduction_ratio=16):
"""
Initializes the SE Block.
Args:
channels (int): Number of input channels.
reduction_ratio (int): Factor by which to reduce channels
in the intermediate layer. Default: 16.
"""
super(SEBlock, self).__init__()
if channels <= reduction_ratio:
# Avoid reducing channels to zero or negative numbers
reduced_channels = channels // 2 if channels > 1 else 1
else:
reduced_channels = channels // reduction_ratio
# Squeeze operation: Global Average Pooling
self.squeeze = nn.AdaptiveAvgPool2d(1)
# Excitation operation: Two Linear layers
self.excitation = nn.Sequential(
nn.Linear(channels, reduced_channels, bias=False),
nn.ReLU(inplace=True),
nn.Linear(reduced_channels, channels, bias=False),
nn.Sigmoid()
)
def forward(self, x):
"""
Forward pass through the SE Block.
Args:
x (torch.Tensor): Input tensor of shape (batch, channels, height, width).
Returns:
torch.Tensor: Output tensor, input scaled by channel-wise attention weights.
"""
batch_size, num_channels, _, _ = x.size()
# Squeeze: (batch, channels, height, width) -> (batch, channels, 1, 1)
squeezed = self.squeeze(x)
# Reshape for Linear layers: (batch, channels, 1, 1) -> (batch, channels)
squeezed = squeezed.view(batch_size, num_channels)
# Excitation: (batch, channels) -> (batch, channels)
channel_weights = self.excitation(squeezed)
# Reshape weights for scaling: (batch, channels) -> (batch, channels, 1, 1)
channel_weights = channel_weights.view(batch_size, num_channels, 1, 1)
# Scale: Multiply original input by the learned channel weights
scaled_output = x * channel_weights
return scaled_output
This SEBlock
module can now be easily integrated into existing CNN architectures. The reduction_ratio
is a hyperparameter controlling the capacity and computational cost of the attention mechanism. A common value is 16.
SE blocks are often added after the main convolutional operations within a building block, like a ResNet block, but before the residual connection is added. Let's illustrate how to modify a basic ResNet block to include an SE layer.
Consider a simplified ResNet block structure:
class BasicResNetBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(BasicResNetBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
# Shortcut connection
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
identity = self.shortcut(x) # Prepare shortcut connection
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity # Add shortcut
out = self.relu(out) # Final ReLU
return out
Now, let's add the SEBlock
just before the residual addition:
class SEResNetBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, reduction_ratio=16):
super(SEResNetBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
# Add the SE Block here
self.se_block = SEBlock(out_channels, reduction_ratio)
# Shortcut connection (same as before)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
identity = self.shortcut(x) # Prepare shortcut connection
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
# Apply SE Block to the main path output
out = self.se_block(out)
out += identity # Add shortcut
out = self.relu(out) # Final ReLU
return out
The diagram below illustrates the data flow within the SEResNetBlock
.
Data flow within a ResNet block augmented with a Squeeze-and-Excitation block. The SE block recalibrates the feature maps from the main convolutional path before they are combined with the shortcut connection.
r
: This controls the bottleneck complexity in the excitation phase. A smaller r
(e.g., 8) means a more complex bottleneck, potentially capturing finer channel relationships but increasing parameters. A larger r
(e.g., 32) reduces parameters but might limit the expressiveness of the attention mechanism. The default of 16 is a reasonable starting point.This hands-on example demonstrates how to implement and integrate a fundamental channel attention mechanism into a standard CNN. By selectively amplifying informative features and suppressing less useful ones based on global channel context, SE blocks can often lead to improvements in model accuracy for various computer vision tasks. Similar principles apply when integrating other attention mechanisms, such as spatial attention or non-local blocks, although their specific implementations will differ.
© 2025 ApX Machine Learning