Momentum-Based Variance Reduction in Non-Convex SGD
(1905.10018v3)
Published 24 May 2019 in cs.LG, math.OC, and stat.ML
Abstract: Variance reduction has emerged in recent years as a strong competitor to stochastic gradient descent in non-convex problems, providing the first algorithms to improve upon the converge rate of stochastic gradient descent for finding first-order critical points. However, variance reduction techniques typically require carefully tuned learning rates and willingness to use excessively large "mega-batches" in order to achieve their improved results. We present a new algorithm, STORM, that does not require any batches and makes use of adaptive learning rates, enabling simpler implementation and less hyperparameter tuning. Our technique for removing the batches uses a variant of momentum to achieve variance reduction in non-convex optimization. On smooth losses $F$, STORM finds a point $\boldsymbol{x}$ with $\mathbb{E}[|\nabla F(\boldsymbol{x})|]\le O(1/\sqrt{T}+\sigma{1/3}/T{1/3})$ in $T$ iterations with $\sigma2$ variance in the gradients, matching the optimal rate but without requiring knowledge of $\sigma$.
The paper introduces Storm, a novel algorithm that applies momentum-based variance reduction to non-convex SGD, eliminating the need for large batches.
It employs adaptive learning rates that automatically adjust to gradient variance, simplifying hyperparameter tuning and enhancing convergence rates.
Empirical validation on CIFAR-10 using a ResNet model demonstrates Storm’s competitive convergence performance compared to AdaGrad and Adam.
Momentum-Based Variance Reduction in Non-Convex SGD: A Summary
The paper "Momentum-Based Variance Reduction in Non-Convex SGD," authored by Ashok Cutkosky and Francesco Orabona, presents a novel approach to variance reduction within the context of non-convex stochastic gradient descent (SGD). Traditionally, variance reduction techniques, while effective in enhancing the convergence rate to critical points in non-convex problems, necessitate intricately tuned learning rates and reliance on large "mega-batches." This paper proposes an advancement in the form of a new algorithm, dubbed "Storm," which seeks to circumvent these limitations through the use of momentum-based variance reduction without necessitating large batch sizes.
Key Contributions
The core contribution of the paper is the introduction of the Storm algorithm. This technique applies a variant of the momentum commonly used in optimization, akin to heuristic methods such as RMSProp or Adam, to achieve variance reduction in non-convex optimization scenarios. Storm's prominence arises from several aspects:
Batch-Free Optimization: Unlike standard variance reduction methodologies that depend on structured batch gradients, Storm operates without any batch gradients, thereby potentially simplifying implementation and reducing computational overhead.
Adaptive Learning Rates: The algorithm employs adaptive learning rates, which inherently adjust to the variance levels in the gradient estimates. This adaptation mitigates the rigorous task of hyperparameter tuning that is typically required in variance reduction methods, promoting ease of use in practical settings where the learning environment is dynamic and ever-changing.
Theoretical Convergence: Storm is capable of matching the optimal convergence rate of O(1/T1/3) for smooth losses. This accomplishment is achieved without prior knowledge of the variance present in the gradient estimates (σ), a common hurdle in deploying variance reduction techniques effectively in stochastic settings.
Theoretical Analysis
Theoretical insights reveal that Storm achieves an expected norm of gradients, E[∥∇F(xt)∥], of O(1/T+σ1/3/T1/3) in T iterations, adhering to the optimal rates established by existing literature. The probabilistic guarantees are facilitated through the careful construction of a Lyapunov potential function, which manages the interplay between the variance in gradient estimates and the optimization trajectory.
Practical Implications and Empirical Validation
On a practical level, the paper discusses implementing the Storm algorithm in settings demanding non-convex stochastic optimization, such as deep learning. Notably, Storm's momentum-based variance reduction can potentially be leveraged to improve training efficiency without extensive computational resources typically required by competing variance reduction methods.
The empirical performance of Storm is validated on the CIFAR-10 image classification benchmark using a ResNet model. Comparisons with AdaGrad and Adam indicate that Storm is competitive, particularly in terms of iterations needed for convergence with respect to training loss and accuracy. However, the final test accuracy, while marginally better than AdaGrad, implies further empirical investigations are necessary to conclusively validate its generalization performance.
Speculations and Future Directions
Given Storm's ability to achieve optimal convergence without extensive tuning, it presents an attractive approach in scenarios where computation cost is a significant constraint. Future research could explore extending Storm's applications to broader classes of machine learning tasks, including reinforcement learning where variance reduction is crucial for handling the inherent stochasticity. Additionally, integrating regularization techniques with Storm could enhance its capacity to generalize from training data to unseen scenarios, a potential avenue for improving test accuracy.
Conclusion
The research elucidated in this paper offers a significant advancement in the use of momentum for variance reduction within non-convex SGD. The algorithm Storm stands as a notable departure from traditional methods, providing a framework that is both theoretically sound and practically efficient, given the challenges of deploying variance reduction in deep learning and other machine learning paradigms. As variance reduction continues to be an area of active research, the insights from this paper could inspire further investigation into even more efficient and adaptive optimization techniques.