While the default DataLoader
provides convenient batching and shuffling, many real-world applications require finer control over how data is sampled and collated into batches. PyTorch offers flexibility through custom samplers and collate functions, allowing you to tailor the data loading process to your specific needs, such as handling imbalanced datasets or working with variable-sized inputs.
The DataLoader
uses a sampler
object to determine the order in which indices are drawn from the Dataset
. By default, if shuffle=True
, it uses a RandomSampler
, and if shuffle=False
, it uses a SequentialSampler
. However, you can explicitly pass your own sampler instance via the sampler
argument (note: if you provide a sampler
, you must leave shuffle=False
, as shuffling is defined by the sampler itself).
PyTorch provides several built-in samplers in torch.utils.data
:
SequentialSampler
: Samples elements sequentially, always in the same order.RandomSampler
: Samples elements randomly. If replacement=True
, samples are drawn with replacement.SubsetRandomSampler
: Samples elements randomly from a given list of indices. Useful for creating validation splits without modifying the original dataset.WeightedRandomSampler
: Samples elements from [0,..,len(weights)-1]
with given probabilities (weights). This is particularly useful for handling imbalanced datasets, where you want to oversample minority classes or undersample majority classes.Example: Using WeightedRandomSampler
for Imbalanced Data
Imagine a classification dataset where class '0' has 900 samples and class '1' has 100 samples. Simple random sampling would lead to batches heavily biased towards class '0'. We can use WeightedRandomSampler
to give samples from class '1' a higher probability of being selected.
import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
# Assume 'dataset' is your torch.utils.data.Dataset instance
# Assume 'targets' is a list or tensor containing the class label for each sample
# e.g., targets = [0, 0, 1, 0, ..., 1, 0]
# Calculate weights for each sample
class_counts = torch.bincount(torch.tensor(targets)) # Counts per class: e.g., [900, 100]
num_samples = len(targets) # Total samples: 1000
# Weight for each sample is 1 / (number of samples in its class)
sample_weights = torch.tensor([1.0 / class_counts[t] for t in targets])
# Create the sampler
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=num_samples, replacement=True)
# Create the DataLoader using the custom sampler
# Note: shuffle must be False when using a sampler
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)
# Now, batches drawn from this dataloader will have a more balanced
# representation of classes over time.
# for batch_features, batch_labels in dataloader:
# # Training steps...
# pass
You can also create entirely custom sampling strategies by inheriting from torch.utils.data.Sampler
and implementing the __iter__
and __len__
methods.
collate_fn
Once the sampler provides a list of indices for a batch, the DataLoader
fetches the corresponding samples from the Dataset
using dataset[index]
. It then needs to assemble these individual samples into a single batch. This assembly process is handled by the collate_fn
argument.
The default collate_fn
works well for many standard cases. It attempts to:
Dataset.__getitem__
returns a dictionary, the collated batch will be a dictionary where each value is a batch of the corresponding items).However, the default collate_fn
might fail or produce undesirable results if your samples have varying sizes (e.g., sequences of different lengths) or contain data types it doesn't know how to stack.
In such cases, you can provide a custom function to the collate_fn
argument of the DataLoader
. This function receives a list of samples (where each sample is the output of Dataset.__getitem__
) and is responsible for returning the collated batch in the desired format.
Example: Padding Variable-Length Sequences
A common scenario involves sequences (like sentences in NLP) that have different lengths. The default collate function cannot directly stack these into a single tensor. A custom collate_fn
can pad the sequences within each batch to the maximum length in that batch.
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
# Example Dataset returning variable-length tensors
class VariableSequenceDataset(Dataset):
def __init__(self, data):
# data is a list of tensors, e.g., [torch.randn(5), torch.randn(8), ...]
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
# For simplicity, assume each item also has a label (e.g., its length)
sequence = self.data[idx]
label = len(sequence)
return sequence, label
# Custom collate function
def pad_collate(batch):
# batch is a list of tuples: [(sequence1, label1), (sequence2, label2), ...]
# Sort batch elements by sequence length (optional, but often done for RNN efficiency)
# batch.sort(key=lambda x: len(x[0]), reverse=True) # Not strictly necessary for padding
# Separate sequences and labels
sequences = [item[0] for item in batch]
labels = [item[1] for item in batch]
# Pad sequences to the length of the longest sequence in the batch
# `batch_first=True` makes the output shape (batch_size, max_seq_len, features)
padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=0.0)
# Stack labels (assuming they are simple scalars)
labels = torch.tensor(labels)
return padded_sequences, labels
# Create dataset and dataloader
sequences = [torch.randn(torch.randint(5, 15, (1,)).item()) for _ in range(100)]
dataset = VariableSequenceDataset(sequences)
dataloader = DataLoader(dataset, batch_size=4, collate_fn=pad_collate)
# Iterate through the dataloader
# for padded_batch, label_batch in dataloader:
# # padded_batch shape: (4, max_len_in_this_batch, 1) if sequences were 1D
# # label_batch shape: (4,)
# # Model processing...
# pass
This custom collate_fn
uses torch.nn.utils.rnn.pad_sequence
to handle the padding, ensuring all sequences in the batch have the same length, making them suitable for processing by models like RNNs.
Besides sampler
and collate_fn
, other arguments offer performance and behavior tuning:
num_workers
(int, optional): Specifies how many subprocesses to use for data loading. Setting this to a positive integer enables multi-process data loading, which can significantly speed up data fetching, especially if data loading involves disk I/O or non-trivial preprocessing on the CPU. A common starting point is setting it to the number of CPU cores available. Default: 0
(data loading happens in the main process).pin_memory
(bool, optional): If True
, the DataLoader
will copy tensors into CUDA pinned memory before returning them. Pinned memory enables faster data transfer from CPU to GPU. This is only effective if you are training on a GPU. Default: False
.drop_last
(bool, optional): If True
, drops the last incomplete batch if the dataset size is not divisible by the batch size. If False
(default), the last batch might be smaller than batch_size
.By understanding and utilizing samplers, custom collate functions, and other DataLoader
arguments, you gain precise control over your data pipeline, enabling efficient handling of diverse data types and structures, addressing dataset imbalances, and optimizing data loading performance for faster model training.
© 2025 ApX Machine Learning