趋近智
在训练神经网络 (neural network)时,我们需要一种方式来量化 (quantization)模型的表现。具体来说,我们需要衡量模型预测与实际目标值之间的差异。这个衡量标准由损失函数 (loss function)提供,它也被称为误差函数或成本函数。它计算一个单一的标量值,表示模型当前预测对于给定数据集的“不良程度”。损失越小,模型表现越好。这个损失值非常重要,因为优化器会通过调整模型参数 (parameter)来努力使其最小化。
Flux.jl 提供了一套预定义的损失函数,可以方便地在 Flux 模块中取用。损失函数的选择在很大程度上取决于您的机器学习 (machine learning)任务性质,主要是回归问题还是分类问题。
当您的模型旨在预测连续数值时,通常会使用回归损失函数 (loss function)。
对于回归任务,目标是预测连续值,均方误差是一个非常普遍的选择。它计算预测值 () 和真实值 () 之间平方差的平均值。
其中 是样本数量。对误差项进行平方有两个作用:它确保误差始终为正,并且对较大的误差比对较小的误差施加更大的惩罚。在 Flux 中,您可以使用 Flux.mse(y_pred, y_true)。
另一个常用的回归损失函数是平均绝对误差。MAE 计算预测与真实值之间绝对差值的平均值。
MAE 衡量一组预测中误差的平均大小,而不考虑其方向。与 MSE 不同,MAE 对误差的惩罚是线性增加的。这使得 MAE 相对于 MSE 对异常值不那么敏感。您可以在 Flux 中使用 Flux.mae(y_pred, y_true)。
均方误差 (MSE) 和平均绝对误差 (MAE) 如何惩罚不同大小预测误差的比较。MSE 的二次方特性导致对较大误差的损失增加更陡峭。
对于模型从一组离散类别中预测类别标签的任务,分类损失函数 (loss function)是合适的。
对于二元分类问题,即只有两个可能的输出类别(例如 0 或 1,垃圾邮件或非垃圾邮件),二元交叉熵是标准损失函数。它衡量分类模型的表现,该模型的输出是介于 0 和 1 之间的概率值。
这里, 是真实标签(0 或 1), 是类别 1 的预测概率。当预测概率 接近真实标签 时,此损失函数被最小化。
在 Flux 中,您使用 Flux.binarycrossentropy(y_pred_probs, y_true)。y_pred_probs 必须是概率,通常通过将原始模型输出(logits)经过 sigmoid 激活函数 (activation function)获得。Flux 的 Flux.logitbinarycrossentropy(y_pred_logits, y_true) 可以直接将 logits 作为输入,这通常在数值上更稳定,并且通常更推荐。
在处理多类别分类问题(超过两个类别)时,分类交叉熵是首选的损失函数。它衡量真实类别分布与预测的类别概率分布之间的差异。
在此公式中, 是样本数量, 是类别数量, 是二元指示符(如果样本 属于类别 则为 1,否则为 0,通常是独热编码), 是模型预测样本 属于类别 的概率。
预测值 通常是应用于网络最后一层的 softmax 激活函数的输出,确保它们对于每个样本在所有类别上加和为 1。Flux 提供了 Flux.crossentropy(y_pred_probs, y_true_onehot)。与二元交叉熵类似,Flux 还提供了 Flux.logitcrossentropy(y_pred_logits, y_true_onehot),它接受原始 logits 并内部应用 log-softmax 变换,以提高数值稳定性和效率。当处理原始模型输出时,这是推荐的函数。对于 y_true_onehot,您通常会使用独热编码来表示您的目标标签(例如,使用 Flux.onehot 或 Flux.onehotbatch)。
损失函数是神经网络 (neural network)在训练期间正向传播的最终计算。它将有关模型在一批数据上表现的所有信息提炼成一个单一的数字。然后,这个标量损失值被 Flux 的自动微分引擎 Zygote.jl 用于计算损失相对于模型所有参数 (parameter)(权重 (weight)和偏置 (bias))的梯度。这些梯度指示每个参数应如何调整以减少损失,从而为优化器的更新步骤提供了依据。因此,损失函数的选择和正确实现与训练过程的成功直接相关。
选择合适的损失函数取决于您的具体问题:
Flux.mse:一个不错的默认选择。对异常值敏感。Flux.mae:相对于 MSE,对异常值的敏感度较低。Flux.binarycrossentropy:当模型输出是概率时(经过 sigmoid 后)使用。Flux.logitbinarycrossentropy:当模型输出是原始 logits 时(sigmoid 之前)首选,因为它在数值上更稳定。目标标签应为 0 或 1。Flux.crossentropy:当模型输出是概率分布时(经过 softmax 后)使用。目标标签应为独热编码。Flux.logitcrossentropy:当模型输出是原始 logits 时(softmax 之前)首选。目标标签应为独热编码。此函数结合了 softmax 激活和交叉熵计算,以获得更好的数值稳定性。务必确保模型的输出激活函数 (activation function)(如果使用基于 logit 的损失,则可能没有)与所选损失函数兼容。
让我们看一个在 Flux 中如何使用这些函数的快速示例:
using Flux
using Statistics: mean # 稍后用于自定义损失示例
# 回归任务的样本数据
y_actual_reg = [1.5f0, 2.0f0, 3.5f0]
y_predicted_reg = [1.3f0, 2.4f0, 3.1f0]
# 计算均方误差
mse_loss = Flux.mse(y_predicted_reg, y_actual_reg)
println("MSE 损失:", mse_loss)
# 计算平均绝对误差
mae_loss = Flux.mae(y_predicted_reg, y_actual_reg)
println("MAE 损失:", mae_loss)
# 多类别分类任务的样本数据(3 个样本,2 个类别)
# 原始模型输出 (logits)
# 维度:(类别数量, 样本数量)
y_predicted_logits_clf = Float32[0.2 0.8; -0.5 0.5; 1.2 -0.1]'
# 真实标签(独热向量形式)
# 样本 1:类别 2,样本 2:类别 1,样本 3:类别 1
y_actual_clf = Flux.onehotbatch([2, 1, 1], 1:2) # 创建一个 2x3 的独热矩阵
# 从 logits 计算分类交叉熵
# Flux.logitcrossentropy 期望 (logits, 目标)
ce_loss = Flux.logitcrossentropy(y_predicted_logits_clf, y_actual_clf)
println("分类交叉熵损失:", ce_loss)
# 对于二元分类(1 个样本,预测类别 1 的 logit)
y_predicted_logit_bin = 0.7f0 # 单个样本的原始 logit 输出
y_actual_bin = 1.0f0 # 真实标签为 1(浮点数,符合 logitbinarycrossentropy 的期望)
# 从 logit 计算二元交叉熵
bce_loss = Flux.logitbinarycrossentropy(y_predicted_logit_bin, y_actual_bin)
println("二元交叉熵损失(来自 logit):", bce_loss)
在这些示例中,y_predicted_reg 和 y_predicted_logits_clf 通常是您的 Flux 模型的输出(例如,model(input_data))。实际标签 y_actual_reg 和 y_actual_clf 来自您的数据集。
虽然 Flux 提供了一套完整的标准损失函数,但 Julia 的灵活性允许您在问题需要特定误差度量时轻松定义自己的函数。Flux 中的自定义损失函数只是一个普通的 Julia 函数,它接受模型的预测和真实目标作为输入,并返回一个表示损失的单一标量值。
例如,如果您需要均方误差的加权版本,可以这样定义它:
function my_weighted_mse(y_pred, y_true, weights)
# 确保 weights, y_pred, 和 y_true 可以广播
# 并计算加权平方误差的均值。
return mean(weights .* (y_pred .- y_true).^2)
end
# 示例用法(假设 model_outputs, y_labels, sample_weights 已定义):
# model_outputs = model(x_batch)
# loss_value = my_weighted_mse(model_outputs, y_batch_labels, batch_sample_weights)
这个自定义函数可以像任何内置的 Flux 损失函数一样在您的训练循环中使用。Zygote.jl 将能够对其进行微分,前提是函数内部的所有操作本身都是可微分的。这种可组合性是 Julia 深度学习 (deep learning)生态系统的一个重要特点。
这部分内容有帮助吗?
© 2026 ApX Machine LearningAI伦理与透明度•