requires_grad)backward()).grad)torch.nntorch.nn.Module Base Classtorch.nn losses)torch.optim)torch.utils.data.Datasettorchvision.transforms)torch.utils.data.DataLoader"While the default DataLoader provides convenient batching and shuffling, many applications require finer control over how data is sampled and collated into batches. PyTorch offers flexibility through custom samplers and collate functions, allowing you to tailor the data loading process to your specific needs, such as handling imbalanced datasets or working with variable-sized inputs."
The DataLoader uses a sampler object to determine the order in which indices are drawn from the Dataset. By default, if shuffle=True, it uses a RandomSampler, and if shuffle=False, it uses a SequentialSampler. However, you can explicitly pass your own sampler instance via the sampler argument (note: if you provide a sampler, you must leave shuffle=False, as shuffling is defined by the sampler itself).
PyTorch provides several built-in samplers in torch.utils.data:
SequentialSampler: Samples elements sequentially, always in the same order.RandomSampler: Samples elements randomly. If replacement=True, samples are drawn with replacement.SubsetRandomSampler: Samples elements randomly from a given list of indices. Useful for creating validation splits without modifying the original dataset.WeightedRandomSampler: Samples elements from [0,..,len(weights)-1] with given probabilities (weights). This is particularly useful for handling imbalanced datasets, where you want to oversample minority classes or undersample majority classes.Example: Using WeightedRandomSampler for Imbalanced Data
Imagine a classification dataset where class '0' has 900 samples and class '1' has 100 samples. Simple random sampling would lead to batches heavily biased towards class '0'. We can use WeightedRandomSampler to give samples from class '1' a higher probability of being selected.
import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
# Assume 'dataset' is your torch.utils.data.Dataset instance
# Assume 'targets' is a list or tensor containing the class label for each sample
# e.g., targets = [0, 0, 1, 0, ..., 1, 0]
# Calculate weights for each sample
class_counts = torch.bincount(torch.tensor(targets)) # Counts per class: e.g., [900, 100]
num_samples = len(targets) # Total samples: 1000
# Weight for each sample is 1 / (number of samples in its class)
sample_weights = torch.tensor([1.0 / class_counts[t] for t in targets])
# Create the sampler
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=num_samples, replacement=True)
# Create the DataLoader using the custom sampler
# Note: shuffle must be False when using a sampler
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)
# Now, batches drawn from this dataloader will have a more balanced
# representation of classes over time.
# for batch_features, batch_labels in dataloader:
# # Training steps...
# pass
You can also create entirely custom sampling strategies by inheriting from torch.utils.data.Sampler and implementing the __iter__ and __len__ methods.
collate_fnOnce the sampler provides a list of indices for a batch, the DataLoader fetches the corresponding samples from the Dataset using dataset[index]. It then needs to assemble these individual samples into a single batch. This assembly process is handled by the collate_fn argument.
The default collate_fn works well for many standard cases. It attempts to:
Dataset.__getitem__ returns a dictionary, the collated batch will be a dictionary where each value is a batch of the corresponding items).However, the default collate_fn might fail or produce undesirable results if your samples have varying sizes (e.g., sequences of different lengths) or contain data types it doesn't know how to stack.
In such cases, you can provide a custom function to the collate_fn argument of the DataLoader. This function receives a list of samples (where each sample is the output of Dataset.__getitem__) and is responsible for returning the collated batch in the desired format.
Example: Padding Variable-Length Sequences
A common scenario involves sequences (like sentences in NLP) that have different lengths. The default collate function cannot directly stack these into a single tensor. A custom collate_fn can pad the sequences within each batch to the maximum length in that batch.
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
# Example Dataset returning variable-length tensors
class VariableSequenceDataset(Dataset):
def __init__(self, data):
# data is a list of tensors, e.g., [torch.randn(5), torch.randn(8), ...]
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
# For simplicity, assume each item also has a label (e.g., its length)
sequence = self.data[idx]
label = len(sequence)
return sequence, label
# Custom collate function
def pad_collate(batch):
# batch is a list of tuples: [(sequence1, label1), (sequence2, label2), ...]
# Sort batch elements by sequence length (optional, but often done for RNN efficiency)
# batch.sort(key=lambda x: len(x[0]), reverse=True) # Not strictly necessary for padding
# Separate sequences and labels
sequences = [item[0] for item in batch]
labels = [item[1] for item in batch]
# Pad sequences to the length of the longest sequence in the batch
# `batch_first=True` makes the output shape (batch_size, max_seq_len, features)
padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=0.0)
# Stack labels (assuming they are simple scalars)
labels = torch.tensor(labels)
return padded_sequences, labels
# Create dataset and dataloader
sequences = [torch.randn(torch.randint(5, 15, (1,)).item()) for _ in range(100)]
dataset = VariableSequenceDataset(sequences)
dataloader = DataLoader(dataset, batch_size=4, collate_fn=pad_collate)
# Iterate through the dataloader
# for padded_batch, label_batch in dataloader:
# # padded_batch shape: (4, max_len_in_this_batch, 1) if sequences were 1D
# # label_batch shape: (4,)
# # Model processing...
# pass
This custom collate_fn uses torch.nn.utils.rnn.pad_sequence to handle the padding, ensuring all sequences in the batch have the same length, making them suitable for processing by models like RNNs.
Besides sampler and collate_fn, other arguments offer performance and behavior tuning:
num_workers (int, optional): Specifies how many subprocesses to use for data loading. Setting this to a positive integer enables multi-process data loading, which can significantly speed up data fetching, especially if data loading involves disk I/O or non-trivial preprocessing on the CPU. A common starting point is setting it to the number of CPU cores available. Default: 0 (data loading happens in the main process).pin_memory (bool, optional): If True, the DataLoader will copy tensors into CUDA pinned memory before returning them. Pinned memory enables faster data transfer from CPU to GPU. This is only effective if you are training on a GPU. Default: False.drop_last (bool, optional): If True, drops the last incomplete batch if the dataset size is not divisible by the batch size. If False (default), the last batch might be smaller than batch_size.By understanding and utilizing samplers, custom collate functions, and other DataLoader arguments, you gain precise control over your data pipeline, enabling efficient handling of diverse data types and structures, addressing dataset imbalances, and optimizing data loading performance for faster model training.
Was this section helpful?
torch.utils.data.DataLoader, PyTorch Authors, 2025 (PyTorch Foundation) - Official documentation describing the DataLoader class, its arguments, and integration with datasets and samplers.torch.utils.data.Sampler, PyTorch Authors, 2025 (PyTorch Foundation) - Official documentation detailing various built-in samplers and providing guidance for creating custom sampling strategies.torch.nn.utils.rnn.pad_sequence, PyTorch Authors, 2024 - Official documentation for padding sequences, demonstrating a specific utility for custom collate functions.© 2026 ApX Machine LearningEngineered with