3 conditional tokens · 3 generated tokens · serial recurrence
Step 1 of 3 — softmax backward: the n² bottleneck
Softmax attention computes an $n \times n$ weight matrix $A = \text{softmax}(QK^\top/\sqrt{d})$.
The backward pass through softmax requires both $A$ and $\partial\mathcal{L}/\partial A$,
which are both $n \times n$ — so the full $A$ must be stored during the forward pass.
This gives $O(n^2)$ activation memory per attention layer, growing quadratically with sequence length.
The serial recurrence $S_t = S_{t-1} + \phi(k_t)v_t^\top$ has a clean backward:
$$\frac{\partial\mathcal{L}}{\partial S_{t-1}} = \frac{\partial\mathcal{L}}{\partial S_t}
+ \phi(q_t)\,\frac{\partial\mathcal{L}}{\partial o_t}^\top \cdot \mathbf{1}[t \in \mathcal{G}]$$
The gradient passes additively through each step — no matrix inversion, no softmax Jacobian.
At conditional positions (t = 1, 2, 3) there is no $\partial\mathcal{L}/\partial o_t$ term,
so the gradient simply passes through unchanged.
The state $S$ is $O(rd)$ at every step, never $O(n^2)$.
Property
Softmax attention
Linear attention
Activation memory (attn)
$O(n^2)$ — store $A$
$O(rd)$ — store $S$
Backward attn FLOPs
$O(n^2 d)$ — transpose $A$
$O(nrd)$ — additive
Gradient checkpointing
Needed at long context
Largely unnecessary
Autoregressive inference
KV cache grows: $O(n)$/step
Fixed state: $O(1)$/step
Gradient quality
Exact softmax
Kernel approx.
Memory is the main practical win.
Even when linear attention's forward-pass FLOP advantage is small (at short context where
$n \ll 4d$), eliminating the $O(n^2)$ activation matrix reduces peak memory substantially.
At $n = 4096$, $d = 4096$, $L = 32$: softmax stores roughly $4096^2 \times 32 \approx 500\text{M}$
attention weights; linear attention stores $32 \times rd$ — a factor of $n/r$ smaller.
The fixed-size inference state is the second major win, enabling constant-memory autoregressive generation.