趋近智
在进入训练的迭代过程之前,我们需要准备好核心部件:模型本身、衡量其误差的方法(损失函数 (loss function))以及根据误差更新模型的机制(优化器)。这个准备阶段确保所有必需的组件都已初始化并为训练循环做好了准备。
首先,你需要一个神经网络 (neural network)模型的实例。定义自定义网络结构通常涉及继承torch.nn.Module。创建一个模型类的对象即可:
# 假设 'SimpleNet' 是你之前定义的自定义 nn.Module 类
model = SimpleNet(input_size=784, hidden_size=128, output_size=10)
print(model)
这会创建网络结构,包括其所有层和参数 (parameter)(权重 (weight)和偏置 (bias))。最初,这些参数具有随机值(或者如果你实现了特定的初始化方案,则由这些方案确定的值)。
深度学习 (deep learning)计算,特别是训练,在GPU上速度要快得多。PyTorch使得将模型移动到合适的设备(CPU或GPU)变得简单。一种好的做法是尽早定义目标设备,然后始终将模型和数据都移动到该设备上。
import torch
# 确定可用设备
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
# 将模型移动到所选设备
model.to(device)
执行model.to(device)会原地修改模型,如果CUDA可用,则将其所有参数和缓冲区移动到GPU内存中,否则保留在CPU上。请记住,任何参与模型计算的张量(如输入数据)也必须位于相同的设备上。我们将在训练循环内部处理数据张量的移动。
损失函数,常被称为准则函数,衡量模型预测与实际目标值之间的距离。PyTorch在torch.nn模块中提供了许多标准损失函数。选择哪种损失函数很大程度上取决于你正在解决的问题类型(例如,回归、分类)。
对于多分类问题,nn.CrossEntropyLoss很常用。它在一个高效的类中结合了nn.LogSoftmax和nn.NLLLoss(负对数似然损失)。
# 用于多分类
loss_fn = torch.nn.CrossEntropyLoss()
# 用于回归问题(预测连续值)
# loss_fn = torch.nn.MSELoss() # 均方误差损失
你像实例化模型一样实例化所选的损失函数。这个loss_fn对象稍后将在训练循环中被调用,通常接收模型的输出和真实标签作为输入,以计算一个标量损失值。
优化器实现了一种算法(如随机梯度下降 (gradient descent)或Adam),用于根据反向传播 (backpropagation)期间计算的梯度调整模型的参数 (parameter)。其作用是使损失函数 (loss function)最小化。优化器位于torch.optim包中。
初始化优化器时,你必须提供两个重要的参数:
model.parameters()轻松完成,该方法返回模型中所有可训练参数的迭代器。lr): 这个超参数 (hyperparameter)控制参数更新的步长。找到一个合适的学习率对有效的训练很有帮助。这通常需要尝试。import torch.optim as optim
# 使用随机梯度下降 (SGD)
learning_rate = 0.01
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
# 或者,使用Adam优化器
# optimizer = optim.Adam(model.parameters(), lr=0.001)
在这里,我们创建了一个SGD优化器实例。它现在持有model所有参数的引用,并且知道当其step()方法稍后被调用时要使用的学习率。不同的优化器可能还有额外的超参数(如SGD的momentum或Adam的betas),你可以在初始化时进行配置。
随着模型被实例化并移动到正确设备,损失函数被定义,以及优化器被配置来更新模型的参数,我们已经设置好了所有必需的组件。我们现在准备继续进行训练过程的核心:在训练循环中迭代数据并执行前向传播、损失计算、反向传播和参数更新。
这部分内容有帮助吗?
torch.nn - PyTorch Documentation, PyTorch Authors, 2024 - PyTorch神经网络模块的官方文档,涵盖模型定义、层以及诸如交叉熵损失和均方误差损失等多种损失函数。torch.optim - PyTorch Documentation, PyTorch Authors, 2025 (PyTorch Foundation) - PyTorch优化算法的官方文档,包含随机梯度下降(SGD)和Adam等优化器,并解释了它们的用法和参数。© 2026 ApX Machine LearningAI伦理与透明度•