标准经验回放将回放缓冲区中存储的所有转移视为同等重要。在抽取用于训练的小批量时,每个转移 (s,a,r,s′) 被选中的概率是均等的。然而,直观上看,并非所有经验都提供相同的学习潜力。有些转移可能只是例行公事,或者确认了智能体已熟知的内容,而另一些则可能代表令人意外的结果,或者智能体当前Q值估计明显不准确的紧要关头。从这些“令人意外的”转移中学习可以带来更快、更有效的更新。
优先经验回放(PER)建立在这一想法之上,通过根据转移的重要性进行采样,优先选择那些智能体能学到最多的转移。其主要思想是使用时序差分(TD)误差的大小作为衡量转移“令人意外”或提供多少有益信息的替代指标。回忆一下,使用目标网络参数 θ− 和当前网络参数 θ 计算转移 (s,a,r,s′) 的TD误差:
δ=r+γa′maxQ(s′,a′;θ−)−Q(s,a;θ)
较大的绝对TD误差 ∣δ∣ 表明预测的Q值 Q(s,a;θ) 与估计的目标值 r+γmaxa′Q(s′,a′;θ−) 之间存在很大差异。此类转移表示当前Q函数可能不准确的情况,因此提供了有力的学习信号。
分配优先级
在PER中,我们不采用均匀采样,而是为回放缓冲区中的每个转移 i 分配一个优先级 pi。这种优先级的一个常见选择与绝对TD误差直接相关:
pi=∣δi∣+ϵ
其中,ϵ 是一个小的正常量,用于以保证TD误差为零的转移仍然有非零的采样概率。这可以防止它们永远不会被回放。
一旦分配了优先级,我们就将其转换为每个转移 i 的采样概率 P(i):
P(i)=∑kpkαpiα
指数 α≥0 控制优先级的程度。
- 如果 α=0,那么对于所有 i,piα=1(假设 pi>0),我们就恢复了均匀采样:P(i)=1/N,其中 N 是缓冲区中的转移数量。
- 如果 α=1,我们得到基于 pi 的直接比例优先级。
- α 在0到1之间的值允许在均匀采样和完全优先级之间进行插值,提供了一种调节对高优先级样本关注度的方式。
使用SumTree进行高效采样
如果采用简单方法,计算归一化项 ∑kpkα 并从可能数百万的转移中按比例采样,其计算代价会很高。PER通常采用一种名为 SumTree 的专用数据结构(一种线段树的变体)来快速实施。
SumTree是一种二叉树,其中每个叶节点存储单个转移的优先级 piα。每个内部节点存储其子节点的优先级之和。因此,根节点存储总和 ∑kpkα。
一个简化的SumTree结构。叶节点保存转移优先级(例如,piα)。父节点保存其子节点优先级的总和。根节点保存总和。采样涉及在 [0,根总和] 范围内抽取一个随机值,然后遍历树以找到对应的叶节点/转移。
这个结构允许:
- 快速采样: 要采样一个转移,生成一个介于0和根节点存储的总和之间的随机数 z。然后,向下遍历树。在每个节点,如果 z 小于左子节点的总和,则向左走;否则,从 z 中减去左子节点的总和,然后向右走。这个过程在 O(logN) 时间内找到与采样值 z 对应的转移。
- 快速优先级更新: 当一个转移的TD误差(及其优先级)在用于训练后被更新时,只需要更新其在树中的祖先节点的优先级。这同样需要 O(logN) 时间。
使用重要性采样纠正偏差
优先采样并非没有影响。通过优先选择TD误差高的转移,我们引入了偏差。网络将主要看到其表现不佳的转移,这可能扭曲了预期值更新,因为采样分布不再与经验生成时的分布匹配。
为了抵消这种偏差,PER将**重要性采样(IS)**权重纳入学习更新中。其思想是降低那些比在均匀采样下更频繁被采样的转移的更新权重。转移 i 的IS权重 wi 计算如下:
wi=(N1⋅P(i)1)β
其中,N 是回放缓冲区的大小,P(i) 是在优先方案下采样转移 i 的概率,β≥0 是一个控制应用多少校正的指数。
- 如果 β=0,则不应用校正 (wi=1)。
- 如果 β=1,则优先化引入的偏差被完全校正。
在实践中,β 通常在训练过程中从一个初始值(例如0.4)线性退火到1。从较小的 β 开始,使得智能体能在Q函数估计尚不准确的早期阶段从高优先级样本中获得更多益处,而后期增加 β 则有助于在估计变得更准确时减少偏差。
为了保持稳定性,这些权重通常通过除以小批量中的最大权重进行归一化:
wi←maxjwjwi
这以保证更新只会被按比例缩小,而不是可能被大幅增加。然后,归一化权重 wi 用于在梯度更新步骤中缩放转移 i 的TD误差(或损失)。例如,使用均方误差损失时,转移 i 的贡献将被加权:
损失i=wi⋅δi2
PER总结
优先经验回放修改了标准DQN训练循环,具体如下:
- 存储优先级: 将新转移添加到回放缓冲区时,最初为其分配最高优先级,以保证它至少被采样一次。将优先级与转移一起存储,通常使用SumTree。
- 优先采样: 使用从存储的优先级导出的概率 P(i)∝piα 从缓冲区中采样小批量。
- 计算IS权重: 对于每个采样的转移 i,计算其重要性采样权重 wi=(N⋅P(i))−β。归一化这些权重。
- 加权更新: 计算小批量中每个转移的TD误差 δi。在对Q网络执行梯度下降步骤时,使用加权TD误差 wi⋅δi(或将 wi 纳入损失计算)。
- 更新优先级: 使用新计算的绝对TD误差 ∣δi∣ 更新SumTree中采样转移的优先级 pi。
通过将学习重点放在最令人意外或最有用的转移上,PER与标准均匀采样的经验回放相比,通常能显著提升数据利用率并加快收敛速度。然而,它在实现上增加了额外的难度(SumTree管理),并引入了需要调整的新超参数(α,β,ϵ)。它是一种精巧的改进,直接应对了DQN中从存储经验中学习的有效性问题。