LLM Optimization Basics: Time
In supervised fine-tuning and RLHF, a model is trained on sequences that mix a conditional prefix (the prompt or instruction) with generated tokens (the target response). The loss is computed only on the generated portion, but the full sequence still flows through the forward pass. This asymmetry — compute everywhere, loss only on some positions — has non-obvious implications for where gradients flow and how much memory is required. This post works through the FLOP arithmetic for both standard (softmax) attention and linear attention, tracing forward and backward costs separately.
Conditional Generation Setup
Let \(n_c\) denote the number of conditional tokens (the prefix) and \(n_g\) the number of generated tokens (the target). The total sequence length is \(n = n_c + n_g\). During training, the entire sequence is packed into a single forward pass under teacher forcing: the model sees the ground-truth token at each position as context, regardless of what it would have predicted.
The cross-entropy loss is then masked to the generated positions:
\[\mathcal{L} = -\frac{1}{n_g} \sum_{t = n_c + 1}^{n_c + n_g} \log p_\theta(x_t \mid x_{<t})\]Conditional positions contribute to context (they appear as keys and values in attention) but contribute zero to \(\mathcal{L}\).
Transformers: Quadratic Attention
Forward Pass
Consider a transformer with \(L\) layers, hidden dimension \(d\), MLP expansion factor 4, and single-head attention for simplicity. For a sequence of length \(n\), the dominant FLOP terms per layer are:
QKV projection. Each of the three projections is \(n \times d \to n \times d\), costing \(2nd^2\) FLOPs each:
\[\text{QKV} = 3 \times 2nd^2 = 6nd^2\]Attention scores. Computing \(QK^\top\) is an \((n \times d) \times (d \times n)\) matrix multiply:
\[\text{Scores} = 2n^2 d\]Weighted sum over values. Multiplying the \(n \times n\) attention weights by \(V \in \mathbb{R}^{n \times d}\):
\[\text{AttnOut} = 2n^2 d\]Output projection. \(n \times d \to n \times d\):
\[\text{OutProj} = 2nd^2\]MLP. Two linear layers \(d \to 4d \to d\):
\[\text{MLP} = 16nd^2\]Summing over one layer and \(L\) layers:
\[C_{\text{fwd}} = L\bigl(24nd^2 + 4n^2 d\bigr)\]When \(d \gg n\) the linear term dominates; when \(n \gg d\) the quadratic attention term \(4n^2 d\) takes over. Critically, the forward pass computes representations for all \(n = n_c + n_g\) tokens. The loss mask applied afterward does not change what the forward pass touches — every position’s QKV projections are computed, and every key and value vector appears in the attention of all later positions.
Sequence packing. An orthogonal inefficiency arises in batched training when sequences have variable lengths. The naive approach pads each sequence to the maximum length, wasting FLOPs on dummy tokens and inflating the \(n^2\) attention term. Sequence packing (also called sample packing or bin packing) eliminates this by concatenating multiple sequences end-to-end into one dense sequence, separated by end-of-sequence tokens. A block-diagonal attention mask prevents cross-sequence attention: if sequence \(j\) occupies positions \([s_j, e_j)\), then \(M_{ab} = -\infty\) whenever \(a\) and \(b\) belong to different sequences. FlashAttention-2’s varlen mode implements this without materializing the full mask. If the average sequence occupies fraction \(\rho\) of the padded length, packing recovers a factor of \(\rho^2\) on the quadratic attention FLOP — for instruction-tuning datasets with short and variable-length responses, \(\rho\) can be as low as 0.3–0.5, making packing effectively mandatory. Each packed example still has its own conditional prefix and generated suffix, so the loss mask must track both cross-sequence exclusions (via the attention mask) and within-sequence conditional positions (via the loss mask), applied at different points in the computation graph.
Backward Pass
The backward pass costs approximately 2× the forward pass FLOP:
\[C_{\text{bwd}} \approx 2\,C_{\text{fwd}}, \qquad C_{\text{total}} \approx 3\,C_{\text{fwd}}\]This factor of two arises because each weight matrix participates in two gradient operations (gradient w.r.t. input and gradient w.r.t. weight). For the attention block, the \(n \times n\) attention weight matrix must be transposed for the backward pass through the weighted sum, adding another \(O(n^2 d)\) term:
\[C_{\text{attn, bwd}} \approx 2 \times \bigl(6nd^2 + 4n^2 d\bigr)\]The same \(n = n_c + n_g\) appears in both terms — all positions participate in the backward FLOP, not just the generated ones.
Gradient flow. Consider the concrete case \(n_c = 10\), \(n_g = 20\), total \(n = 30\). Positions \(1\text{--}10\) are conditional; positions \(11\text{--}30\) are generated. Since \(\mathcal{L}\) depends only on logits at positions \(11\text{--}30\), backpropagation traces backward only through computation paths that connect to those positions. \(W_O\), \(W_Q\), and the MLP weights are touched only by generated positions. \(W_K\) and \(W_V\) are different: the attention output at each generated position \(t\) is \(o_t = \sum_{i \leq t} \alpha_{ti} v_i\) where \(v_i = W_V h_i\), so backpropagating through \(o_t\) produces gradients w.r.t. \(W_K\) and \(W_V\) at every attended position \(i \leq t\), including all \(n_c\) conditional tokens. The conditional tokens thus receive non-zero gradients for \(W_K\), \(W_V\), and their own input embeddings — despite contributing zero to \(\mathcal{L}\).
Figure 1. Step through the three panels to trace gradient flow from loss positions back through attention to all tokens.
| Component | Positions receiving gradient |
|---|---|
| Loss | \(11\text{--}30\) (generated only) |
| \(W_O\) (output proj.) | \(11\text{--}30\) |
| \(W_Q\) (query proj.) | \(11\text{--}30\) |
| \(W_K\) (key proj.) | \(1\text{--}30\) (all tokens) |
| \(W_V\) (value proj.) | \(1\text{--}30\) (all tokens) |
| MLP weights | \(11\text{--}30\) (through residual from loss) |
| Input embeddings | \(1\text{--}30\) (all tokens, via key/value paths) |
Activation memory. The backward pass needs the intermediate activations stored during the forward pass to compute \(\partial \mathcal{L} / \partial W\). For softmax attention, this includes both the pre- and post-softmax \(n \times n\) attention weight matrix (needed for the backward through softmax), giving activation memory per layer of roughly \(O(n^2 + nd)\) — dominated by \(O(n^2)\) at long context. Since \(W_K\) and \(W_V\) receive gradients from all positions, the activations of all \(n_c\) conditional tokens must be retained even though they produce no loss.
Detaching conditional tokens. One common optimization is to process the prefix under torch.no_grad(), computing its KV cache without storing activations, then re-attach for the generated portion:
with torch.no_grad():
past_kv = model(conditional_tokens, use_cache=True).past_key_values
outputs = model(generated_tokens, past_key_values=past_kv)
loss = criterion(outputs.logits, targets)
loss.backward()
This cuts all gradient paths through the conditional positions, reducing the backward attention FLOP from \(8n^2 d\) to \(8n_g^2 d\) — a ratio of \((n_g/n)^2\), which is significant when \(n_c \gg n_g\). The tradeoff is that \(W_K\) and \(W_V\) no longer receive gradient contributions from conditional positions, slightly altering training dynamics; in practice this is negligible for long prompts.
Gradient checkpointing. Because activation memory scales as \(O(n^2 L)\) (from the attention matrices), long sequences can exhaust GPU memory before the backward pass begins. Gradient checkpointing stores only a sparse set of activations during the forward pass and recomputes the rest on demand during the backward pass. The cost is one extra forward pass per checkpointed segment, raising total training FLOP from \(3C_{\text{fwd}}\) to \(4C_{\text{fwd}}\), while activation memory drops from \(O(nL)\) to \(O(n\sqrt{L})\) or \(O(n)\) depending on granularity. If conditional tokens are detached, their activations need not be checkpointed at all, compounding the savings.
Linear Attention
Forward Pass
Linear attention replaces \(\exp(q^\top k)\) with a kernel approximation \(\phi(q)^\top \phi(k)\) for a positive-valued feature map \(\phi : \mathbb{R}^d \to \mathbb{R}^r\). Using associativity, the attention output at position \(t\) becomes:
\[o_t = \frac{\phi(q_t)^\top \bigl(\sum_{i \leq t} \phi(k_i) v_i^\top\bigr)}{\phi(q_t)^\top \bigl(\sum_{i \leq t} \phi(k_i)\bigr)} = \frac{\phi(q_t)^\top S_t}{\phi(q_t)^\top z_t}\]where \(S_t = \sum_{i \leq t} \phi(k_i) v_i^\top \in \mathbb{R}^{r \times d}\) and \(z_t = \sum_{i \leq t} \phi(k_i) \in \mathbb{R}^r\) are built by scanning left to right. Each step costs \(O(rd)\) to update and \(O(rd)\) to query, making the total attention cost \(O(nrd)\) instead of \(O(n^2 d)\). Substituting into the per-layer FLOP count (with \(r = d\)):
\[C_{\text{layer}}^{\text{linear}} = 24nd^2 + O(nd^2) \approx 24nd^2\]The quadratic term disappears entirely. The tradeoff is approximation quality: the kernel \(\phi(q)^\top \phi(k)\) is a rank-\(r\) approximation to the full softmax attention matrix. Softmax’s sharply peaked distribution is difficult to reproduce with a low-rank kernel, so linear attention models tend to underperform on tasks requiring precise long-range retrieval. Methods like RetNet and Mamba abandon the kernel approximation entirely, adopting structured recurrences with explicit forgetting mechanisms as a different inductive bias.
Linear attention also admits a parallel mode for training: expanding into matrix form, all positions can be computed simultaneously as \(O = (\Phi_Q \Phi_K^\top \odot M) V\) where \(M\) is the causal mask and \(\Phi_Q, \Phi_K \in \mathbb{R}^{n \times r}\) are the feature-mapped queries and keys. This costs \(O(n^2 r)\) — still quadratic in \(n\), but with a factor of \(r/d \ll 1\) compared to softmax attention — and is fully parallelizable across positions, making it GPU-efficient during training.
Figure 2. Step through the panels to see how reordering matrix multiplication eliminates the quadratic attention term.
Backward Pass
Linear attention has a natural dual form: the same computation can be expressed either as a parallel matrix operation (for training) or as a sequential recurrence (for inference), and the two modes have very different backward-pass characteristics.
Backward through the parallel mode. In parallel mode, the causal attention output is \(O = (\Phi_Q \Phi_K^\top \odot M) V\). This is differentiable in the standard way: gradients flow back through the two matrix multiplications without any softmax inversion. There is no \(n \times n\) matrix to transpose during the backward — the Gram matrix \(\Phi_Q \Phi_K^\top \in \mathbb{R}^{n \times r}\) is much cheaper to store than the \(n \times n\) softmax attention matrix. The backward pass cost mirrors the forward: \(O(n^2 r)\) for the masked Gram matrix, plus \(O(nrd)\) for the value aggregation step. Crucially, no \(O(n^2)\) activation needs to be retained for the backward through softmax, because there is no softmax.
Backward through the serial mode. The serial recurrence \(S_t = S_{t-1} + \phi(k_t) v_t^\top\) has a particularly clean gradient structure: since \(S_t\) is a simple running sum, the gradient \(\partial \mathcal{L}/\partial S_t\) propagates additively backward through time. There are no vanishing or exploding gradient paths through the state transition (unlike an RNN with multiplicative gates), though in practice the lack of a forgetting mechanism means information from early positions can accumulate and dilute later context. The per-step backward cost matches the forward: \(O(rd)\) to propagate gradients through the rank-1 update.
Memory. The most concrete advantage over softmax attention is activation memory. Softmax attention requires storing the \(n \times n\) attention weight matrix for the backward pass through softmax, giving \(O(n^2)\) memory per layer. Linear attention in serial mode only needs the running state \(S_t \in \mathbb{R}^{r \times d}\) — a fixed \(O(rd)\) regardless of sequence length. This makes gradient checkpointing largely unnecessary for the attention block at long context, and eliminates the \(O(n^2)\) activation term entirely. The conditional-token detaching trick still applies: by processing the prefix without gradients, the state \(S_{n_c}\) is computed but not retained, and the backward pass only traverses the generated portion of the recurrence.
Inference. The serial recurrence makes autoregressive generation \(O(1)\) per token: the model maintains the fixed-size state \((S_t, z_t)\) and updates it with each new token, with no KV cache that grows with sequence length. This is the primary practical motivation for linear attention — the forward-pass FLOP advantage over softmax only kicks in when \(n \gtrsim 4d\) (roughly 16,000 tokens for \(d = 4096\)), but the \(O(1)\) inference cost is beneficial at any sequence length.
Figure 3. The softmax backward requires storing the full n×n attention matrix; linear attention backward propagates an O(rd) gradient through the serial recurrence.