经验回放有助于打破连续训练样本之间的相关性,从而提高训练的稳定性。然而,在训练Q网络时,另一个主要的不稳定性来源在于时序差分 (TD) 更新中使用的目标值不断变化。
回顾应用于函数近似的标准Q学习更新规则。我们的目标是最小化当前Q值估计 Q(St,At;θ) 与目标值之间的差异。这个目标通常使用奖励 Rt+1 和下一状态 St+1 的估计值来计算。在Q学习中,这涉及使用当前网络参数 θ 来寻找下一状态的最大Q值:
目标值t=Rt+1+γa′maxQ(St+1,a′;θt)
损失函数,通常是均方误差 (MSE),则会是类似如下的形式:
L(θt)=E[((Rt+1+γa′maxQ(St+1,a′;θt))−Q(St,At;θt))2]
注意到参数 θt 同时出现在目标计算和我们试图调整的值 (Q(St,At;θt)) 中。当我们执行梯度下降来更新 θt 时,我们本质上是在追逐一个移动的目标。随着网络权重 θt 在每一步中变化,目标值本身也会移动。这种相互依赖性可能导致训练期间的震荡甚至发散,使学习不稳定。这就像你每次调整瞄准时,目标也会移动,你很难击中它。
使用目标网络稳定目标
为了解决这个“移动目标”问题,DQN算法引入了第二个神经网络:目标网络。这个目标网络(我们用 θ− 表示其参数)本质上是在线Q网络(我们正在积极训练的那个,参数为 θ)的克隆。
其工作方式如下:
- 初始化: 目标网络参数 θ− 被初始化为与在线网络参数 θ 相同。
- 目标计算: 当为损失函数计算TD目标时,我们使用目标网络 θ− 来估计下一状态的值。目标值 yt 变为:
yt=Rt+1+γa′maxQ(St+1,a′;θ−)
注意此处使用了 θ− 而非 θt。
- 损失计算: 损失随后使用这个固定目标 yt 和在线网络对当前状态-动作对的预测来计算:
L(θt)=E(St,At,Rt+1,St+1)∼D[(yt−Q(St,At;θt))2]
这里,D 表示用于采样转换的经验回放缓冲区。
- 参数更新: 只有在线网络参数 θt 会通过使用这个损失进行梯度下降来更新。目标网络参数 θ− 在这些更新期间保持不变。
- 定期更新: 在固定数量的训练步骤(我们称此频率为 C)后,在线网络的权重会被复制到目标网络:θ−←θt。
这种机制提供了稳定性,因为在在线网络 θ 的 C 次连续更新中,目标值 yt 保持固定。在线网络现在正在学习近似一个静态目标,这大大简化了学习动态,并降低了震荡和发散的可能性。更新目标网络的频率 C 是一个需要选择的超参数;典型值可能从数百到数千步不等,具体取决于特定问题。
交互流程图显示了在线网络 (Q(St,A;θ)) 和目标网络 (Q(St+1,a′;θ−)) 在DQN更新步骤中的使用方式。在线网络参数 θ 通过梯度下降频繁更新,而目标网络参数 θ− 则仅定期通过从 θ 复制来更新。
通过将固定Q目标与经验回放结合,DQN解决了应用Q学习与深度神经网络等复杂函数近似器时固有的两个主要不稳定性来源。经验回放解除了数据样本的相关性,而目标网络则为学习更新提供了稳定目标。这些技术共同构成了早期DQN从原始像素输入中学习玩Atari游戏取得成功的基础。