趋近智
state_dict逐步完成为合成表格数据创建自定义 Dataset 的过程,然后使用 DataLoader 为模型训练准备数据。这项实践操作展示了 PyTorch 如何处理数据管道,并与使用 tf.data 达到类似效果的方式进行对比。
假设我们有一个简单的二元分类任务。我们的数据包含数值特征和对应的二元标签(0 或 1)。为简单起见,我们将人工生成这些数据。
第一步是定义一个继承自 torch.utils.data.Dataset 的类。这个类必须实现三个特殊方法:
__init__(self, ...): 构造函数。你通常在此处加载数据(例如,从文件、数据库或动态生成)。你也可以在此处执行任何一次性预处理。__len__(self): 此方法应返回数据集中样本的总数。DataLoader 使用此方法来了解数据集的大小。__getitem__(self, idx): 此方法负责从数据集中获取给定索引 idx 处的单个样本(特征和对应的标签)。你通常也在此处应用针对单个样本的转换,例如将数据转换为 PyTorch 张量。让我们创建 SyntheticTabularDataset:
import torch
from torch.utils.data import Dataset
import numpy as np
class SyntheticTabularDataset(Dataset):
def __init__(self, num_samples=1000, num_features=10):
"""
Constructor for the SyntheticTabularDataset.
Args:
num_samples (int): The total number of samples to generate.
num_features (int): The number of features for each sample.
"""
super().__init__() # 良好实践:调用父类构造函数
self.num_samples = num_samples
self.num_features = num_features
# 生成合成特征(来自正态分布的随机数)
# 在实际场景中,你会从文件或其他来源加载特征。
self.features = np.random.randn(num_samples, num_features).astype(np.float32)
# 生成合成标签(随机二元标签:0 或 1)
self.labels = np.random.randint(0, 2, num_samples).astype(np.int64)
# 对于 TensorFlow 用户:在此阶段,你的数据可能以 NumPy 数组的形式存在。
# 你可能会使用 tf.data.Dataset.from_tensor_slices((self.features, self.labels))
# 来创建一个 TensorFlow Dataset。在这里,我们正在定义一个 PyTorch Dataset 类。
def __len__(self):
"""
Returns the total number of samples in the dataset.
"""
return self.num_samples
def __getitem__(self, idx):
"""
Retrieves the sample (features and label) at the given index.
Args:
idx (int): The index of the sample to retrieve.
Returns:
tuple: (features, label) where features is a PyTorch tensor
and label is a PyTorch tensor.
"""
# 获取指定样本
sample_features = self.features[idx]
sample_label = self.labels[idx]
# 将 NumPy 数组转换为 PyTorch 张量
# 这是在 __getitem__ 中常见的操作
# 在 TensorFlow 中,tf.data 在数据传递给 from_tensor_slices 或通过 map 转换时,会隐式处理张量转换。
return torch.from_numpy(sample_features), torch.tensor(sample_label)
在 __init__ 方法中,我们使用 NumPy 生成特征和标签,并将它们存储为实例属性。np.float32 是特征的常见数据类型,而 np.int64 则是分类标签的典型类型(尤其在使用 CrossEntropyLoss 等损失函数时)。
__len__ 方法很简单;它只返回 self.num_samples。
__getitem__ 方法接受一个索引 idx,从 NumPy 数组中获取对应的特征和标签,然后将它们转换为 PyTorch 张量。torch.from_numpy 用于数组到张量的转换,torch.tensor 适用于标量标签。这种转换为张量的操作很重要,因为 PyTorch 模型期望张量输入。
既然我们已经定义了 SyntheticTabularDataset,现在让我们创建一个它的实例,看看如何访问单个样本:
# 创建自定义数据集的实例
dataset = SyntheticTabularDataset(num_samples=50, num_features=3)
# 检查数据集的长度
print(f"数据集长度: {len(dataset)}")
# 获取单个样本(例如,索引为 0 的第一个样本)
features_sample, label_sample = dataset[0]
print(f"\n第一个样本:")
print(f" 特征: {features_sample}")
print(f" 标签: {label_sample}")
print(f" 特征形状: {features_sample.shape}, 数据类型: {features_sample.dtype}")
print(f" 标签形状: {label_sample.shape}, 数据类型: {label_sample.dtype}")
# 获取另一个样本
features_sample_2, label_sample_2 = dataset[10]
print(f"\n索引 10 处的样本:")
print(f" 特征: {features_sample_2}")
print(f" 标签: {label_sample_2}")
运行这段代码将显示样本总数以及从数据集中获取的单个样本的结构,它们已方便地转换为 PyTorch 张量。
虽然访问单个样本对检查有用,但对于模型训练,我们需要以批次处理数据、打乱数据,并可能并行加载数据。这就是 torch.utils.data.DataLoader 发挥作用的地方。
DataLoader 接受一个 Dataset 对象作为输入,并提供一个可迭代对象。主要参数包括:
dataset: 用于加载数据的 Dataset 对象。batch_size: 每个批次的样本数量。shuffle: 如果为 True,数据会在每个 epoch 重新打乱。这通常建议用于训练数据,以防止模型学习样本的顺序。num_workers: 用于数据加载的子进程数量。设置 num_workers > 0 可启用多进程数据加载,这可以显著加快数据获取速度,尤其是在 __getitem__ 涉及 I/O 操作或大量计算时。对于像我们合成示例这样简单的内存中数据集,num_workers=0(默认值,在主进程中加载)通常没有问题。让我们为 SyntheticTabularDataset 创建一个 DataLoader:
from torch.utils.data import DataLoader
# 重新实例化数据集,也许增加更多样本以适应典型的训练场景
train_dataset = SyntheticTabularDataset(num_samples=1000, num_features=5)
# 创建一个 DataLoader
batch_size = 32
# 对于 TensorFlow 用户:DataLoader 结合了以下功能:
# tf.data.Dataset.shuffle() 和 tf.data.Dataset.batch()。
train_dataloader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0) # 为简单起见,这里使用 0;实际任务中尝试 >0
# 遍历 DataLoader 以获取批次
print(f"\n遍历 DataLoader(以头两个批次为例):")
for i, (batch_features, batch_labels) in enumerate(train_dataloader):
if i < 2: # 如果是前两个批次则打印信息
print(f" 批次 {i+1}:")
print(f" 特征批次形状: {batch_features.shape}") # [batch_size, num_features]
print(f" 标签批次形状: {batch_labels.shape}") # [batch_size]
print(f" 特征批次数据类型: {batch_features.dtype}")
print(f" 标签批次数据类型: {batch_labels.dtype}")
else:
break # 为简洁起见,显示两个批次后停止
# 你也可以在典型的训练循环中遍历它:
# num_epochs = 3
# for epoch in range(num_epochs):
# print(f"\n周期 {epoch+1}/{num_epochs}")
# for batch_idx, (features, labels) in enumerate(train_dataloader):
# # 在实际训练循环中:
# # 1. 将数据移至设备(例如,GPU)
# # features, labels = features.to(device), labels.to(device)
# # 2. 前向传播:model_output = model(features)
# # 3. 计算损失:loss = criterion(model_output, labels)
# # 4. 反向传播:loss.backward()
# # 5. 优化器步进:optimizer.step()
# # 6. 梯度清零:optimizer.zero_grad()
# if batch_idx % 10 == 0: # 每 10 个批次打印一次进度
# print(f" 已处理批次 {batch_idx+1}/{len(train_dataloader)}")
# print("-" * 30)
运行此代码时,你会看到 batch_features 是一个形状为 (batch_size, num_features) 的张量,而 batch_labels 是一个形状为 (batch_size,) 的张量。DataLoader 已有效地将 Dataset 中的单个样本分组到这些批次中。如果 shuffle=True,每次你遍历 train_dataloader 时(例如,在每个训练周期开始时),这些批次内的样本顺序(以及批次本身的顺序)都会不同。
在我们的 SyntheticTabularDataset 中,我们直接在 __getitem__ 方法中将 NumPy 数组转换为 PyTorch 张量。对于更复杂的预处理或数据增强(在图像数据中尤其常见),PyTorch Dataset 类通常在其 __init__ 方法中接受一个 transform 参数。这个 transform 通常是一个可调用对象(例如一个函数或一个带有 __call__ 方法的对象),它会在样本返回之前在 __getitem__ 中应用于样本。
例如,如果你处理图像,你可能会将 torchvision.transforms 中的一系列转换(如调整大小、裁剪、归一化和转换为张量)传递给你的自定义图像数据集。
# 转换如何使用的示例
# class MyImageDataset(Dataset):
# def __init__(self, image_paths, labels, transform=None):
# self.image_paths = image_paths
# self.labels = labels
# self.transform = transform
#
# def __len__(self):
# return len(self.image_paths)
#
# def __getitem__(self, idx):
# image = Image.open(self.image_paths[idx]) # 加载图像
# label = self.labels[idx]
# if self.transform:
# image = self.transform(image) # 应用转换
# return image, label
# from torchvision import transforms
# image_transform = transforms.Compose([
# transforms.Resize((256, 256)),
# transforms.RandomCrop(224),
# transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])
# image_dataset = MyImageDataset(paths, labels, transform=image_transform)
这种方法使数据加载逻辑清晰,并允许灵活组合预处理步骤。对于我们的表格数据,__getitem__ 中的直接张量转换已足够,但了解这种常见模式对于更高级的用例会很有帮助。
这项实践练习展示了 PyTorch 数据处理的基本模式:使用 Dataset 定义如何获取单个处理后的项,然后使用 DataLoader 高效地批处理、打乱和迭代这些项。这种关注点分离提供了很大的灵活性,类似于你在 TensorFlow 的 tf.data API 中定义数据源然后应用 .batch() 和 .shuffle() 等转换的方式,但 Dataset 本身具有更明确的基于 Python 类的结构。你现在拥有了构建块,可以在 PyTorch 中为各种数据类型创建高效的数据管道。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造