Linear Attention: Eliminating the Quadratic Term

n = 6 tokens  ·  how reordering matrix multiplication changes complexity

Step 1 of 3 — the quadratic bottleneck

n = 6 TOKENS t1 t2 t3 t4 t5 t6 compute QKᵀ — each row i scores against every column j ≤ i q₁ q₂ q₃ q₄ k₁ k₂ k₃ k₄ (showing 4 of 6) n² entries each costs O(d) FLOPs → total O(n²d)
Standard attention computes $QK^\top$, an $(n \times d) \times (d \times n)$ multiply. Every query $q_i$ must be scored against every key $k_j$ ($j \leq i$), producing $\tfrac{n(n+1)}{2}$ dot products. Each dot product costs $O(d)$, giving the well-known $$\text{Scores} + \text{AttnOut} = 2n^2 d \;\text{ FLOPs per layer.}$$ Doubling the sequence length multiplies this cost by four.
SCAN LEFT → RIGHT: build cumulative state S t1 t2 t3 t4 t5 t6 (query) φ(kᵢ) vᵢᵀ St = Σᵢ≤t φ(kᵢ) vᵢᵀ ∈ ℝʳˣᵈ update cost: O(rd) per step φ(q₆)ᵀ o6 = φ(q₆)ᵀ S₆ query cost: O(rd) Total: O(nrd) — no n² matrix ever formed
Replace $\exp(q^\top k)$ with a kernel approximation $\phi(q)^\top \phi(k)$ where $\phi : \mathbb{R}^d \to \mathbb{R}^r$ is a positive-valued feature map. Then use associativity to reorder the computation: $$o_t = \frac{\phi(q_t)^\top \overbrace{\Bigl(\textstyle\sum_{i \leq t} \phi(k_i)\, v_i^\top\Bigr)}^{S_t \in \mathbb{R}^{r \times d}}} {\phi(q_t)^\top \bigl(\sum_{i \leq t} \phi(k_i)\bigr)}$$ $S_t$ is updated by adding one rank-1 term $\phi(k_t)v_t^\top$ per step — cost $O(rd)$. Querying with $\phi(q_t)^\top S_t$ also costs $O(rd)$. The total attention cost across all $n$ positions is $O(nrd)$, linear in sequence length.
sequence length n → FLOPs → softmax: O(n²d) linear: O(nrd) n ≈ d/r crossover
Property Softmax attention Linear attention
Attention FLOPs $O(n^2 d)$ $O(nrd)$
Attention memory $O(n^2)$ (score matrix) $O(rd)$ (state $S$)
Autoregressive inference KV cache grows with $n$ Fixed-size state $S$
Approximation quality Exact Kernel approx. of softmax
Long-range retrieval Strong (peaked softmax) Weaker (low-rank kernel)
MLP FLOPs (unchanged) $O(nd^2)$ $O(nd^2)$
When does it matter?  Linear attention saves the $4n^2 d$ attention term only when it dominates $16nd^2$ (the MLP), i.e. when $n \gtrsim 4d$. For $d = 4096$ that is $n \gtrsim 16{,}000$ tokens — beyond typical fine-tuning lengths. At shorter contexts, the MLP dominates and the quadratic attention term is a minority cost. The real benefit is inference: the fixed-size state $S$ makes autoregressive generation $O(1)$ per step instead of $O(n)$, which is the design motivation behind RetNet, Mamba, and RWKV.