One Step of Gradient Descent is Provably the Optimal In-Context Learner with One Layer of Linear Self-Attention (2307.03576v1)
Abstract: Recent works have empirically analyzed in-context learning and shown that transformers trained on synthetic linear regression tasks can learn to implement ridge regression, which is the Bayes-optimal predictor, given sufficient capacity [Aky\"urek et al., 2023], while one-layer transformers with linear self-attention and no MLP layer will learn to implement one step of gradient descent (GD) on a least-squares linear regression objective [von Oswald et al., 2022]. However, the theory behind these observations remains poorly understood. We theoretically study transformers with a single layer of linear self-attention, trained on synthetic noisy linear regression data. First, we mathematically show that when the covariates are drawn from a standard Gaussian distribution, the one-layer transformer which minimizes the pre-training loss will implement a single step of GD on the least-squares linear regression objective. Then, we find that changing the distribution of the covariates and weight vector to a non-isotropic Gaussian distribution has a strong impact on the learned algorithm: the global minimizer of the pre-training loss now implements a single step of $\textit{pre-conditioned}$ GD. However, if only the distribution of the responses is changed, then this does not have a large effect on the learned algorithm: even when the response comes from a more general family of $\textit{nonlinear}$ functions, the global minimizer of the pre-training loss still implements a single step of GD on a least-squares linear regression objective.
- Transformers learn to implement preconditioned gradient descent for in-context learning, 2023.
- What learning algorithm is in-context learning? investigations with linear models. In The Eleventh International Conference on Learning Representations, 2023. URL https://openreview.net/forum?id=0g0X4H8yN4I.
- Language models are few-shot learners. CoRR, abs/2005.14165, 2020. URL https://arxiv.org/abs/2005.14165.
- Why can gpt learn in-context? language models implicitly perform gradient descent as meta-optimizers, 2023.
- What can transformers learn in-context? A case study of simple function classes. In NeurIPS, 2022. URL http://papers.nips.cc/paper_files/paper/2022/hash/c529dba08a146ea8d6cf715ae8930cbe-Abstract-Conference.html.
- Looped transformers as programmable computers, 2023.
- Jurassic-1: Technical details and evaluation. Technical report, AI21 Labs, 2021.
- Transformers learn shortcuts to automata. In The Eleventh International Conference on Learning Representations, ICLR 2023, Kigali, Rwanda, May 1-5, 2023. OpenReview.net, 2023. URL https://openreview.net/pdf?id=De4FYqjFueZ.
- Rethinking the role of demonstrations: What makes in-context learning work? In EMNLP, 2022.
- Language models are unsupervised multitask learners. Technical report, OpenAI, 2019.
- Transformers learn in-context by gradient descent, 2022.
- Gpt-j-6b: A 6 billion parameter autoregressive language model, 2021.
- Larger language models do in-context learning differently, 2023.
- An explanation of in-context learning as implicit bayesian inference. In The Tenth International Conference on Learning Representations, ICLR 2022, Virtual Event, April 25-29, 2022. OpenReview.net, 2022. URL https://openreview.net/forum?id=RdJVFCHjUMI.
- Trained transformers learn linear models in-context, 2023.