As introduced, loading and processing data efficiently is fundamental for training deep learning models. PyTorch provides a standardized way to handle datasets through its torch.utils.data.Dataset
abstract class. Think of Dataset
as a contract: it defines a standard interface for accessing your data, regardless of whether it resides in memory, on disk, or needs to be generated on the fly.
Dataset
Abstract ClassAt its core, torch.utils.data.Dataset
is an abstract class representing a dataset. Any custom dataset you create in PyTorch should inherit from this class. Why use this structure? It ensures that different datasets, whether built-in or custom, present a consistent API to other PyTorch components, most notably the DataLoader
, which we'll cover later. This standardization simplifies the process of swapping datasets or using different data sources with the same training code.
To create your own custom dataset, you need to subclass torch.utils.data.Dataset
and override two essential methods:
__len__(self)
: This method should return the total number of samples in your dataset. The DataLoader
uses this to determine the size of the dataset.__getitem__(self, idx)
: This method is responsible for loading and returning a single sample from the dataset given an index idx
. This is where the actual data loading logic resides (e.g., reading an image file, retrieving a row from a CSV, accessing an element in a list). The DataLoader
calls this method repeatedly to construct batches.Let's illustrate this with a simple example. Imagine you have your features and corresponding labels stored in Python lists or NumPy arrays.
import torch
from torch.utils.data import Dataset
import numpy as np
class SimpleCustomDataset(Dataset):
"""A simple example dataset with features and labels."""
def __init__(self, features, labels):
"""
Args:
features (list or np.array): A list or array of features.
labels (list or np.array): A list or array of labels.
"""
# Basic check: Features and labels must have the same length
assert len(features) == len(labels), "Features and labels must have the same length."
self.features = features
self.labels = labels
def __len__(self):
"""Returns the total number of samples."""
return len(self.features)
def __getitem__(self, idx):
"""
Generates one sample of data.
Args:
idx (int): The index of the item.
Returns:
tuple: (feature, label) for the given index.
"""
# Retrieve feature and label for the given index
feature = self.features[idx]
label = self.labels[idx]
# Often, you'll convert data to PyTorch tensors here
# We assume features/labels might not be tensors yet
sample = (torch.tensor(feature, dtype=torch.float32),
torch.tensor(label, dtype=torch.long)) # Assuming classification label
return sample
# --- Example Usage ---
# Sample data (replace with your actual data)
num_samples = 100
num_features = 10
features_data = np.random.randn(num_samples, num_features)
labels_data = np.random.randint(0, 5, size=num_samples) # Example: 5 classes
# Create an instance of the custom dataset
my_dataset = SimpleCustomDataset(features_data, labels_data)
# Access dataset properties and elements
print(f"Dataset size: {len(my_dataset)}")
# Get the first sample
first_sample = my_dataset[0]
feature_sample, label_sample = first_sample
print(f"\nFirst sample features:\n{feature_sample}")
print(f"First sample shape: {feature_sample.shape}")
print(f"First sample label: {label_sample}")
# Get the tenth sample
tenth_sample = my_dataset[9]
print(f"\nTenth sample label: {tenth_sample[1]}")
In this example:
__init__
method stores the feature and label data passed during instantiation.__len__
simply returns the length of the features list (which is the same as the labels list).__getitem__
takes an index idx
, retrieves the corresponding feature and label, converts them into PyTorch tensors, and returns them as a tuple. This conversion to tensors is a common practice within __getitem__
.The real utility of the custom Dataset
comes when dealing with data that isn't readily available in memory. For instance, you might have image file paths and labels stored in a CSV file.
import torch
from torch.utils.data import Dataset
from PIL import Image # Python Imaging Library for image loading
import pandas as pd
import os
class ImageFilelistDataset(Dataset):
"""Dataset for loading image paths and labels from a CSV file."""
def __init__(self, csv_file, root_dir, transform=None):
"""
Args:
csv_file (string): Path to the csv file with annotations.
Assumes columns: 'image_path', 'label'
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.annotations = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform # We'll discuss transforms later
def __len__(self):
return len(self.annotations)
def __getitem__(self, idx):
# Get image path relative to root_dir from the CSV
img_rel_path = self.annotations.iloc[idx, 0] # Assuming first column is path
img_full_path = os.path.join(self.root_dir, img_rel_path)
# Load the image using PIL
try:
image = Image.open(img_full_path).convert('RGB') # Ensure 3 channels
except FileNotFoundError:
print(f"Error: Image not found at {img_full_path}")
# Handle error appropriately, e.g., return None or raise exception
# For simplicity here, we'll return None and rely on DataLoader's collate_fn
# to potentially handle it (or filter later). A better approach
# might be to clean the CSV beforehand.
return None, None # Returning None values
# Get the label from the CSV
label = self.annotations.iloc[idx, 1] # Assuming second column is label
label = torch.tensor(int(label), dtype=torch.long)
# Apply transformations if any
if self.transform:
image = self.transform(image) # Transforms usually convert PIL Image to Tensor
# If no transform is provided that converts to tensor, do it manually
if not isinstance(image, torch.Tensor):
# Basic conversion if no other transform applied
image = torch.tensor(np.array(image), dtype=torch.float32).permute(2, 0, 1) / 255.0
return image, label
# --- Example Usage (Conceptual - requires actual images and CSV) ---
# Assume you have:
# 1. A folder 'data/images/' containing image files (e.g., cat1.jpg, dog1.png)
# 2. A CSV file 'data/annotations.csv' with content like:
# image_path,label
# images/cat1.jpg,0
# images/dog1.png,1
# ...
# image_dataset = ImageFilelistDataset(csv_file='data/annotations.csv',
# root_dir='data/')
# Accessing would be similar:
# print(f"Image dataset size: {len(image_dataset)}")
# if len(image_dataset) > 0:
# img, lbl = image_dataset[0]
# if img is not None:
# print(f"First image shape: {img.shape}") # Shape depends on transforms
# print(f"First image label: {lbl}")
In this ImageFilelistDataset
example:
__init__
reads the CSV using pandas and stores the file paths and the root directory. It also accepts an optional transform
argument (we'll see its use shortly).__len__
returns the number of rows in the CSV file.__getitem__
constructs the full image path, loads the image using PIL, retrieves the label, applies any specified transformations, ensures the image is a tensor, and returns the image tensor and label tensor.Notice that the Dataset
itself only defines how to get a single item. It doesn't load the entire dataset into memory at once (unless your __init__
explicitly does so, which is generally avoided for large datasets). It also doesn't handle batching, shuffling, or parallel loading. That's where the DataLoader
comes in, building directly upon the foundation laid by the Dataset
. By implementing __len__
and __getitem__
, you provide the necessary structure for DataLoader
to efficiently access your data samples.
© 2025 ApX Machine Learning