As introduced earlier in this chapter, Stochastic Gradient Descent (SGD) is a workhorse for large-scale machine learning, but its reliance on noisy gradients computed from small mini-batches can lead to slow or erratic convergence. The variance introduced by sampling only a small fraction of the data at each step often requires careful tuning of learning rates and schedules.
The Stochastic Average Gradient (SAG) algorithm offers a compelling alternative aimed directly at reducing this variance. Proposed by Schmidt, Le Roux, and Bach in 2012, SAG attempts to combine the low iteration cost of SGD with the faster convergence properties often associated with batch gradient descent, particularly for strongly convex problems.
The insight behind SAG is elegantly simple: instead of using just the gradient from the current mini-batch (or single data point in the simplest case), why not leverage information from gradients computed in previous iterations? SAG maintains a memory of the most recently computed gradient for each individual data point in the training set.
Let the objective function be the average loss over N data points: L(θ)=N1∑i=1NLi(θ) where Li(θ) is the loss associated with the i-th data point and θ represents the model parameters.
Full batch gradient descent uses the gradient ∇L(θ)=N1∑i=1N∇Li(θ). SGD, in its simplest form (batch size 1), picks a random index ik at iteration k and updates using ∇Lik(θk).
SAG works differently. It maintains a table containing N gradient vectors, g1,g2,…,gN. Each gi stores the most recent gradient ∇Li(θ) computed for data point i at some past iteration when i was selected.
At each iteration k+1:
Initially, the stored gradients gi might be initialized to zero vectors or computed from an initial parameter guess.
Flow of a single SAG iteration. A data point ik is sampled, its current gradient is computed, and this new gradient updates both the running average gradient estimate (Gk+1) and the stored gradient memory (gik).
One of the main advantages of SAG is its convergence rate. For strongly convex and smooth objective functions, SAG achieves a linear convergence rate. This is a significant improvement over the typically sublinear rate of standard SGD for such problems. In essence, SAG converges much faster in terms of the number of iterations required to reach a certain accuracy, approaching the desirable properties of full batch gradient descent. The convergence does not degrade significantly with increasing dataset size N, unlike batch gradient descent whose cost per iteration scales linearly with N.
SAG represents an important step towards bridging the gap between the scalability of SGD and the fast convergence of batch methods. Its memory requirement, however, limits its applicability in scenarios with extremely large N. This limitation motivated the development of subsequent variance reduction techniques like SVRG, which we will discuss next, aiming to retain the fast convergence benefits while mitigating the memory burden.
© 2025 ApX Machine Learning