Stochastic Gradient Descent (SGD) is a workhorse for large-scale machine learning. However, its reliance on noisy gradients computed from small mini-batches can lead to slow or erratic convergence. The variance introduced by sampling only a small fraction of the data at each step often requires careful tuning of learning rates and schedules.
The Stochastic Average Gradient (SAG) algorithm offers a compelling alternative aimed directly at reducing this variance. Proposed by Schmidt, Le Roux, and Bach in 2012, SAG attempts to combine the low iteration cost of SGD with the faster convergence properties often associated with batch gradient descent, particularly for strongly convex problems.
The Core Idea: Averaging Past Gradients
The insight behind SAG is elegantly simple: instead of using just the gradient from the current mini-batch (or single data point in the simplest case), why not leverage information from gradients computed in previous iterations? SAG maintains a memory of the most recently computed gradient for each individual data point in the training set.
Let the objective function be the average loss over N data points:
L(θ)=N1∑i=1NLi(θ)
where Li(θ) is the loss associated with the i-th data point and θ represents the model parameters.
Full batch gradient descent uses the gradient ∇L(θ)=N1∑i=1N∇Li(θ). SGD, in its simplest form (batch size 1), picks a random index ik at iteration k and updates using ∇Lik(θk).
SAG works differently. It maintains a table containing N gradient vectors, g1,g2,…,gN. Each gi stores the most recent gradient ∇Li(θ) computed for data point i at some past iteration when i was selected.
The SAG Update Rule
At each iteration k+1:
- Sample: Randomly select an index ik from {1,2,…,N}.
- Compute Current Gradient: Calculate the gradient for the selected data point using the current parameters θk: vik=∇Lik(θk).
- Update Average Gradient Estimate: The core of SAG is using an average of the stored gradients as the update direction. This average can be updated efficiently. Let Gk=N1∑j=1Ngj be the average of gradients stored up to iteration k. The new average Gk+1 is calculated by effectively replacing the old stored gradient gik with the newly computed gradient vik in the sum:
Gk+1=Gk+N1(vik−gik)
This avoids re-summing all N stored gradients at every step.
- Update Parameters: Update the parameters using this less noisy average gradient estimate:
θk+1=θk−ηGk+1
where η is the learning rate.
- Update Memory: Store the newly computed gradient vik in the memory table, replacing the previous value for gik:
gik←vik
Initially, the stored gradients gi might be initialized to zero vectors or computed from an initial parameter guess.
Flow of a single SAG iteration. A data point ik is sampled, its current gradient is computed, and this new gradient updates both the running average gradient estimate (Gk+1) and the stored gradient memory (gik).
Convergence Properties
One of the main advantages of SAG is its convergence rate. For strongly convex and smooth objective functions, SAG achieves a linear convergence rate. This is a significant improvement over the typically sublinear rate of standard SGD for such problems. In essence, SAG converges much faster in terms of the number of iterations required to reach a certain accuracy, approaching the desirable properties of full batch gradient descent. The convergence does not degrade significantly with increasing dataset size N, unlike batch gradient descent whose cost per iteration scales linearly with N.
Computational Cost and Memory
- Computation: The cost per iteration of SAG is dominated by the computation of a single gradient ∇Lik(θk) and vector additions/subtractions. This is comparable to the cost per iteration of SGD (with batch size 1). It's substantially cheaper than batch gradient descent, which requires computing N gradients per iteration.
- Memory: The primary drawback of SAG is its memory requirement. It needs to store N gradient vectors, where N is the total number of data points in the training set. If the parameter dimension d is large and N is massive (billions of samples), storing N×d floating-point numbers can become impractical or even impossible.
SAG vs. SGD
- Variance: SAG significantly reduces the variance compared to SGD by averaging gradients. This leads to more stable updates and faster convergence.
- Memory: SGD has minimal memory overhead (just storing parameters and current mini-batch gradients), while SAG requires storing N gradients.
- Convergence Speed (Iterations): SAG typically converges much faster (linearly) than SGD (sublinearly) for strongly convex problems.
- Implementation Complexity: SAG is slightly more complex to implement due to the need to manage the gradient memory table.
SAG represents an important step towards bridging the gap between the scalability of SGD and the fast convergence of batch methods. Its memory requirement, however, limits its applicability in scenarios with extremely large N. This limitation motivated the development of subsequent variance reduction techniques like SVRG, which we will discuss next, aiming to retain the fast convergence benefits while mitigating the memory burden.