趋近智
随着您学会构建神经网络 (neural network)架构并了解训练循环的核心机制,您会发现循环本身可能会变得相当繁忙。除了计算梯度和更新权重 (weight)之外,您还会希望监测训练进展、保存最佳模型、在训练不再有效时停止,甚至动态调整学习率等参数 (parameter)。将所有这些逻辑直接塞入主训练代码中可能会导致混乱。这正是回调机制的用武之地,它提供了一种简洁、模块化的方式,将自定义操作插入到训练生命周期中。
回调本质上是Flux.jl在训练过程中预定义点执行的函数,例如,在每个训练周期的开始或结束时,甚至在处理每个小批量数据之后。它们充当钩子,允许您扩展训练功能,而无需直接修改核心的Flux.train!逻辑。这种方法使您的主要训练脚本保持专注,并使您的辅助任务可重复使用且更易于管理。
在训练程序中使用回调带来诸多实际益处:
通过将这些任务委托给回调,您的主训练循环保持整洁,并专注于核心的学习过程。
Flux.jl 的 Flux.train! 函数通过其 cb 关键字参数 (parameter)支持回调。您通常传递一个函数或可调用对象的数组,Flux 将在适当的时间间隔执行它们,通常在每个训练步骤(参数更新)之后。
# 使用 Flux.train! 的回调简化示例
# Flux.train!(loss_function, parameters, data_iterable, optimizer, cb = my_callbacks)
my_callbacks 中的每个元素都是 Flux.train! 将调用的函数。Flux 还提供了诸如 Flux.throttle 之类的实用工具,它们常与回调结合使用,以控制操作的频率,例如记录日志。
让我们看一些广泛使用的回调模式以及如何实现它们。
您经常会希望记录训练损失。Flux.throttle 在这里是一个方便的工具。它包装一个函数,并确保只有在其上次执行后经过一定时间才会被调用。
using Flux, Logging, Statistics
# 假设模型 (model)、损失函数 (loss_function)、训练数据加载器 (train_loader) 和优化器 (opt) 已定义
# loss_function(x, y) 计算一个批次的损失
# 用于记录当前批次损失的回调,每5秒最多执行一次
# 注意:为了让此回调直接在 Flux.train! 中工作,loss_function 需要在
# x_batch 和 y_batch 可访问的范围内定义,或者更常见的是,
# 损失是在 Flux.train! 调用内部计算的。
# 对于 Flux.train! 更常见的模式是直接传递损失函数,
# Flux 自己处理评估。此时回调可能不需要重新计算损失。
# 让我们显示一个简单的信息日志:
iter_count = 0
log_callback = () -> begin
global iter_count += 1
# 您通常会从 Flux 获取损失,或在需要时计算它
# 在此示例中,我们仅记录迭代次数。
@info "迭代次数: $(iter_count)"
end
# 限制日志回调的执行频率,使其每5秒最多运行一次
throttled_logger = Flux.throttle(log_callback, 5)
# 示例用法(实际数据流取决于您的循环)
# Flux.train!(loss_function, params(model), train_loader, opt, cb = throttled_logger)
当调用 Flux.train! 时,throttled_logger 将在参数 (parameter)更新后被执行。Flux.throttle 确保 log_callback 逻辑不会过于频繁地执行,从而避免您的控制台被大量消息淹没。
对于训练周期级别的日志记录或更复杂的指标,您通常会围绕 Flux.gradient 和 Flux.update! 编写一个更结构化的训练循环,这为您提供了明确的点来调用自定义的训练周期结束回调。
在长时间的训练运行中保存模型至关重要。检查点回调可以自动完成这项工作,例如,每当验证损失改善时就保存模型。
using Flux, BSON
using Dates # 用于文件名时间戳
# 假设 `model` 是您的 Flux 模型
# `val_loss_fn()` 是一个计算验证损失的函数
# `current_epoch` 由您的训练循环追踪
# 检查点回调的状态(通常封装在一个可调用结构体中)
best_validation_loss = Float32(Inf)
checkpoint_dir = "model_checkpoints/"
mkpath(checkpoint_dir) # 如果目录不存在则创建
function save_best_model_callback(epoch_num::Int) # 传入当前训练周期
global best_validation_loss # 访问全局状态(或使用可调用结构体)
current_val_loss = val_loss_fn() # 用户定义的获取验证损失的函数
@info "训练周期 $(epoch_num) - 验证损失: $(current_val_loss)"
if current_val_loss < best_validation_loss
best_validation_loss = current_val_loss
model_cpu = cpu(model) # 在保存前将模型移至 CPU
timestamp = Dates.format(now(), "YYYYmmdd_HHMMSS")
filename = joinpath(checkpoint_dir, "model_epoch_$(epoch_num)_val_$(best_validation_loss)_$(timestamp).bson")
BSON.@save filename model_state=Flux.state(model_cpu) epoch=epoch_num val_loss=best_validation_loss
@info "检查点已保存: $(filename)"
end
end
# 此回调将在自定义训练循环中,于每个训练周期结束时被调用。
# 如果使用 Flux.train!,您需要进行调整或使用辅助库,
# 因为 Flux.train! 的 `cb` 是按步骤执行的。
在此示例中,save_best_model_callback 检查验证损失,并在取得新的最佳损失时使用 BSON.jl 保存模型(其状态)。通常倾向于使用 Flux.state(model) 而不是直接保存整个模型对象,因为它对于 Flux 版本或模型定义的更改更可靠。
下面的图表显示了不同类型的回调如何集成到训练过程中:
回调在训练的各个阶段提供钩子:整体开始/结束、训练周期开始/结束以及批次开始/结束。这使得微调 (fine-tuning)控制和监控成为可能。
提前停止是一个重要的回调,用于防止过拟合 (overfitting)和节省计算。如果被监控的指标(如验证损失)在“容忍”的训练周期数内停止改善,它将停止训练。
尽管 Flux.jl 本身不提供直接插入 Flux.train! 以停止训练的内置 EarlyStopping 回调,如果您正在编写自己的训练循环,可以轻松实现该逻辑,或者使用 FluxTraining.jl 等提供此功能的库。
您可以这样将 EarlyStopper 构建为一个可调用结构体:
mutable struct EarlyStopper
patience::Int
min_delta::Float64
best_val_loss::Float64
epochs_without_improvement::Int
verbose::Bool
triggered::Bool # 用于标记是否停止
end
EarlyStopper(patience::Int=5, min_delta::Float64=0.001; verbose::Bool=true) =
EarlyStopper(patience, min_delta, Inf, 0, verbose, false)
function (es::EarlyStopper)(current_val_loss::Float64; epoch::Int=0)
if es.triggered return true end # 已触发
if es.verbose
@info "训练周期 $epoch: 提前停止器正在检查验证损失 $(current_val_loss)。最佳: $(es.best_val_loss)。连续 $(es.epochs_without_improvement) 个训练周期无改善。"
end
if current_val_loss < es.best_val_loss - es.min_delta
es.best_val_loss = current_val_loss
es.epochs_without_improvement = 0
if es.verbose
@info "训练周期 $epoch: 验证损失改善至 $(es.best_val_loss)。"
end
else
es.epochs_without_improvement += 1
if es.epochs_without_improvement >= es.patience
if es.verbose
@info "训练周期 $epoch: 提前停止已触发,连续 $(es.epochs_without_improvement) 个训练周期无改善(验证损失: $(current_val_loss))。"
end
es.triggered = true
return true # 发出停止信号
end
end
return false # 发出继续信号
end
# 在自定义循环中的用法:
# stopper = EarlyStopper(patience=3, verbose=true)
# for epoch = 1:num_epochs
# # ... 进行一个训练周期的训练 ...
# val_loss = calculate_validation_loss() # 您的函数
# if stopper(val_loss, epoch=epoch)
# @info "提前停止训练。"
# break
# end
# end
这个 EarlyStopper 记录最佳验证损失以及自上次改善以来的训练周期数。如果容忍期已过,它会发出信号表示训练应停止。
下面的图表说明了提前停止如何通过在验证损失开始下降时停止训练来防止过拟合。
训练在第10个训练周期停止,此时验证损失(橙色线)开始增加,即使训练损失(蓝色线)持续下降。这能防止模型进一步过拟合训练数据。
回调也可以管理学习率的调整。例如,如果验证损失趋于平稳,您可能希望降低学习率。
using Flux
# 假设 `opt` 是您的优化器,例如 opt = Adam(0.001)
# 且 `val_loss_fn()` 计算验证损失
mutable struct ReduceLROnPlateau
optimizer::Any
factor::Float64
patience::Int
min_lr::Float64
best_val_loss::Float64
epochs_without_improvement::Int
verbose::Bool
end
ReduceLROnPlateau(optimizer; factor=0.1, patience=3, min_lr=1e-6, verbose=true) =
ReduceLROnPlateau(optimizer, factor, patience, min_lr, Inf, 0, verbose)
function (rlrop::ReduceLROnPlateau)(current_val_loss::Float64; epoch::Int=0)
if current_val_loss < rlrop.best_val_loss
rlrop.best_val_loss = current_val_loss
rlrop.epochs_without_improvement = 0
else
rlrop.epochs_without_improvement += 1
if rlrop.epochs_without_improvement >= rlrop.patience
# 检查优化器是否有 'eta' 字段(学习率常用)
# 这是简化版;实际优化器可能具有嵌套结构。
current_lr = Flux.Optimise.getattr(rlrop.optimizer, :eta)
if current_lr === nothing # 尝试从链中的第一个优化器获取
if !isempty(rlrop.optimizer.os) && Flux.Optimise.getattr(rlrop.optimizer.os[1], :eta) !== nothing
current_lr = Flux.Optimise.getattr(rlrop.optimizer.os[1], :eta)
end
end
if current_lr !== nothing && current_lr > rlrop.min_lr
new_lr = max(current_lr * rlrop.factor, rlrop.min_lr)
Flux.Optimise.setattr!(rlrop.optimizer, :eta, new_lr) # 这可能需要针对特定优化器进行调整
if !isempty(rlrop.optimizer.os) && Flux.Optimise.getattr(rlrop.optimizer.os[1], :eta) !== nothing
Flux.Optimise.setattr!(rlrop.optimizer.os[1], :eta, new_lr)
end
if rlrop.verbose
@info "训练周期 $epoch: 由于性能停滞,学习率从 $current_lr 降低到 $new_lr。"
end
rlrop.epochs_without_improvement = 0 # 降低后重置容忍期
end
end
end
end
# 在自定义循环中的用法:
# lr_scheduler = ReduceLROnPlateau(opt, factor=0.5, patience=2)
# for epoch = 1:num_epochs
# # ... 训练 ...
# val_loss = calculate_validation_loss()
# lr_scheduler(val_loss, epoch=epoch)
# end
这个 ReduceLROnPlateau 回调监控验证损失,如果在设定的训练周期数内没有看到改善,则降低优化器的学习率。访问和修改学习率(opt.eta 或类似属性)取决于所使用的具体优化器。Flux 的 OptimiserChain 需要仔细处理,以调整相应内部优化器的学习率。
使用 Flux.train! 时,您可以传递一个回调函数数组。在自定义循环中,您只需在适当的阶段调用每个已注册的回调。
# 自定义循环结构的示例
# all_callbacks = [logging_cb, checkpointing_cb, early_stopping_cb, lr_scheduler_cb]
# for epoch in 1:max_epochs
# # ... 运行一个训练周期的训练 ...
# # 在训练周期结束时:
# current_val_loss = calculate_validation_loss()
# for cb in all_callbacks
# if cb isa EarlyStopper || cb isa ReduceLROnPlateau # 需要验证损失的回调
# if cb(current_val_loss, epoch=epoch) && cb isa EarlyStopper && cb.triggered
# # 处理提前停止信号
# break_training_loop = true
# end
# else # 其他回调可能不需要验证损失或训练周期
# cb()
# end
# end
# if break_training_loop break end
# end
best_val_loss、epochs_without_improvement),可调用结构体是封装此状态的绝佳方式。on_epoch_begin、on_batch_end)的回调系统。回调对于高效地协调复杂训练过程是不可或缺的。它们使您能够自动化监控、动态调整训练,并确保您的模型开发既有效又富有洞察力。随着您构建更多的模型,您很可能会创建一套个性化的回调工具包,从而简化您在 Julia 中的常见深度学习 (deep learning)任务。
这部分内容有帮助吗?
Flux.train!和实用工具的使用。© 2026 ApX Machine LearningAI伦理与透明度•