趋近智
requires_grad)backward()).grad)torch.nn 搭建模型torch.nn.Module 基类torch.nn 损失)torch.optim)torch.utils.data.Datasettorchvision.transforms)torch.utils.data.DataLoader尽管创建自定义 Dataset 类为您的特定数据提供了最大的灵活性,但许多深度学习任务,特别是在研究和基准测试中,使用标准化数据集。手动准备这些数据集涉及下载、解压、组织文件和编写解析逻辑,这可能既耗时又容易出错。
幸运的是,PyTorch 提供了配套库,可以简化常见领域的数据处理过程。对于计算机视觉,torchvision 包是一个不可或缺的工具。它不仅包含流行的数据集,还包含预训练模型和常用的图像转换函数。本节主要介绍如何访问和使用 torchvision.datasets 提供的数据集。
torchvision.datasets 访问数据集torchvision.datasets 模块提供了对许多广泛使用的计算机视觉数据集的便捷访问,例如 MNIST、Fashion MNIST、CIFAR 10/100、ImageNet、COCO 等。使用这些数据集很简单。通常,您会从 torchvision.datasets 导入特定的数据集类并实例化它。
让我们看一个使用 CIFAR 10 数据集的例子,它包含 60,000 张 32x32 彩色图像,分为 10 个类别。
import torchvision
import torchvision.transforms as transforms
# 定义一个简单的转换,将图像转换为 PyTorch 张量
transform = transforms.Compose([transforms.ToTensor()])
# 加载训练数据集
# root: 数据将被存储/查找的目录
# train=True: 指定训练集
# download=True: 如果本地未找到数据,则下载
# transform: 将定义的转换应用于每张图像
train_dataset = torchvision.datasets.CIFAR10(root='./data',
train=True,
download=True,
transform=transform)
# 加载测试数据集
test_dataset = torchvision.datasets.CIFAR10(root='./data',
train=False,
download=True,
transform=transform)
print(f"CIFAR-10 training dataset size: {len(train_dataset)}")
print(f"CIFAR-10 test dataset size: {len(test_dataset)}")
# 访问单个数据点(图像、标签)
img, label = train_dataset[0]
print(f"Image shape: {img.shape}") # 通常输出:torch.Size([3, 32, 32])
print(f"Label: {label}") # 输出:表示类别的整数
当您首次运行此代码时,torchvision 会检查指定的 root 目录(在本例中为 ./data)。如果 CIFAR 10 数据不存在,设置 download=True 会指示 torchvision 自动将数据集下载并解压到该目录中。后续运行将发现数据已存在于本地并跳过下载。
注意 transform 参数。您可以在此处指定数据预处理步骤,这些步骤在数据样本加载后但在 __getitem__ 返回之前应用于每个样本。我们使用了 transforms.ToTensor(),它将 PIL 图像格式(torchvision 数据集常用)转换为 PyTorch 张量。数据转换将在下一节中进行更详细的介绍。
重要的是,torchvision.datasets 返回的对象(如上文的 train_dataset 和 test_dataset)是继承自 torch.utils.data.Dataset 的类实例。这意味着它们实现了必需的 __len__ 和 __getitem__ 方法,使其与 PyTorch 的 DataLoader 完全兼容。
len(train_dataset) 返回数据集中样本的总数。train_dataset[i] 返回第 i 个样本,通常是一个元组 (data, target),其中 data 是预处理后的输入(例如,图像张量),target 是对应的标签或标注。以下是 CIFAR-10 训练集中类别分布的简单可视化:
CIFAR-10 数据集是平衡的,每个类别恰好有 5,000 张训练图像。
尽管 torchvision 最为完善,但其他领域也存在类似的库:
torchaudio: 为音频处理任务提供数据集(如 SpeechCommands、LJSpeech 等)、模型和转换功能。torchtext: 为自然语言处理提供数据集(如 IMDb 情感分析、WikiText 语言建模)、分词器和词汇工具。注意:torchtext 经历了重大的 API 变更,因此请查阅其文档以了解当前的使用模式。使用这些库遵循相似的原则:导入所需的数据集类,实例化它(通常带有下载和预处理选项),然后将生成的 Dataset 对象与 DataLoader 一起使用。
依靠这些内置数据集可以显著加快开发和实验速度,使您能够专注于模型架构和训练,而不是数据获取和准备,尤其是在使用标准基准时。请记住,这些数据集对象直接与本章后面讨论的 DataLoader 集成,从而实现高效的批处理和洗牌。
这部分内容有帮助吗?
torch.utils.data - PyTorch Documentation, PyTorch Foundation, 2025 - PyTorch数据加载工具的官方文档,包括支持torchvision数据集的Dataset和DataLoader。torchvision.datasets - PyTorch Documentation, PyTorch Foundation, 2024 - 描述TorchVision中内置数据集、其用法和配置参数的官方文档。Dataset、DataLoader以及有效使用torchvision数据集。© 2026 ApX Machine Learning用心打造