In the standard Deep Q-Network approach, we introduced Experience Replay as a way to break correlations between consecutive samples and reuse past experiences. The replay buffer stores transitions (s,a,r,s′) and the agent samples mini-batches uniformly at random from this buffer to perform Q-network updates. While effective, this uniform sampling treats all transitions as equally important. Intuitively, however, some experiences might offer much more learning potential than others. Imagine an agent encountering a situation that leads to a surprisingly high or low reward, or where its current Q-value prediction is way off. These "surprising" transitions seem more valuable for learning than routine ones where the predictions are already accurate.
Prioritized Experience Replay (PER) leverages this intuition by moving away from uniform sampling. Instead of picking transitions randomly, PER samples transitions proportionally to their learning potential, typically measured by the magnitude of their Temporal Difference (TD) error. The TD error represents how "surprising" a transition was to the network:
δt=rt+γa′maxQtarget(st+1,a′)−Q(st,at)A large absolute TD error, ∣δt∣, signifies that the network's prediction for the state-action pair (st,at) was inaccurate compared to the observed reward and the estimated value of the next state. These are precisely the transitions we want the agent to focus on.
In PER, each transition i in the replay buffer is assigned a priority pi. A common way to define this priority is based on the absolute TD error:
pi=∣δi∣+ϵHere, ϵ is a small positive constant added to ensure that transitions with zero TD error still have a non-zero probability of being sampled.
The probability P(i) of sampling transition i is then defined based on its priority:
P(i)=∑kpkαpiαThe exponent α≥0 is a hyperparameter that controls the degree of prioritization. When α=0, we recover the original uniform sampling strategy (P(i)=1/N, where N is the buffer size). As α increases, the sampling becomes more heavily skewed towards transitions with high TD errors. Efficiently sampling according to these probabilities often involves specialized data structures like SumTrees.
Sampling transitions based on their priority introduces a bias because the updates no longer reflect the original distribution of experiences. Transitions with high TD errors are over-represented in the training batches. To counteract this bias, PER uses Importance Sampling (IS) weights when calculating the loss for the sampled transitions.
The IS weight wi for transition i corrects for its non-uniform sampling probability P(i):
wi=(N⋅P(i)1)βHere, N is the size of the replay buffer. The hyperparameter β∈[0,1] controls how much correction is applied. β=1 fully compensates for the non-uniform probabilities, while β=0 applies no correction. In practice, β is often annealed from an initial value (e.g., 0.4) towards 1 over the course of training.
These weights are then incorporated into the loss function, typically by multiplying the TD error δi by wi during the gradient update step. This ensures that while we sample important transitions more often, the magnitude of their update is scaled down to prevent overfitting to these specific samples and to maintain unbiased estimates in expectation. For numerical stability, the weights are often normalized by dividing by the maximum weight in the mini-batch: wi←wi/maxjwj.
Prioritized Experience Replay is an enhancement to the standard experience replay mechanism used in DQN. By sampling transitions based on their TD error magnitude, it focuses the learning process on the most informative experiences. While it introduces some complexity through priority calculation, non-uniform sampling, and importance sampling weights, PER often leads to significant improvements in learning speed and data efficiency compared to uniform sampling. It's frequently combined with other DQN improvements like Double DQN and Dueling Networks to build highly effective RL agents.
© 2025 ApX Machine Learning