Gated Linear Attention Transformers with Hardware-Efficient Training (2312.06635v6)
Abstract: Transformers with linear attention allow for efficient parallel training but can simultaneously be formulated as an RNN with 2D (matrix-valued) hidden states, thus enjoying linear-time inference complexity. However, linear attention generally underperforms ordinary softmax attention. Moreover, current implementations of linear attention lack I/O-awareness and are thus slower than highly optimized implementations of softmax attention. This work describes a hardware-efficient algorithm for linear attention that trades off memory movement against parallelizability. The resulting implementation, dubbed FLASHLINEARATTENTION, is faster than FLASHATTENTION-2 (Dao, 2023) as a standalone layer even on short sequence lengths (e.g., 1K). We then generalize this algorithm to a more expressive variant of linear attention with data-dependent gates. When used as a replacement for the standard attention layer in Transformers, the resulting gated linear attention (GLA) Transformer is found to perform competitively against the LLaMA-architecture Transformer (Touvron et al., 2023) as well recent linear-time-inference baselines such as RetNet (Sun et al., 2023a) and Mamba (Gu & Dao, 2023) on moderate-scale LLMing experiments. GLA Transformer is especially effective at length generalization, enabling a model trained on 2K to generalize to sequences longer than 20K without significant perplexity degradations. For training speed, the GLA Transformer has higher throughput than a similarly-sized Mamba model.
- The sciqa scientific question answering benchmark for scholarly knowledge. Scientific Reports, 13(1):7240, May 2023. ISSN 2045-2322. doi: 10.1038/s41598-023-33607-z. URL https://doi.org/10.1038/s41598-023-33607-z.
- Using fast weights to attend to the recent past. Advances in neural information processing systems, 29, 2016.
- Piqa: Reasoning about physical commonsense in natural language. In Proceedings of the AAAI conference on artificial intelligence, volume 34, pp. 7432–7439, 2020.
- Guy E. Blelloch. Prefix sums and their applications. 1990. URL https://api.semanticscholar.org/CorpusID:60459178.
- Striped attention: Faster ring attention for causal transformers. ArXiv, abs/2311.09431, 2023. URL https://api.semanticscholar.org/CorpusID:265220849.
- Learning phrase representations using rnn encoder-decoder for statistical machine translation. arXiv preprint arXiv:1406.1078, 2014.
- 3.2 the a100 datacenter gpu and ampere architecture. In 2021 IEEE International Solid-State Circuits Conference (ISSCC), volume 64, pp. 48–50, 2021. doi: 10.1109/ISSCC42613.2021.9365803.
- Rethinking attention with performers. arXiv preprint arXiv:2009.14794, 2020.
- Boolq: Exploring the surprising difficulty of natural yes/no questions. arXiv preprint arXiv:1905.10044, 2019.
- Think you have solved question answering? try arc, the ai2 reasoning challenge. arXiv preprint arXiv:1803.05457, 2018.
- Accelerating reduction and scan using tensor core units. In Rudolf Eigenmann, Chen Ding, and Sally A. McKee (eds.), Proceedings of the ACM International Conference on Supercomputing, ICS 2019, Phoenix, AZ, USA, June 26-28, 2019, pp. 46–57. ACM, 2019. doi: 10.1145/3330345.3331057. URL https://doi.org/10.1145/3330345.3331057.
- Tri Dao. Flashattention-2: Faster attention with better parallelism and work partitioning. CoRR, abs/2307.08691, 2023. doi: 10.48550/ARXIV.2307.08691. URL https://doi.org/10.48550/arXiv.2307.08691.
- Flashattention: Fast and memory-efficient exact attention with io-awareness. In NeurIPS, 2022. URL http://papers.nips.cc/paper_files/paper/2022/hash/67d57c32e20fd0a7a302cb81d36e40d5-Abstract-Conference.html.
- Flashfftconv: Efficient convolutions for long sequences with tensor cores. CoRR, abs/2311.05908, 2023. doi: 10.48550/ARXIV.2311.05908. URL https://doi.org/10.48550/arXiv.2311.05908.
- A framework for few-shot language model evaluation, September 2021. URL https://doi.org/10.5281/zenodo.5371628.
- Learning to forget: Continual prediction with LSTM. Neural Comput., 12(10):2451–2471, 2000. doi: 10.1162/089976600300015015. URL https://doi.org/10.1162/089976600300015015.
- Mamba: Linear-time sequence modeling with selective state spaces. 2023. URL https://api.semanticscholar.org/CorpusID:265551773.
- Efficiently modeling long sequences with structured state spaces. In The Tenth International Conference on Learning Representations, ICLR 2022, Virtual Event, April 25-29, 2022. OpenReview.net, 2022a. URL https://openreview.net/forum?id=uYLFoz1vlAC.
- Efficiently modeling long sequences with structured state spaces, 2022b.
- Franz A. Heinsen. Efficient parallelization of an ubiquitous sequential computation. 2023. URL https://api.semanticscholar.org/CorpusID:265149785.
- Using fast weights to deblur old memories. In Proceedings of the ninth annual conference of the Cognitive Science Society, pp. 177–186, 1987.
- Long short-term memory. Neural Computation, 9(8):1735–1780, 1997.
- Transformer quality in linear time. In Kamalika Chaudhuri, Stefanie Jegelka, Le Song, Csaba Szepesvári, Gang Niu, and Sivan Sabato (eds.), International Conference on Machine Learning, ICML 2022, 17-23 July 2022, Baltimore, Maryland, USA, volume 162 of Proceedings of Machine Learning Research, pp. 9099–9117. PMLR, 2022. URL https://proceedings.mlr.press/v162/hua22a.html.
- Going beyond linear transformers with recurrent fast weight programmers. Advances in Neural Information Processing Systems, 34:7703–7717, 2021.
- Finetuning pretrained transformers into rnns. In Marie-Francine Moens, Xuanjing Huang, Lucia Specia, and Scott Wen-tau Yih (eds.), Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing, EMNLP 2021, Virtual Event / Punta Cana, Dominican Republic, 7-11 November, 2021, pp. 10630–10643. Association for Computational Linguistics, 2021a. doi: 10.18653/V1/2021.EMNLP-MAIN.830. URL https://doi.org/10.18653/v1/2021.emnlp-main.830.
- Finetuning pretrained transformers into RNNs. In Marie-Francine Moens, Xuanjing Huang, Lucia Specia, and Scott Wen-tau Yih (eds.), Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing, pp. 10630–10643, Online and Punta Cana, Dominican Republic, November 2021b. Association for Computational Linguistics. doi: 10.18653/v1/2021.emnlp-main.830. URL https://aclanthology.org/2021.emnlp-main.830.
- Transformers are rnns: Fast autoregressive transformers with linear attention. In International conference on machine learning, pp. 5156–5165. PMLR, 2020.
- Tobias Katsch. Gateloop: Fully data-controlled linear recurrence for sequence modeling. ArXiv, abs/2311.01927, 2023. URL https://api.semanticscholar.org/CorpusID:265018962.
- tcfft: Accelerating half-precision FFT through tensor cores. CoRR, abs/2104.11471, 2021. URL https://arxiv.org/abs/2104.11471.
- Lightseq: Sequence level parallelism for distributed training of long context transformers. ArXiv, abs/2310.03294, 2023a. URL https://api.semanticscholar.org/CorpusID:263671659.
- Sequence parallelism: Long sequence training from system perspective. In Anna Rogers, Jordan Boyd-Graber, and Naoaki Okazaki (eds.), Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pp. 2391–2404, Toronto, Canada, July 2023b. Association for Computational Linguistics. doi: 10.18653/v1/2023.acl-long.134. URL https://aclanthology.org/2023.acl-long.134.
- Lucas D. Lingle. Transformer-vq: Linear-time transformers via vector quantization. CoRR, abs/2309.16354, 2023. doi: 10.48550/ARXIV.2309.16354. URL https://doi.org/10.48550/arXiv.2309.16354.
- Ring attention with blockwise transformers for near-infinite context. ArXiv, abs/2310.01889, 2023. URL https://api.semanticscholar.org/CorpusID:263608461.
- Fixing weight decay regularization in adam. 2018.
- Huanru Henry Mao. Fine-tuning pre-trained transformers into decaying fast weights. In Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing, pp. 10236–10242, Abu Dhabi, United Arab Emirates, December 2022. Association for Computational Linguistics. doi: 10.18653/v1/2022.emnlp-main.697. URL https://aclanthology.org/2022.emnlp-main.697.
- Parallelizing linear recurrent neural nets over sequence length. In 6th International Conference on Learning Representations, ICLR 2018, Vancouver, BC, Canada, April 30 - May 3, 2018, Conference Track Proceedings. OpenReview.net, 2018. URL https://openreview.net/forum?id=HyUNwulC-.
- Long range language modeling via gated state spaces. In The Eleventh International Conference on Learning Representations, ICLR 2023, Kigali, Rwanda, May 1-5, 2023. OpenReview.net, 2023. URL https://openreview.net/pdf?id=5MkYIYCbva.
- Can a suit of armor conduct electricity? a new dataset for open book question answering. arXiv preprint arXiv:1809.02789, 2018.
- The lambada dataset: Word prediction requiring a broad discourse context. arXiv preprint arXiv:1606.06031, 2016.
- RWKV: reinventing rnns for the transformer era. CoRR, abs/2305.13048, 2023. doi: 10.48550/ARXIV.2305.13048. URL https://doi.org/10.48550/arXiv.2305.13048.
- Random feature attention. arXiv preprint arXiv:2103.02143, 2021.
- Accelerating non-power-of-2 size fourier transforms with GPU tensor cores. In 35th IEEE International Parallel and Distributed Processing Symposium, IPDPS 2021, Portland, OR, USA, May 17-21, 2021, pp. 507–516. IEEE, 2021. doi: 10.1109/IPDPS49936.2021.00059. URL https://doi.org/10.1109/IPDPS49936.2021.00059.
- Recurrent linear transformers. CoRR, abs/2310.15719, 2023. doi: 10.48550/ARXIV.2310.15719. URL https://doi.org/10.48550/arXiv.2310.15719.
- The devil in linear transformer. arXiv preprint arXiv:2210.10340, 2022.
- Scaling transnormer to 175 billion parameters. arXiv preprint arXiv:2307.14995, 2023a.
- Hierarchically gated recurrent neural network for sequence modeling. CoRR, abs/2311.04823, 2023b. doi: 10.48550/ARXIV.2311.04823. URL https://doi.org/10.48550/arXiv.2311.04823.
- Swish: a self-gated activation function. arXiv: Neural and Evolutionary Computing, 2017. URL https://api.semanticscholar.org/CorpusID:196158220.
- Coqa: A conversational question answering challenge. Transactions of the Association for Computational Linguistics, 7:249–266, 2019.
- Winogrande: An adversarial winograd schema challenge at scale. Communications of the ACM, 64(9):99–106, 2021.
- Linear transformers are secretly fast weight programmers. In Marina Meila and Tong Zhang (eds.), Proceedings of the 38th International Conference on Machine Learning, ICML 2021, 18-24 July 2021, Virtual Event, volume 139 of Proceedings of Machine Learning Research, pp. 9355–9366. PMLR, 2021. URL http://proceedings.mlr.press/v139/schlag21a.html.
- Jürgen Schmidhuber. Learning to control fast-weight memories: An alternative to dynamic recurrent networks. Neural Computation, 4(1):131–139, 1992.
- Noam Shazeer. Glu variants improve transformer. arXiv preprint arXiv:2002.05202, 2020.
- Simplified state space layers for sequence modeling. In The Eleventh International Conference on Learning Representations, ICLR 2023, Kigali, Rwanda, May 1-5, 2023. OpenReview.net, 2023. URL https://openreview.net/pdf?id=Ai8Hw3AXqks.
- Roformer: Enhanced transformer with rotary position embedding. CoRR, abs/2104.09864, 2021. URL https://arxiv.org/abs/2104.09864.
- Retentive network: A successor to transformer for large language models. arXiv preprint arXiv:2307.08621, 2023.
- Triton: an intermediate language and compiler for tiled neural network computations. In Tim Mattson, Abdullah Muzahid, and Armando Solar-Lezama (eds.), Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages, MAPL@PLDI 2019, Phoenix, AZ, USA, June 22, 2019, pp. 10–19. ACM, 2019. doi: 10.1145/3315508.3329973. URL https://doi.org/10.1145/3315508.3329973.
- Llama: Open and efficient foundation language models. arXiv preprint arXiv:2302.13971, 2023.
- Jos van der Westhuizen and Joan Lasenby. The unreasonable effectiveness of the forget gate. CoRR, abs/1804.04849, 2018. URL http://arxiv.org/abs/1804.04849.
- Attention is all you need. Advances in neural information processing systems, 30, 2017.
- Pretraining without attention. CoRR, abs/2212.10544, 2022. doi: 10.48550/ARXIV.2212.10544. URL https://doi.org/10.48550/arXiv.2212.10544.
- Diffusion models without attention. 2023. URL https://api.semanticscholar.org/CorpusID:265506646.
- Hellaswag: Can a machine really finish your sentence? arXiv preprint arXiv:1905.07830, 2019.
- Root mean square layer normalization. Advances in Neural Information Processing Systems, 32, 2019.