趋近智
让我们将本章提到的方法付诸实践。我们讨论了权重初始化的作用,学习率调度如何帮助模型收敛,以及网格搜索和随机搜索等寻找合适超参数值的方法。现在,您将亲身体验如何使用这些方法来调整深度学习模型。
调整超参数更多的是一门艺术而非科学,需要多加尝试。然而,系统化的方法会大大提高您找到一个能带来更好模型性能和泛化能力的配置的可能性。
一个常见的场景是使用简单的卷积神经网络(CNN)在CIFAR-10数据集上进行图像分类。CIFAR-10包含60,000张32x32彩色图像,分为10个类别。为完成此任务,需要具备基本的PyTorch环境,并熟悉模型定义、数据加载和训练循环的编写。
我们的目的不是构建一个绝对最优的CIFAR-10分类器,而是呈现超参数调整的过程。
首先,我们使用PyTorch定义一个简单的CNN架构。这将是我们的基础模型:
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleCNN(nn.Module):
def __init__(self, dropout_rate=0.5):
super().__init__()
# 卷积层
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1) # 输入: 3x32x32 -> 输出: 16x32x32
self.pool = nn.MaxPool2d(2, 2) # 输出: 16x16x16
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) # 输出: 32x16x16
# 池化后: 32x8x8
# 全连接层
self.fc1 = nn.Linear(32 * 8 * 8, 128) # 展平尺寸: 32*8*8 = 2048
self.dropout = nn.Dropout(dropout_rate) # 应用Dropout
self.fc2 = nn.Linear(128, 10) # 10个输出类别
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # 展平除批量维度外的所有维度
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
# 注意:权重初始化(如He初始化)通常由PyTorch层默认处理
# 但如果需要,也可以在此处明确设置。
我们还需要用于CIFAR-10的标准数据加载和转换流程。我们假设您有函数load_cifar10_data(batch_size),它返回用于训练集和验证集的PyTorch DataLoader实例。请记住包含归一化处理。
根据本章内容,有几个超参数是值得调整的候选:
Adam和带有动量的SGD。Adam通常是一个不错的默认选择,但SGD+动量在细致调整后有时能取得稍好的泛化能力。为简单起见,我们在此使用Adam,但调整其学习率。SimpleCNN中。值通常介于到之间。我们将使用随机搜索以提高效率。让我们定义搜索范围:
主要思想是运行多个训练实验(试验),每个实验都使用从我们定义的范围内随机采样的超参数集。我们进行固定且相对较少轮次的训练(例如10-15轮),以快速获得信号,记录验证性能,然后比较不同试验的结果。
以下是调整循环的概要:
import random
import numpy as np
import torch.optim as optim
# 假设SimpleCNN, load_cifar10_data已定义
# 假设train_one_epoch()和evaluate()函数存在
num_trials = 20 # 要尝试的随机配置数量
num_epochs_per_trial = 10 # 短时间训练
results = []
for trial in range(num_trials):
print(f"--- 试验 {trial+1}/{num_trials} ---")
# 1. 采样超参数
lr = 10**np.random.uniform(-4, -2) # 学习率的对数均匀采样
weight_decay = 10**np.random.uniform(-5, -3) # 权重衰减的对数均匀采样
dropout_rate = random.uniform(0.1, 0.5)
batch_size = random.choice([64, 128, 256])
print(f"采样结果: lr={lr:.6f}, wd={weight_decay:.6f}, dropout={dropout_rate:.4f}, batch_size={batch_size}")
# 2. 设置数据加载器、模型、优化器
train_loader, val_loader = load_cifar10_data(batch_size=batch_size)
model = SimpleCNN(dropout_rate=dropout_rate)
# 如果可用,考虑使用CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
best_val_accuracy = 0.0
# 3. 固定轮次训练
for epoch in range(num_epochs_per_trial):
# train_one_epoch(model, train_loader, criterion, optimizer, device)
# val_loss, val_accuracy = evaluate(model, val_loader, criterion, device)
# 用于说明结构的模拟训练/评估
print(f" 轮次 {epoch+1}/{num_epochs_per_trial} - 模拟训练...")
# 在实际运行时,根据evaluate()结果更新best_val_accuracy
# 对于本示例,我们模拟一个结果
simulated_val_accuracy = 0.3 + trial*0.01 + epoch*0.02 + random.uniform(-0.05, 0.05) # 占位符
best_val_accuracy = max(best_val_accuracy, simulated_val_accuracy)
print(f"试验 {trial+1} 完成。最佳验证准确率: {best_val_accuracy:.4f}")
# 4. 记录结果
results.append({
'trial': trial + 1,
'lr': lr,
'weight_decay': weight_decay,
'dropout_rate': dropout_rate,
'batch_size': batch_size,
'best_val_accuracy': best_val_accuracy
})
# 5. 分析结果(见下一节)
print("\n--- 调整完成 ---")
# 按验证准确率排序结果
results.sort(key=lambda x: x['best_val_accuracy'], reverse=True)
print("前5个配置:")
for i in range(min(5, len(results))):
print(f"排名 {i+1}: 准确率={results[i]['best_val_accuracy']:.4f}, "
f"学习率={results[i]['lr']:.6f}, 权重衰减={results[i]['weight_decay']:.6f}, "
f"Dropout={results[i]['dropout_rate']:.4f}, 批量大小={results[i]['batch_size']}")
注意:train_one_epoch和evaluate函数是标准的PyTorch训练组件,为简洁起见此处省略。您可以像往常一样实现它们。
运行调整循环后,results列表包含每个超参数配置的性能。简单地按验证准确率排序,就能得到搜索过程中表现最优的配置集。
将超参数与性能之间的关系可视化可以提供一些发现。例如,让我们绘制验证准确率与学习率(对数尺度)的关系图:
经过10个训练轮次后,不同随机采样学习率所达到的验证准确率。在此模拟运行中,大约在到之间的值似乎表现最佳。
可以为权重衰减和Dropout率制作类似的图表。例如,您可能会发现,非常低或非常高的Dropout率会损害性能,或者适度的权重衰减是有益的。分析表现最佳的试验有助于您把握哪些超参数值(或范围)最有前景。
这项实践表明了改进深度学习模型的基础工作流程。虽然存在自动化超参数优化工具(例如Optuna、Ray Tune),但理解手动过程能提供宝贵的直觉,以有效设置搜索范围和解读结果。尝试是必不可少的,因此请尝试将此过程应用于您自己的模型和数据集。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造