LLM Optimization Basics: Time
In supervised fine-tuning and RLHF, a model is trained on sequences that mix a conditional prefix (the prompt or instruction) with generated tokens (the target response). The loss is computed only on the generated portion, but the full sequence still flows through the forward pass. This asymmetry — compute everywhere, loss only on some positions — has non-obvious implications for where gradients are actually calculated. This post works through the FLOP arithmetic carefully, then answers a concrete backpropagation question.
Conditional Generation Setup
Let \(n_c\) denote the number of conditional tokens (the prefix) and \(n_g\) the number of generated tokens (the target). The total sequence length is \(n = n_c + n_g\). During training, the entire sequence is packed into a single forward pass under teacher forcing: the model sees the ground-truth token at each position as context, regardless of what it would have predicted.
The cross-entropy loss is then masked to the generated positions:
\[\mathcal{L} = -\frac{1}{n_g} \sum_{t = n_c + 1}^{n_c + n_g} \log p_\theta(x_t \mid x_{<t})\]Conditional positions contribute to context (they appear as keys and values in attention) but contribute zero to \(\mathcal{L}\).
FLOP Accounting for the Forward Pass
Consider a transformer with:
- \(L\) layers
- Hidden dimension \(d\)
- MLP expansion factor 4 (intermediate width \(4d\))
- Single-head attention for simplicity (multi-head scales identically)
For a sequence of length \(n\), the dominant FLOP terms per layer are:
QKV projection. Each of the three projections is \(n \times d \to n \times d\), costing \(2nd^2\) FLOPs each (factor of 2 for multiply-accumulate):
\[\text{QKV} = 3 \times 2nd^2 = 6nd^2\]Attention scores. Computing \(QK^\top\) is an \((n \times d) \times (d \times n)\) matrix multiply:
\[\text{Scores} = 2n^2 d\]Weighted sum over values. Multiplying the \(n \times n\) attention weights by \(V \in \mathbb{R}^{n \times d}\):
\[\text{AttnOut} = 2n^2 d\]Output projection. \(n \times d \to n \times d\):
\[\text{OutProj} = 2nd^2\]MLP. Two linear layers: \(d \to 4d \to d\):
\[\text{MLP} = 2 \times 2nd \times 4d = 16nd^2\]Summing over one layer:
\[C_{\text{layer}} = (6 + 2 + 16)nd^2 + 4n^2 d = 24nd^2 + 4n^2 d\]For \(L\) layers, the total forward FLOP is:
\[C_{\text{fwd}} = L\bigl(24nd^2 + 4n^2 d\bigr)\]When \(d \gg n\) (typical for large models with short sequences), the linear term \(24nd^2\) dominates; when \(n \gg d\), the quadratic attention term \(4n^2 d\) takes over.
The Forward Pass is Blind to the Loss Mask
Critically, the forward pass computes representations for all \(n = n_c + n_g\) tokens. The loss mask applied afterward does not change what the forward pass touches. Every position’s QKV projections are computed, and every position’s key and value vectors appear in the attention of all later positions.
Where Backpropagation Calculates Gradients
Concrete Example: \(n_c = 10\), \(n_g = 20\)
Tokens are indexed \(1, 2, \ldots, 30\). Positions \(1\text{--}10\) are conditional; positions \(11\text{--}30\) are generated. The loss \(\mathcal{L}\) is a function of the logits at positions \(11\text{--}30\) only.
Backpropagation traces the gradient backward through the computation graph. The key question is: which quantities — hidden states, weight matrices, and input embeddings — lie in that computation graph?
Gradient w.r.t. the output projection \(W_O\). This projection maps the attention output at position \(t\) to the residual stream. A gradient contribution at \(W_O\) arises whenever position \(t\) is in the backward path from the loss. Since the loss is at positions \(11\text{--}30\), and since each such position passes through the output projection at every layer, \(W_O\) receives gradient contributions from positions \(11\text{--}30\). Positions \(1\text{--}10\) do not directly contribute here.
Gradient w.r.t. the query projection \(W_Q\). The query vector for position \(t\) is \(q_t = W_Q h_t\). A query only generates a gradient when that position’s output appears in the loss path, which again means positions \(11\text{--}30\).
Gradient w.r.t. the key and value projections \(W_K, W_V\). This is where the asymmetry appears. For each generated position \(t \in \{11, \ldots, 30\}\), the attention output is:
\[o_t = \sum_{i=1}^{t} \alpha_{ti} \, v_i, \qquad \alpha_{ti} \propto \exp\!\left(\frac{q_t^\top k_i}{\sqrt{d}}\right)\]Every key \(k_i\) and value \(v_i\) for \(i \leq t\) participates in this computation. Backpropagating through \(o_t\) produces gradients w.r.t. \(k_i = W_K h_i\) and \(v_i = W_V h_i\) for all \(i \in \{1, \ldots, t\}\). Aggregating over all generated positions \(t \in \{11, \ldots, 30\}\):
- Positions \(1\text{--}10\) each appear as keys/values in the attention of positions \(11\text{--}30\) — every generated position attends back to the entire prefix.
- Positions \(11\text{--}30\) appear as keys/values for the generated positions at or after them.
Therefore \(W_K\) and \(W_V\) receive gradient contributions from all 30 positions. The conditional tokens \(1\text{--}10\), despite having zero loss, are fully in the backward computation graph through the key and value paths.
Gradient w.r.t. input embeddings. The embedding of token \(x_i\) at position \(i\) is the root of the residual stream at that position, so it accumulates gradients from every layer and every gradient path that passes through \(h_i\). Since conditional tokens \(1\text{--}10\) serve as keys and values throughout, their embeddings also receive non-zero gradients.
Figure 1. Step through the three panels to trace gradient flow from loss positions back through attention to all tokens.
Summary Table
| Component | Positions receiving gradient |
|---|---|
| Loss | \(11\text{--}30\) (generated only) |
| \(W_O\) (output proj.) | \(11\text{--}30\) |
| \(W_Q\) (query proj.) | \(11\text{--}30\) |
| \(W_K\) (key proj.) | \(1\text{--}30\) (all tokens) |
| \(W_V\) (value proj.) | \(1\text{--}30\) (all tokens) |
| MLP weights | \(11\text{--}30\) (through residual from loss) |
| Input embeddings | \(1\text{--}30\) (all tokens, via key/value paths) |
The distinction matters because gradient accumulation into \(W_K\) and \(W_V\) involves the hidden states of conditional positions, which must therefore be retained in memory during the forward pass — even though those positions contribute nothing to the loss value.
Backward Pass FLOPs
The backward pass cost is approximately 2× the forward pass FLOP for parameter gradients, plus another pass for input gradients (needed to propagate through layers). The standard rule of thumb is:
\[C_{\text{bwd}} \approx 2 \, C_{\text{fwd}}\]giving a total training FLOP per step of roughly \(3 C_{\text{fwd}}\).
This 2× factor arises because each weight matrix appears in both a matrix multiply (forward) and two gradient operations (gradient w.r.t. input and gradient w.r.t. weight). For attention specifically, the \((n \times n)\) attention matrix must also be transposed for the backward pass through the weighted sum, adding another \(O(n^2 d)\) term. The full backward pass through one attention layer costs:
\[C_{\text{attn, bwd}} \approx 2 \times \bigl(6nd^2 + 4n^2 d\bigr)\]where the same \(n\) appears in both terms: all 30 tokens participate in the backward attention FLOP, not just the 20 generated ones.
Memory Implication: Activation Storage
The backward pass needs access to intermediate activations stored during the forward pass (for computing \(\partial \mathcal{L} / \partial W\)). The activations that must be retained include the hidden states at every layer for all \(n\) positions — including the 10 conditional tokens.
Activation memory per token per layer scales as \(O(d)\) for the residual stream plus \(O(nd)\) for the attention weight matrix (which has shape \(n \times n\) and requires storing both the pre- and post-softmax values for the backward through softmax). For \(n = 30, d = 4096, L = 32\):
\[\text{Activation memory} \approx L \cdot n \cdot (d + n) \cdot \text{bytes/element}\]Even though positions \(1\text{--}10\) produce no loss, their activations must be retained for the backward pass through \(W_K\) and \(W_V\).
The Optimization: Detaching Conditional Tokens
One common optimization is to process conditional tokens without gradients, then re-attach for the generated portion:
with torch.no_grad():
past_kv = model(conditional_tokens, use_cache=True).past_key_values
outputs = model(generated_tokens, past_key_values=past_kv)
loss = criterion(outputs.logits, targets)
loss.backward()
Under this approach:
- The KV cache for positions \(1\text{--}10\) is computed once in
no_gradmode — activations are not stored. - Backpropagation only traverses the computation graph for positions \(11\text{--}30\).
- \(W_K\) and \(W_V\) still receive gradients from positions \(11\text{--}30\) (which appear as keys/values for later positions), but the cross-attention path from generated positions back to conditional key/value vectors is cut.
The tradeoff: gradients w.r.t. \(W_K\) and \(W_V\) no longer include the contribution from conditional positions, which may slightly change training dynamics (the model’s key/value representations for the prefix are shaped less by downstream loss signal). In practice, this is often negligible for long prompts where \(n_c \gg n_g\), and the memory savings can be substantial.
FLOP Savings from Detaching
Without detaching, backward FLOP includes the full \(4n^2 d\) attention term with \(n = 30\). With detaching, the backward attention only traverses the generated block of size \(n_g = 20\):
\[C_{\text{attn, bwd, detached}} \approx 2 \times 4n_g^2 d = 8 n_g^2 d\]versus
\[C_{\text{attn, bwd, full}} \approx 2 \times 4n^2 d = 8 n^2 d\]The ratio is \((n_g / n)^2 = (20/30)^2 \approx 0.44\). For very short generations relative to long prompts (e.g., chain-of-thought verification where the prompt is large), this reduction is much more significant.
Gradient Checkpointing and the Memory-FLOP Tradeoff
Because activations scale as \(O(nLd)\), long sequences can exhaust GPU memory before the backward pass begins. Gradient checkpointing (also called activation recomputation) addresses this by storing only a sparse set of checkpoints during the forward pass and recomputing the intermediate activations on demand during the backward pass.
The cost: one extra forward pass per checkpointed segment, increasing total training FLOP by roughly \(1/3\) (from \(3 C_{\text{fwd}}\) to \(4 C_{\text{fwd}}\)). The benefit: activation memory drops from \(O(nL)\) to \(O(n\sqrt{L})\) or \(O(n)\) depending on checkpointing granularity.
For conditional generation, the decision of where to place checkpoints interacts with the detaching optimization: if conditional tokens are detached, their activations need not be checkpointed at all, further reducing memory overhead.