Emergent Mind

Fast Attention Requires Bounded Entries

(2302.13214)
Published Feb 26, 2023 in cs.LG , cs.CC , cs.DS , and stat.ML

Abstract

In modern machine learning, inner product attention computation is a fundamental task for training LLMs such as Transformer, GPT-1, BERT, GPT-2, GPT-3 and ChatGPT. Formally, in this problem, one is given as input three matrices $Q, K, V \in [-B,B]{n \times d}$, and the goal is to construct the matrix $\mathrm{Att}(Q,K,V) := \mathrm{diag}(A {\bf 1}_n){-1} A V \in \mathbb{R}{n \times d}$, where $A = \exp(QK\top/d)$ is the `attention matrix', and $\exp$ is applied entry-wise. Straightforward methods for this problem explicitly compute the $n \times n$ attention matrix $A$, and hence require time $\Omega(n2)$ even when $d = n{o(1)}$ is small. In this paper, we investigate whether faster algorithms are possible by implicitly making use of the matrix $A$. We present two results, showing that there is a sharp transition at $B = \Theta(\sqrt{\log n})$. $\bullet$ If $d = O(\log n)$ and $B = o(\sqrt{\log n})$, there is an $n{1+o(1)}$ time algorithm to approximate $\mathrm{Att}(Q,K,V)$ up to $1/\mathrm{poly}(n)$ additive error. $\bullet$ If $d = O(\log n)$ and $B = \Theta (\sqrt{\log n})$, assuming the Strong Exponential Time Hypothesis from fine-grained complexity theory, it is impossible to approximate $\mathrm{Att}(Q,K,V)$ up to $1/\mathrm{poly}(n)$ additive error in truly subquadratic time $n{2 - \Omega(1)}$. This gives a theoretical explanation for the phenomenon observed in practice that attention computation is much more efficient when the input matrices have smaller entries.

We're not able to analyze this paper right now due to high demand.

Please check back later (sorry!).

Generate a summary of this paper on our Pro plan:

We ran into a problem analyzing this paper.

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.