When working with TensorFlow, you're accustomed to using tf.data.Dataset
to represent a sequence of elements, such as tensors or tuples of tensors, and applying transformations like map
, batch
, and shuffle
directly to the dataset object to construct an input pipeline. This creates a graph of operations that processes data efficiently.
PyTorch offers a different, yet equally powerful, abstraction for handling datasets through its torch.utils.data.Dataset
class. Instead of a chain of transformation methods applied to a dataset object, the PyTorch Dataset
is an abstract class that you typically subclass to create your own custom dataset. The core idea is to provide a standardized way to access individual data samples.
Any custom dataset in PyTorch that inherits from torch.utils.data.Dataset
must implement two essential methods:
__len__(self)
: This method should return the total number of samples in the dataset. PyTorch's data loading utilities rely on this to determine the extent of the dataset.__getitem__(self, index)
: This method is responsible for retrieving a single data sample (e.g., an image and its corresponding label) at the given index
. The index
will be an integer ranging from 0 to len(self) - 1
.This approach gives you fine-grained control over how data is loaded and processed on a per-sample basis. While tf.data
pipelines are defined by chaining operations on the Dataset
object itself, in PyTorch, data transformations are often encapsulated within the __getitem__
method of your custom Dataset
or applied using composable transform objects, as we'll see when discussing torchvision.transforms
.
Let's consider a simple example. Suppose you have a list of image file paths and their corresponding labels. In TensorFlow, you might use tf.data.Dataset.from_tensor_slices((filepaths, labels))
and then a map
function to load and preprocess images.
In PyTorch, you would define a class:
from torch.utils.data import Dataset
from PIL import Image # For image loading, as an example
class CustomImageDataset(Dataset):
def __init__(self, image_paths, labels, transform=None):
"""
Args:
image_paths (list): List of paths to images.
labels (list): List of corresponding labels.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.image_paths = image_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# Load image
img_path = self.image_paths[idx]
# Example: image = Image.open(img_path).convert("RGB")
# For demonstration, let's assume image is loaded and preprocessed
# In a real scenario, you'd load the image from img_path
# and apply any necessary preprocessing here.
# For simplicity, we'll just return the path and label.
image_data_placeholder = f"image_data_for_{img_path}" # Replace with actual image loading
label = self.labels[idx]
sample = {"image": image_data_placeholder, "label": label}
if self.transform:
sample = self.transform(sample) # Assuming transform handles a dict
# Typically, __getitem__ returns a tuple (features, label)
# or a dictionary as shown above.
# For ML, features and labels should be PyTorch tensors.
return sample # Or (torch.tensor(image_data), torch.tensor(label))
In this CustomImageDataset
, __init__
stores the data sources (file paths and labels) and any transformations. __len__
simply returns the number of items. __getitem__
is where the logic for fetching and potentially transforming a single item resides. It takes an index idx
, retrieves the corresponding image path and label, loads the image (actual loading code omitted for brevity), and applies any specified transformations.
PyTorch supports two main types of datasets:
torch.utils.data.Dataset
that implement __getitem__()
and __len__()
. They represent a map from (integer) indices to data samples. The example CustomImageDataset
above is a map-style dataset.torch.utils.data.IterableDataset
that implement __iter__()
. They are suitable for situations where random access is difficult or data is streamed, as they represent an iterable over data samples.For most common use cases, especially when transitioning from TensorFlow where you might have data readily available in lists or files, map-style datasets are the direct counterpart to consider.
The following diagram illustrates the difference in how dataset structures are approached in TensorFlow and PyTorch:
High-Level Comparison of Dataset Abstractions in TensorFlow and PyTorch. TensorFlow's
tf.data.Dataset
involves chaining operations to form a pipeline. PyTorch'storch.utils.data.Dataset
is typically a custom class defining indexed data access, which is then used by aDataLoader
.
In essence, tf.data.Dataset
is more about defining a data transformation pipeline as a single, stateful object that yields processed (and often batched) data. In contrast, PyTorch's torch.utils.data.Dataset
is primarily concerned with providing indexed access to individual raw or lightly processed data samples. The tasks of batching, shuffling, and parallel data loading are then delegated to the torch.utils.data.DataLoader
, which we will cover in the next section. This separation of concerns allows for flexibility: the Dataset
defines what the data is and how to get one item, while the DataLoader
defines how to iterate over many items in batches.
© 2025 ApX Machine Learning