Standard Experience Replay, as discussed in the context of DQN, treats all transitions stored in the replay buffer as equally important. When sampling a mini-batch for training, each transition (s,a,r,s′) has an equal probability of being selected. However, intuitively, not all experiences offer the same learning potential. Some transitions might be routine or confirm what the agent already knows well, while others might represent surprising outcomes or critical moments where the agent's current Q-value estimates are significantly inaccurate. Learning from these "surprising" transitions could lead to faster and more effective updates.
Prioritized Experience Replay (PER) builds on this intuition by sampling transitions based on their significance, giving preference to those from which the agent can learn the most. The core idea is to use the magnitude of the Temporal Difference (TD) error as a proxy for how "surprising" or informative a transition is. Recall the TD error for a transition (s,a,r,s′) using target network parameters θ− and current network parameters θ:
δ=r+γa′maxQ(s′,a′;θ−)−Q(s,a;θ)A large absolute TD error, ∣δ∣, indicates a large discrepancy between the predicted Q-value Q(s,a;θ) and the estimated target value r+γmaxa′Q(s′,a′;θ−). Such transitions represent situations where the current Q-function is likely inaccurate and thus provide a strong learning signal.
In PER, instead of uniform sampling, we assign a priority pi to each transition i in the replay buffer. A common choice for this priority is directly related to the absolute TD error:
pi=∣δi∣+ϵHere, ϵ is a small positive constant added to ensure that transitions with zero TD error still have a non-zero probability of being sampled. This prevents them from never being replayed.
Once priorities are assigned, we convert them into sampling probabilities P(i) for each transition i:
P(i)=∑kpkαpiαThe exponent α≥0 controls the degree of prioritization.
Calculating the normalization term ∑kpkα and sampling proportionally from potentially millions of transitions can be computationally expensive if done naively. PER typically employs a specialized data structure called a SumTree (a variant of a segment tree) for efficient implementation.
A SumTree is a binary tree where each leaf node stores the priority piα of a single transition. Each internal node stores the sum of the priorities of its children. The root node, therefore, stores the total sum ∑kpkα.
A simplified SumTree structure. Leaf nodes hold transition priorities (e.g., piα). Parent nodes hold the sum of their children's priorities. The root holds the total sum. Sampling involves drawing a random value in [0,Root Sum] and traversing the tree to find the corresponding leaf/transition.
This structure allows for:
Prioritized sampling is not without consequences. By preferentially selecting transitions with high TD errors, we introduce bias. The network will primarily see transitions where it performs poorly, potentially distorting the expected value updates because the sampled distribution no longer matches the distribution under which the experiences were generated.
To counteract this bias, PER incorporates Importance Sampling (IS) weights into the learning update. The intuition is to down-weight the updates for transitions that were sampled more frequently than they would have been under uniform sampling. The IS weight wi for transition i is calculated as:
wi=(N1⋅P(i)1)βHere, N is the size of the replay buffer, P(i) is the probability of sampling transition i under the prioritized scheme, and β≥0 is an exponent that controls how much correction is applied.
In practice, β is often annealed from an initial value (e.g., 0.4) linearly towards 1 over the course of training. Starting with a smaller β allows the agent to benefit more strongly from the prioritized samples early on when the Q-function is still poorly estimated, while increasing β later helps reduce bias as the estimates become more refined.
For stability, these weights are typically normalized by dividing by the maximum weight in the mini-batch:
wi←maxjwjwiThis ensures that the updates are only scaled down, not potentially scaled up significantly. The normalized weight wi is then used to scale the TD error (or the loss) for transition i during the gradient update step. For instance, using mean squared error loss, the contribution of transition i would be weighted:
Lossi=wi⋅δi2Prioritized Experience Replay modifies the standard DQN training loop in the following ways:
By focusing learning on transitions that are most surprising or informative, PER often leads to significant improvements in data efficiency and faster convergence compared to standard Experience Replay with uniform sampling. However, it introduces additional complexity in implementation (SumTree management) and new hyperparameters (α, β, ϵ) that need tuning. It represents a sophisticated enhancement that directly addresses the efficiency of learning from stored experiences in DQN.
© 2025 ApX Machine Learning