Now that we've covered the foundational concepts and architectures for image segmentation, let's put theory into practice. This section guides you through building, training, and evaluating a semantic segmentation model. We will implement a simplified version of the U-Net architecture, a popular choice for segmentation tasks, particularly in domains like medical imaging, using PyTorch. While we focus on U-Net here, the principles apply broadly to other architectures like FCNs or DeepLab.
We assume you have a working Python environment with PyTorch, TorchVision, and libraries like NumPy and Matplotlib installed.
Semantic segmentation requires images and corresponding pixel-level masks. Each pixel in the mask is labeled with the class it belongs to (e.g., 0 for background, 1 for road, 2 for building).
For this exercise, you might use a standard dataset like Pascal VOC, Cityscapes, or even create a simple synthetic dataset. Let's assume we have a dataset directory structure like this:
data/
├── images/
│ ├── 0001.png
│ ├── 0002.png
│ └── ...
└── masks/
├── 0001.png
├── 0002.png
└── ...
We'll need a custom PyTorch Dataset
class to load images and masks.
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import numpy as np
class SegmentationDataset(Dataset):
def __init__(self, image_dir, mask_dir, transform=None, mask_transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.image_filenames = sorted(os.listdir(image_dir))
self.mask_filenames = sorted(os.listdir(mask_dir))
self.transform = transform
self.mask_transform = mask_transform
# Basic check: ensure image and mask lists match
assert len(self.image_filenames) == len(self.mask_filenames), \
"Number of images and masks must be the same."
# Optionally add more rigorous checks here (e.g., matching filenames)
def __len__(self):
return len(self.image_filenames)
def __getitem__(self, idx):
img_path = os.path.join(self.image_dir, self.image_filenames[idx])
mask_path = os.path.join(self.mask_dir, self.mask_filenames[idx])
image = Image.open(img_path).convert("RGB")
mask = Image.open(mask_path).convert("L") # Assuming mask is grayscale
if self.transform:
image = self.transform(image)
if self.mask_transform:
# Important: Apply geometric transforms identically to image and mask
# but avoid normalizing the mask values like the image.
# Often requires careful handling of random transformations.
# For simplicity here, assume basic resize/tensor conversion.
mask = self.mask_transform(mask)
# Convert mask to LongTensor for CrossEntropyLoss
mask = mask.squeeze(0).long()
else:
# Default conversion if no specific mask transform
mask = torch.from_numpy(np.array(mask)).long()
return image, mask
# Define transformations (adjust size and normalization as needed)
image_transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
mask_transform = transforms.Compose([
transforms.Resize((128, 128), interpolation=transforms.InterpolationMode.NEAREST), # Use NEAREST for masks
transforms.ToTensor()
])
# Create Datasets and DataLoaders
# Replace with your actual data paths
train_dataset = SegmentationDataset('data/images', 'data/masks', transform=image_transform, mask_transform=mask_transform)
# val_dataset = SegmentationDataset('data_val/images', 'data_val/masks', transform=image_transform, mask_transform=mask_transform) # For validation
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
# val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4)
Note the use of transforms.InterpolationMode.NEAREST
for resizing masks. This prevents interpolation from creating invalid class labels between existing ones. Mask tensors should typically be of type LongTensor
.
Let's implement a simplified U-Net. It consists of an encoder (contracting path) that captures context and a decoder (expansive path) that enables precise localization using transposed convolutions and skip connections.
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""(Convolution => BatchNorm => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e5474a7ae105f32e70a5168b
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=True):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
factor = 2 if bilinear else 1
self.down4 = Down(512, 1024 // factor)
self.up1 = Up(1024, 512 // factor, bilinear)
self.up2 = Up(512, 256 // factor, bilinear)
self.up3 = Up(256, 128 // factor, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
# Instantiate the model
# n_channels=3 for RGB images, n_classes = number of segmentation classes (e.g., 2 for binary)
num_classes = 2 # Example: Background + Foreground
model = UNet(n_channels=3, n_classes=num_classes)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
This U-Net implementation uses standard convolutional blocks, max-pooling for downsampling, and optionally bilinear upsampling or transposed convolutions for upsampling. The skip connections concatenate feature maps from the encoder path with the upsampled feature maps in the decoder path, helping to recover fine-grained details lost during downsampling.
For semantic segmentation with multiple classes, the standard loss function is Cross-Entropy Loss applied pixel-wise. Each pixel is treated as a classification problem. If your dataset is highly imbalanced (e.g., small objects in large backgrounds), you might consider weighted cross-entropy or Dice Loss.
CrossEntropyLoss(output,target)=−c=1∑Ctargetclog(softmax(output)c)Where C is the number of classes, output are the raw logits from the model for a pixel, and target is the one-hot encoded ground truth label for that pixel (though PyTorch's nn.CrossEntropyLoss
handles integer targets directly).
We'll use the Adam optimizer.
import torch.optim as optim
# Loss Function
# `ignore_index` can be useful if you have a label to ignore (e.g., border pixels)
criterion = nn.CrossEntropyLoss()#ignore_index=255)
# Optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-4) # Learning rate might need tuning
The training loop iterates through the dataset, performs forward and backward passes, and updates the model weights.
num_epochs = 25 # Adjust as needed
train_losses = []
model.train() # Set model to training mode
for epoch in range(num_epochs):
running_loss = 0.0
for i, (images, masks) in enumerate(train_loader):
images = images.to(device)
masks = masks.to(device) # Shape: [batch_size, H, W]
# Zero the parameter gradients
optimizer.zero_grad()
# Forward pass
outputs = model(images) # Shape: [batch_size, num_classes, H, W]
# Calculate loss
loss = criterion(outputs, masks)
# Backward pass and optimize
loss.backward()
optimizer.step()
running_loss += loss.item()
if (i + 1) % 50 == 0: # Print status every 50 mini-batches
print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
epoch_loss = running_loss / len(train_loader)
train_losses.append(epoch_loss)
print(f'Epoch [{epoch+1}/{num_epochs}] completed. Average Loss: {epoch_loss:.4f}')
print('Finished Training')
# Save the trained model (optional)
# torch.save(model.state_dict(), 'unet_segmentation_model.pth')
This is a basic training loop. In practice, you would add:
The most common metric for segmentation is Intersection over Union (IoU), also known as the Jaccard Index. It measures the overlap between the predicted segmentation mask (A) and the ground truth mask (B) for a specific class.
IoU=J(A,B)=∣A∪B∣∣A∩B∣=Union AreaIntersection AreaMean IoU (mIoU) is often reported, which is the average IoU calculated over all classes.
def calculate_iou(pred, target, num_classes, smooth=1e-6):
"""Calculates IoU for each class."""
pred = torch.argmax(pred, dim=1) # Convert logits to predicted class indices [B, H, W]
pred = pred.contiguous().view(-1)
target = target.contiguous().view(-1)
iou_per_class = []
for clas in range(num_classes): # Calculate IoU for each class
pred_inds = (pred == clas)
target_inds = (target == clas)
intersection = (pred_inds[target_inds]).long().sum().item() # Correct intersection calc
union = pred_inds.long().sum().item() + target_inds.long().sum().item() - intersection
if union == 0:
# If there is no ground truth or prediction, score is 1 if both empty, 0 otherwise
iou_per_class.append(float('nan')) # or 0 or 1 according to convention
else:
iou = (intersection + smooth) / (union + smooth)
iou_per_class.append(iou)
return np.array(iou_per_class)
def calculate_miou(pred_loader, model, num_classes, device):
"""Calculates mean IoU over a dataset."""
model.eval() # Set model to evaluation mode
total_iou = np.zeros(num_classes)
num_samples = 0
with torch.no_grad():
for images, masks in pred_loader:
images = images.to(device)
masks = masks.to(device) # Ground truth masks
outputs = model(images) # Model predictions (logits)
iou = calculate_iou(outputs.cpu(), masks.cpu(), num_classes)
# Handle NaN values if a class is not present in the batch
# For a robust mIoU, accumulate intersection and union counts across batches
# This simplified version averages batch IoUs, which can be less accurate.
valid_iou = iou[~np.isnan(iou)]
if len(valid_iou) > 0:
total_iou[:len(valid_iou)] += valid_iou # Accumulate IoU per class
num_samples += 1 # Count batches with valid IoU scores
# Calculate mean IoU, ignoring NaNs from classes absent in the dataset partition
mean_iou_per_class = total_iou / num_samples
mean_iou = np.nanmean(mean_iou_per_class) # Average across classes that were present
print(f'Mean IoU across {num_samples} samples: {mean_iou:.4f}')
print(f'IoU per class: {mean_iou_per_class}')
return mean_iou
# Example usage after training (assuming you have a val_loader)
# mIoU = calculate_miou(val_loader, model, num_classes, device)
Implementing a robust mIoU calculation often involves accumulating the intersection and union counts per class across all batches before dividing, rather than averaging per-batch IoUs, especially when classes might be absent in some batches.
Visualizing the model's output helps understand its performance qualitatively.
import matplotlib.pyplot as plt
def visualize_predictions(dataset, model, device, num_samples=5):
model.eval()
samples_shown = 0
fig, axes = plt.subplots(num_samples, 3, figsize=(10, num_samples * 3))
fig.suptitle("Image / Ground Truth / Prediction")
# Use the dataset directly to get raw images and masks before normalization
vis_dataset = SegmentationDataset('data/images', 'data/masks',
transform=transforms.Compose([transforms.Resize((128, 128))]),
mask_transform=transforms.Compose([transforms.Resize((128, 128), interpolation=transforms.InterpolationMode.NEAREST)]))
# Get normalized images for model input
input_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
with torch.no_grad():
for i in range(len(vis_dataset)):
if samples_shown >= num_samples:
break
raw_image, raw_mask = vis_dataset[i] # Get raw PIL images/arrays
input_image = input_transform(raw_image).unsqueeze(0).to(device) # Prepare for model
output = model(input_image)
pred_mask = torch.argmax(output.squeeze(), dim=0).cpu().numpy()
axes[samples_shown, 0].imshow(raw_image)
axes[samples_shown, 0].set_title("Image")
axes[samples_shown, 0].axis('off')
axes[samples_shown, 1].imshow(raw_mask, cmap='gray') # Adjust cmap if needed
axes[samples_shown, 1].set_title("Ground Truth")
axes[samples_shown, 1].axis('off')
axes[samples_shown, 2].imshow(pred_mask, cmap='gray') # Adjust cmap based on num_classes
axes[samples_shown, 2].set_title("Prediction")
axes[samples_shown, 2].axis('off')
samples_shown += 1
plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap
plt.show()
# Example usage:
# visualize_predictions(train_dataset, model, device) # Use train or val dataset
This visualization function shows the original image, the ground truth mask, and the model's predicted mask side-by-side for a few examples.
Optionally, you can plot the training loss curve to check convergence:
Hypothetical training loss curve showing a decrease over 25 epochs.
This practical provides a starting point. To improve your segmentation model, consider:
Building effective segmentation models involves careful data preparation, appropriate architecture selection, correct loss implementation, and thorough evaluation. This hands-on exercise provides the fundamental building blocks for tackling diverse segmentation challenges.
© 2025 ApX Machine Learning