Once your data is encapsulated within a Dataset
object, the next step is often to preprocess and augment it before it's fed into your neural network. In PyTorch, this is primarily handled by the torchvision.transforms
module, especially when working with image data. These transforms are similar in purpose to TensorFlow's tf.image
functions or Keras preprocessing layers (tf.keras.layers.Rescaling
, tf.keras.layers.RandomFlip
, etc.), but their integration into the data loading pipeline is distinct.
Transforms in PyTorch are essentially callable Python objects that take a data sample (like a PIL Image or a PyTorch tensor) and return a transformed version. They can be chained together to create a preprocessing pipeline.
The torchvision.transforms
module offers a rich set of pre-built transformations. Let's look at some of the most frequently used ones.
transforms.ToTensor()
: This is a fundamental transform when working with image data. It converts a PIL Image (Python Imaging Library) or a NumPy array (with shape H x W x C, height by width by channels) into a torch.FloatTensor
of shape C x H x W. Importantly, it also scales the pixel values from the range [0, 255] to [0.0, 1.0]. This is typically one of the first transforms applied to image data.
In TensorFlow, you might achieve a similar effect by using tf.image.convert_image_dtype
with tf.float32
which also scales to [0,1], or by manually casting and dividing.
transforms.Normalize(mean, std)
: This transform normalizes a tensor image with a given mean and standard deviation for each channel. The formula applied is output[channel] = (input[channel] - mean[channel]) / std[channel]
. Normalization helps the model train more effectively by ensuring that input features have a similar range of values. The mean
and std
are sequences (like lists or tuples) of values, one for each channel. For RGB images, you'd provide three values for mean and three for std.
# Common mean and std for ImageNet pretrained models
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
This is analogous to TensorFlow's tf.image.per_image_standardization
if you compute mean and variance per image, or more commonly, applying a tf.keras.layers.Normalization
layer with pre-calculated statistics or by using its adapt
method on a sample of your data.
transforms.Resize(size)
: Resizes the input image to the given size
. If size
is an integer, the smaller edge of the image will be matched to this number, maintaining the aspect ratio. If size
is a sequence like (h, w)
, it resizes the image to these exact dimensions.
transforms.CenterCrop(size)
: Crops the given image at the center. size
can be an integer (for a square crop) or a sequence (h, w)
. This is often used in validation or testing pipelines after an initial resize.
Data augmentation is a technique to artificially increase the diversity of your training dataset by applying random transformations. This helps in making the model more robust and reducing overfitting. These transforms are typically applied only to the training data.
transforms.RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(3./4., 4./3.))
: Crops a random portion of the image and resizes it to the given size
. The scale
parameter defines the range of the area of the original image to crop, and ratio
defines the range of aspect ratios for the crop. This is a very common augmentation for image classification.
transforms.RandomHorizontalFlip(p=0.5)
: Horizontally flips the given image randomly with a given probability p
(default is 0.5).
transforms.RandomRotation(degrees, interpolation=...)
: Rotates the image by a random angle. degrees
can be a single number (e.g., 30
, for rotation between -30 and +30 degrees) or a sequence like (-30, 30)
.
transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)
: Randomly changes the brightness, contrast, saturation, and hue of an image. You can specify the range for each parameter. For example, brightness=0.2
means pick a brightness factor randomly from [max(0, 1 - 0.2), 1 + 0.2]
.
TensorFlow offers similar augmentation capabilities through tf.image
functions (e.g., tf.image.random_flip_left_right
, tf.image.random_brightness
) or Keras preprocessing layers (e.g., tf.keras.layers.RandomFlip
, tf.keras.layers.RandomRotation
). A key difference is that Keras layers are often part of the model graph and can execute on the accelerator (GPU/TPU), while PyTorch transforms are typically applied on the CPU by the DataLoader
workers.
transforms.Compose
To apply multiple transformations in sequence, PyTorch provides transforms.Compose
. You pass a list of transform objects to its constructor, and Compose
will apply them in the order they appear in the list.
Here's how you might define separate transform pipelines for training and validation:
from torchvision import transforms
# Training pipeline: includes augmentation
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(224), # Augmentation: random crop and resize
transforms.RandomHorizontalFlip(), # Augmentation: random horizontal flip
transforms.ColorJitter(brightness=0.2, contrast=0.2), # Augmentation: color adjustment
transforms.ToTensor(), # Convert to tensor and scale to [0, 1]
transforms.Normalize(mean=[0.485, 0.456, 0.406], # Normalize
std=[0.229, 0.224, 0.225])
])
# Validation/Testing pipeline: no augmentation, just necessary preprocessing
val_test_transforms = transforms.Compose([
transforms.Resize(256), # Resize smaller edge to 256
transforms.CenterCrop(224), # Crop center 224x224
transforms.ToTensor(), # Convert to tensor and scale to [0, 1]
transforms.Normalize(mean=[0.485, 0.456, 0.406], # Normalize
std=[0.229, 0.224, 0.225])
])
You would then pass these composed transform objects to your Dataset
instance, usually during its initialization or within its __getitem__
method. Many built-in datasets in torchvision.datasets
(like ImageFolder
) accept a transform
argument directly.
# Example with torchvision.datasets.ImageFolder
# from torchvision.datasets import ImageFolder
# train_dataset = ImageFolder(root='path/to/train_data', transform=train_transforms)
# val_dataset = ImageFolder(root='path/to/val_data', transform=val_test_transforms)
If you're writing a custom Dataset
, you'd typically apply the transform in the __getitem__
method:
# class MyCustomDataset(Dataset):
# def __init__(self, data_paths, labels, transform=None):
# self.data_paths = data_paths
# self.labels = labels
# self.transform = transform
#
# def __getitem__(self, index):
# img_path = self.data_paths[index]
# image = Image.open(img_path).convert("RGB") # Load image as PIL
# label = self.labels[index]
#
# if self.transform:
# image = self.transform(image) # Apply transforms
#
# return image, label
#
# def __len__(self):
# return len(self.data_paths)
While torchvision.transforms
covers many common use cases, you might need specialized preprocessing steps. You can create custom transforms easily in PyTorch. A transform can be any callable that accepts a sample and returns a transformed sample. Often, this is done by defining a class with a __call__
method.
Let's say you want to add a specific type of noise to your images:
import torch
import numpy as np
from PIL import Image # Assuming input is PIL Image for this example
class AddSaltAndPepperNoise:
def __init__(self, amount=0.05):
self.amount = amount
def __call__(self, img):
# Ensure img is a PIL Image, convert to numpy array
if not isinstance(img, Image.Image):
raise TypeError("Input must be a PIL Image.")
np_img = np.array(img)
original_shape = np_img.shape
# Add salt
num_salt = np.ceil(self.amount * np_img.size * 0.5)
coords = [np.random.randint(0, i - 1, int(num_salt)) for i in original_shape]
if len(original_shape) == 2: # Grayscale
np_img[coords[0], coords[1]] = 255
else: # Color
np_img[coords[0], coords[1], :] = 255
# Add pepper
num_pepper = np.ceil(self.amount * np_img.size * 0.5)
coords = [np.random.randint(0, i - 1, int(num_pepper)) for i in original_shape]
if len(original_shape) == 2: # Grayscale
np_img[coords[0], coords[1]] = 0
else: # Color
np_img[coords[0], coords[1], :] = 0
return Image.fromarray(np_img) # Return PIL Image
def __repr__(self):
return self.__class__.__name__ + f'(amount={self.amount})'
# Using the custom transform in a pipeline
# Assume image is loaded as PIL Image before this point
custom_pipeline = transforms.Compose([
AddSaltAndPepperNoise(amount=0.03),
transforms.ToTensor(),
# ... other transforms like Normalize
])
This custom transform AddSaltAndPepperNoise
can now be seamlessly integrated into a transforms.Compose
pipeline along with built-in transforms. Remember that the input and output types of your custom transform should be compatible with adjacent transforms in the pipeline (e.g., if it expects a PIL image, ensure it's placed before ToTensor()
).
Transitioning from TensorFlow's data preprocessing, keep these points in mind:
torchvision.transforms
are typically applied by the DataLoader
on the CPU in separate worker processes. This difference can affect the overall training throughput depending on the complexity of transformations and the balance of CPU/GPU workload. CPU-based preprocessing with tf.data.map
and tf.image
functions is a more direct parallel to the common PyTorch approach.RandomHorizontalFlip
) or configured with fixed parameters (e.g., Normalize
with pre-defined mean/std). If you need to compute statistics from your data (like the mean and standard deviation for normalization), you typically do this once offline and then hardcode these values into the Normalize
transform. This contrasts with Keras's Normalization
layer, which has an adapt()
method to compute these statistics online from a batch of data.nn.Module
based transforms to the model's forward
pass, though this is less common for standard input preprocessing).By understanding and utilizing torchvision.transforms
, you can build flexible and efficient data preprocessing pipelines in PyTorch, adapting the techniques you've learned in the TensorFlow ecosystem. Remember to always apply augmentations only to your training data and ensure consistent preprocessing for your validation and test sets. Experimentation is often necessary to find the optimal set of transformations for your specific task and dataset.
© 2025 ApX Machine Learning