Preprocessing and augmenting data is a common requirement for Dataset objects before they are fed into a neural network. PyTorch primarily handles this through the torchvision.transforms module, especially when working with image data. These transforms are similar in purpose to TensorFlow's tf.image functions or Keras preprocessing layers (tf.keras.layers.Rescaling, tf.keras.layers.RandomFlip, etc.), but their integration into the data loading pipeline is distinct.
Transforms in PyTorch are essentially callable Python objects that take a data sample (like a PIL Image or a PyTorch tensor) and return a transformed version. They can be chained together to create a preprocessing pipeline.
The torchvision.transforms module offers a rich set of pre-built transformations. Let's look at some of the most frequently used ones.
transforms.ToTensor(): This is a fundamental transform when working with image data. It converts a PIL Image (Python Imaging Library) or a NumPy array (with shape H x W x C, height by width by channels) into a torch.FloatTensor of shape C x H x W. Importantly, it also scales the pixel values from the range [0, 255] to [0.0, 1.0]. This is typically one of the first transforms applied to image data.
In TensorFlow, you might achieve a similar effect by using tf.image.convert_image_dtype with tf.float32 which also scales to [0,1], or by manually casting and dividing.
transforms.Normalize(mean, std): This transform normalizes a tensor image with a given mean and standard deviation for each channel. The formula applied is output[channel] = (input[channel] - mean[channel]) / std[channel]. Normalization helps the model train more effectively by ensuring that input features have a similar range of values. The mean and std are sequences (like lists or tuples) of values, one for each channel. For RGB images, you'd provide three values for mean and three for std.
# Common mean and std for ImageNet pretrained models
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
This is analogous to TensorFlow's tf.image.per_image_standardization if you compute mean and variance per image, or more commonly, applying a tf.keras.layers.Normalization layer with pre-calculated statistics or by using its adapt method on a sample of your data.
transforms.Resize(size): Resizes the input image to the given size. If size is an integer, the smaller edge of the image will be matched to this number, maintaining the aspect ratio. If size is a sequence like (h, w), it resizes the image to these exact dimensions.
transforms.CenterCrop(size): Crops the given image at the center. size can be an integer (for a square crop) or a sequence (h, w). This is often used in validation or testing pipelines after an initial resize.
Data augmentation is a technique to artificially increase the diversity of your training dataset by applying random transformations. This helps in making the model more effective and reducing overfitting. These transforms are typically applied only to the training data.
transforms.RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(3./4., 4./3.)): Crops a random portion of the image and resizes it to the given size. The scale parameter defines the range of the area of the original image to crop, and ratio defines the range of aspect ratios for the crop. This is a very common augmentation for image classification.
transforms.RandomHorizontalFlip(p=0.5): Horizontally flips the given image randomly with a given probability p (default is 0.5).
transforms.RandomRotation(degrees, interpolation=...): Rotates the image by a random angle. degrees can be a single number (e.g., 30, for rotation between -30 and +30 degrees) or a sequence like (-30, 30).
transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0): Randomly changes the brightness, contrast, saturation, and hue of an image. You can specify the range for each parameter. For example, brightness=0.2 means pick a brightness factor randomly from [max(0, 1 - 0.2), 1 + 0.2].
TensorFlow offers similar augmentation capabilities through tf.image functions (e.g., tf.image.random_flip_left_right, tf.image.random_brightness) or Keras preprocessing layers (e.g., tf.keras.layers.RandomFlip, tf.keras.layers.RandomRotation). A main difference is that Keras layers are often part of the model graph and can execute on the accelerator (GPU/TPU), while PyTorch transforms are typically applied on the CPU by the DataLoader workers.
transforms.ComposeTo apply multiple transformations in sequence, PyTorch provides transforms.Compose. You pass a list of transform objects to its constructor, and Compose will apply them in the order they appear in the list.
Here's how you might define separate transform pipelines for training and validation:
from torchvision import transforms
# Training pipeline: includes augmentation
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(224), # Augmentation: random crop and resize
transforms.RandomHorizontalFlip(), # Augmentation: random horizontal flip
transforms.ColorJitter(brightness=0.2, contrast=0.2), # Augmentation: color adjustment
transforms.ToTensor(), # Convert to tensor and scale to [0, 1]
transforms.Normalize(mean=[0.485, 0.456, 0.406], # Normalize
std=[0.229, 0.224, 0.225])
])
# Validation/Testing pipeline: no augmentation, just necessary preprocessing
val_test_transforms = transforms.Compose([
transforms.Resize(256), # Resize smaller edge to 256
transforms.CenterCrop(224), # Crop center 224x224
transforms.ToTensor(), # Convert to tensor and scale to [0, 1]
transforms.Normalize(mean=[0.485, 0.456, 0.406], # Normalize
std=[0.229, 0.224, 0.225])
])
You would then pass these composed transform objects to your Dataset instance, usually during its initialization or within its __getitem__ method. Many built-in datasets in torchvision.datasets (like ImageFolder) accept a transform argument directly.
# Example with torchvision.datasets.ImageFolder
# from torchvision.datasets import ImageFolder
# train_dataset = ImageFolder(root='path/to/train_data', transform=train_transforms)
# val_dataset = ImageFolder(root='path/to/val_data', transform=val_test_transforms)
If you're writing a custom Dataset, you'd typically apply the transform in the __getitem__ method:
# class MyCustomDataset(Dataset):
# def __init__(self, data_paths, labels, transform=None):
# self.data_paths = data_paths
# self.labels = labels
# self.transform = transform
#
# def __getitem__(self, index):
# img_path = self.data_paths[index]
# image = Image.open(img_path).convert("RGB") # Load image as PIL
# label = self.labels[index]
#
# if self.transform:
# image = self.transform(image) # Apply transforms
#
# return image, label
#
# def __len__(self):
# return len(self.data_paths)
While torchvision.transforms covers many common use cases, you might need specialized preprocessing steps. You can create custom transforms easily in PyTorch. A transform can be any callable that accepts a sample and returns a transformed sample. Often, this is done by defining a class with a __call__ method.
Let's say you want to add a specific type of noise to your images:
import torch
import numpy as np
from PIL import Image # Assuming input is PIL Image for this example
class AddSaltAndPepperNoise:
def __init__(self, amount=0.05):
self.amount = amount
def __call__(self, img):
# Ensure img is a PIL Image, convert to numpy array
if not isinstance(img, Image.Image):
raise TypeError("Input must be a PIL Image.")
np_img = np.array(img)
original_shape = np_img.shape
# Add salt
num_salt = np.ceil(self.amount * np_img.size * 0.5)
coords = [np.random.randint(0, i - 1, int(num_salt)) for i in original_shape]
if len(original_shape) == 2: # Grayscale
np_img[coords[0], coords[1]] = 255
else: # Color
np_img[coords[0], coords[1], :] = 255
# Add pepper
num_pepper = np.ceil(self.amount * np_img.size * 0.5)
coords = [np.random.randint(0, i - 1, int(num_pepper)) for i in original_shape]
if len(original_shape) == 2: # Grayscale
np_img[coords[0], coords[1]] = 0
else: # Color
np_img[coords[0], coords[1], :] = 0
return Image.fromarray(np_img) # Return PIL Image
def __repr__(self):
return self.__class__.__name__ + f'(amount={self.amount})'
# Using the custom transform in a pipeline
# Assume image is loaded as PIL Image before this point
custom_pipeline = transforms.Compose([
AddSaltAndPepperNoise(amount=0.03),
transforms.ToTensor(),
# ... other transforms like Normalize
])
This custom transform AddSaltAndPepperNoise can now be integrated into a transforms.Compose pipeline along with built-in transforms. Remember that the input and output types of your custom transform should be compatible with adjacent transforms in the pipeline (e.g., if it expects a PIL image, ensure it's placed before ToTensor()).
Transitioning from TensorFlow's data preprocessing, keep these points in mind:
torchvision.transforms are typically applied by the DataLoader on the CPU in separate worker processes. This difference can affect the overall training throughput depending on the complexity of transformations and the balance of CPU/GPU workload. CPU-based preprocessing with tf.data.map and tf.image functions is a more direct parallel to the common PyTorch approach.RandomHorizontalFlip) or configured with fixed parameters (e.g., Normalize with pre-defined mean/std). If you need to compute statistics from your data (like the mean and standard deviation for normalization), you typically do this once offline and then hardcode these values into the Normalize transform. This contrasts with Keras's Normalization layer, which has an adapt() method to compute these statistics online from a batch of data.nn.Module based transforms to the model's forward pass, though this is less common for standard input preprocessing).By understanding and utilizing torchvision.transforms, you can build flexible and efficient data preprocessing pipelines in PyTorch, adapting the techniques you've learned in the TensorFlow ecosystem. Remember to always apply augmentations only to your training data and ensure consistent preprocessing for your validation and test sets. Experimentation is often necessary to find the optimal set of transformations for your specific task and dataset.
Was this section helpful?
torchvision.transforms - PyTorch documentation, PyTorch Core Team, 2024 - Official documentation for PyTorch's image transformation library, detailing various preprocessing and augmentation transforms.© 2026 ApX Machine LearningEngineered with