When working with TensorFlow, you're accustomed to using tf.data.Dataset to represent a sequence of elements, such as tensors or tuples of tensors, and applying transformations like map, batch, and shuffle directly to the dataset object to construct an input pipeline. This creates a graph of operations that processes data efficiently.
PyTorch offers a different, yet equally powerful, abstraction for handling datasets through its torch.utils.data.Dataset class. Instead of a chain of transformation methods applied to a dataset object, the PyTorch Dataset is an abstract class that you typically subclass to create your own custom dataset. The core idea is to provide a standardized way to access individual data samples.
Any custom dataset in PyTorch that inherits from torch.utils.data.Dataset must implement two essential methods:
__len__(self): This method should return the total number of samples in the dataset. PyTorch's data loading utilities rely on this to determine the extent of the dataset.__getitem__(self, index): This method is responsible for retrieving a single data sample (e.g., an image and its corresponding label) at the given index. The index will be an integer ranging from 0 to len(self) - 1.This approach gives you fine-grained control over how data is loaded and processed on a per-sample basis. While tf.data pipelines are defined by chaining operations on the Dataset object itself, in PyTorch, data transformations are often encapsulated within the __getitem__ method of your custom Dataset or applied using composable transform objects, as we'll see when discussing torchvision.transforms.
Let's consider a simple example. Suppose you have a list of image file paths and their corresponding labels. In TensorFlow, you might use tf.data.Dataset.from_tensor_slices((filepaths, labels)) and then a map function to load and preprocess images.
In PyTorch, you would define a class:
from torch.utils.data import Dataset
from PIL import Image # For image loading, as an example
class CustomImageDataset(Dataset):
def __init__(self, image_paths, labels, transform=None):
"""
Args:
image_paths (list): List of paths to images.
labels (list): List of corresponding labels.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.image_paths = image_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# Load image
img_path = self.image_paths[idx]
# Example: image = Image.open(img_path).convert("RGB")
# For demonstration, let's assume image is loaded and preprocessed
# In a real scenario, you'd load the image from img_path
# and apply any necessary preprocessing here.
# For simplicity, we'll just return the path and label.
image_data_placeholder = f"image_data_for_{img_path}" # Replace with actual image loading
label = self.labels[idx]
sample = {"image": image_data_placeholder, "label": label}
if self.transform:
sample = self.transform(sample) # Assuming transform handles a dict
# Typically, __getitem__ returns a tuple (features, label)
# or a dictionary as shown above.
# For ML, features and labels should be PyTorch tensors.
return sample # Or (torch.tensor(image_data), torch.tensor(label))
In this CustomImageDataset, __init__ stores the data sources (file paths and labels) and any transformations. __len__ simply returns the number of items. __getitem__ is where the logic for fetching and potentially transforming a single item resides. It takes an index idx, retrieves the corresponding image path and label, loads the image (actual loading code omitted for brevity), and applies any specified transformations.
PyTorch supports two main types of datasets:
torch.utils.data.Dataset that implement __getitem__() and __len__(). They represent a map from (integer) indices to data samples. The example CustomImageDataset above is a map-style dataset.torch.utils.data.IterableDataset that implement __iter__(). They are suitable for situations where random access is difficult or data is streamed, as they represent an iterable over data samples.For most common use cases, especially when transitioning from TensorFlow where you might have data readily available in lists or files, map-style datasets are the direct counterpart to consider.
The following diagram illustrates the difference in how dataset structures are approached in TensorFlow and PyTorch:
High-Level Comparison of Dataset Abstractions in TensorFlow and PyTorch. TensorFlow's
tf.data.Datasetinvolves chaining operations to form a pipeline. PyTorch'storch.utils.data.Datasetis typically a custom class defining indexed data access, which is then used by aDataLoader.
In essence, tf.data.Dataset is more about defining a data transformation pipeline as a single, stateful object that yields processed (and often batched) data. In contrast, PyTorch's torch.utils.data.Dataset is primarily concerned with providing indexed access to individual raw or lightly processed data samples. The tasks of batching, shuffling, and parallel data loading are then delegated to the torch.utils.data.DataLoader, which we will cover in the next section. This separation of concerns allows for flexibility: the Dataset defines what the data is and how to get one item, while the DataLoader defines how to iterate over many items in batches.
Was this section helpful?
torch.utils.data), PyTorch Core Team, 2025 - The official PyTorch documentation for the torch.utils.data module, including Dataset and DataLoader.tf.data: Build TensorFlow input pipelines, Google, 2024 (Google) - The official TensorFlow guide on tf.data for building efficient input pipelines.DataLoader, Eli Stevens, Luca Antiga, and Thomas Viehmann, 2020 (Manning Publications) - Provides an explanation of PyTorch's data loading, including Dataset and DataLoader.tf.data API, including creating pipelines for various data types and applying transformations.© 2026 ApX Machine LearningEngineered with