While creating custom Dataset
classes gives you maximum flexibility for your specific data, many deep learning tasks, especially in research and benchmarking, utilize standardized datasets. Preparing these datasets manually involves downloading, extracting, organizing files, and writing parsing logic, which can be time consuming and error prone.
Fortunately, PyTorch offers companion libraries that streamline this process for common domains. For computer vision, the torchvision
package is an indispensable tool. It contains not only popular datasets but also pre trained models and common image transformation functions. This section focuses on accessing and using the datasets provided by torchvision.datasets
.
torchvision.datasets
The torchvision.datasets
module provides convenient access to many widely used computer vision datasets, such as MNIST, Fashion MNIST, CIFAR 10/100, ImageNet, COCO, and others. Using these datasets is straightforward. You typically import the specific dataset class from torchvision.datasets
and instantiate it.
Let's look at an example using the CIFAR 10 dataset, which consists of 60,000 32x32 color images in 10 classes.
import torchvision
import torchvision.transforms as transforms
# Define a simple transformation to convert images to PyTorch Tensors
transform = transforms.Compose([transforms.ToTensor()])
# Load the training dataset
# root: directory where data will be stored/found
# train=True: specifies the training set
# download=True: downloads the data if not found locally
# transform: applies the defined transformation to each image
train_dataset = torchvision.datasets.CIFAR10(root='./data',
train=True,
download=True,
transform=transform)
# Load the test dataset
test_dataset = torchvision.datasets.CIFAR10(root='./data',
train=False,
download=True,
transform=transform)
print(f"CIFAR-10 training dataset size: {len(train_dataset)}")
print(f"CIFAR-10 test dataset size: {len(test_dataset)}")
# Accessing a single data point (image, label)
img, label = train_dataset[0]
print(f"Image shape: {img.shape}") # Output typically: torch.Size([3, 32, 32])
print(f"Label: {label}") # Output: An integer representing the class
When you first run this code, torchvision
checks the specified root
directory (./data
in this case). If the CIFAR 10 data is not present, setting download=True
instructs torchvision
to automatically download and extract the dataset into that directory. Subsequent runs will find the data locally and skip the download.
Notice the transform
argument. This is where you can specify preprocessing steps to be applied to each data sample after it's loaded but before it's returned by __getitem__
. We used transforms.ToTensor()
, which converts the PIL Image format (commonly used by torchvision
datasets) into a PyTorch Tensor. Data transformations are covered in more detail in the next section.
Importantly, the objects returned by torchvision.datasets
(like train_dataset
and test_dataset
above) are instances of classes that inherit from torch.utils.data.Dataset
. This means they implement the necessary __len__
and __getitem__
methods, making them fully compatible with PyTorch's DataLoader
.
len(train_dataset)
returns the total number of samples in the dataset.train_dataset[i]
returns the i-th sample, typically as a tuple (data, target)
, where data
is the preprocessed input (e.g., an image tensor) and target
is the corresponding label or annotation.Here's a simple visualization of the class distribution in the CIFAR-10 training set:
The CIFAR-10 dataset is balanced, with exactly 5,000 training images per class.
While torchvision
is the most established, similar libraries exist for other domains:
torchaudio
: Provides datasets (SpeechCommands, LJSpeech, etc.), models, and transformations for audio processing tasks.torchtext
: Offers datasets (sentiment analysis like IMDb, language modeling like WikiText), tokenizers, and vocabulary tools for natural language processing. Note: torchtext
has undergone significant API changes, so consult its documentation for current usage patterns.Using these libraries follows similar principles: import the desired dataset class, instantiate it (often with options for downloading and preprocessing), and then use the resulting Dataset
object with a DataLoader
.
Leveraging these built in datasets significantly speeds up development and experimentation, allowing you to focus on model architecture and training rather than data acquisition and preparation, especially when working with standard benchmarks. Remember that these dataset objects integrate directly with the DataLoader
discussed later in this chapter, enabling efficient batching and shuffling.
© 2025 ApX Machine Learning