Linear Attention: Backward Pass

3 conditional tokens  ·  3 generated tokens  ·  serial recurrence

Step 1 of 3 — softmax backward: the n² bottleneck

conditional (n, = 3) generated (nᵧ = 3) t1 t2 t3 t4 t5 t6 stored for softmax backward O(n²) per layer ∂ℒ/∂A (n×n gradient)
Softmax attention computes an $n \times n$ weight matrix $A = \text{softmax}(QK^\top/\sqrt{d})$. The backward pass through softmax requires both $A$ and $\partial\mathcal{L}/\partial A$, which are both $n \times n$ — so the full $A$ must be stored during the forward pass. This gives $O(n^2)$ activation memory per attention layer, growing quadratically with sequence length.
t1 t2 t3 t4 t5 t6 S0 S3 S4 S5 S6 t1, t2, t3 +φ(k4)v4 +φ(k5)v5 +φ(k6)v6 ∂ℒ/∂S flows back O(rd) fixed size No n×n matrix — state S stays O(rd) throughout forward and backward
The serial recurrence $S_t = S_{t-1} + \phi(k_t)v_t^\top$ has a clean backward: $$\frac{\partial\mathcal{L}}{\partial S_{t-1}} = \frac{\partial\mathcal{L}}{\partial S_t} + \phi(q_t)\,\frac{\partial\mathcal{L}}{\partial o_t}^\top \cdot \mathbf{1}[t \in \mathcal{G}]$$ The gradient passes additively through each step — no matrix inversion, no softmax Jacobian. At conditional positions (t = 1, 2, 3) there is no $\partial\mathcal{L}/\partial o_t$ term, so the gradient simply passes through unchanged. The state $S$ is $O(rd)$ at every step, never $O(n^2)$.
Softmax O(n²) Linear O(rd) activation memory per attention layer →
Property Softmax attention Linear attention
Activation memory (attn) $O(n^2)$ — store $A$ $O(rd)$ — store $S$
Backward attn FLOPs $O(n^2 d)$ — transpose $A$ $O(nrd)$ — additive
Gradient checkpointing Needed at long context Largely unnecessary
Autoregressive inference KV cache grows: $O(n)$/step Fixed state: $O(1)$/step
Gradient quality Exact softmax Kernel approx.
Memory is the main practical win.  Even when linear attention's forward-pass FLOP advantage is small (at short context where $n \ll 4d$), eliminating the $O(n^2)$ activation matrix reduces peak memory substantially. At $n = 4096$, $d = 4096$, $L = 32$: softmax stores roughly $4096^2 \times 32 \approx 500\text{M}$ attention weights; linear attention stores $32 \times rd$ — a factor of $n/r$ smaller. The fixed-size inference state is the second major win, enabling constant-memory autoregressive generation.