LLM Optimization Basics: Time

In supervised fine-tuning and RLHF, a model is trained on sequences that mix a conditional prefix (the prompt or instruction) with generated tokens (the target response). The loss is computed only on the generated portion, but the full sequence still flows through the forward pass. This asymmetry — compute everywhere, loss only on some positions — has non-obvious implications for where gradients flow and how much memory is required. This post works through the FLOP arithmetic for both standard (softmax) attention and linear attention, tracing forward and backward costs separately.

在监督微调和 RLHF 中,模型的训练序列混合了条件前缀(提示或指令)与生成 token(目标回复)。损失仅在生成部分上计算,但整个序列仍然要经过前向传播。这种不对称性——所有位置都参与计算,但只有部分位置产生损失——对梯度流向和内存消耗有着非显而易见的影响。本文从 FLOP 算术出发,分别梳理标准(softmax)注意力和线性注意力的前向与反向开销。

Conditional Generation Setup

Let \(n_c\) denote the number of conditional tokens (the prefix) and \(n_g\) the number of generated tokens (the target). The total sequence length is \(n = n_c + n_g\). During training, the entire sequence is packed into a single forward pass under teacher forcing: the model sees the ground-truth token at each position as context, regardless of what it would have predicted.

The cross-entropy loss is then masked to the generated positions:

\[\mathcal{L} = -\frac{1}{n_g} \sum_{t = n_c + 1}^{n_c + n_g} \log p_\theta(x_t \mid x_{<t})\]

Conditional positions contribute to context (they appear as keys and values in attention) but contribute zero to \(\mathcal{L}\).

设 \(n_c\) 为条件 token(前缀)的数量,\(n_g\) 为生成 token(目标)的数量。总序列长度为 \(n = n_c + n_g\)。训练时,整个序列在一次前向传播中以 teacher forcing 方式输入:模型在每个位置都以真实 token 作为上下文,而非自身的预测结果。

交叉熵损失仅在生成位置上计算:

\[\mathcal{L} = -\frac{1}{n_g} \sum_{t = n_c + 1}^{n_c + n_g} \log p_\theta(x_t \mid x_{<t})\]

条件位置为上下文提供信息(作为注意力中的键和值),但对 \(\mathcal{L}\) 的贡献为零。

Transformers: Quadratic Attention

Forward Pass

Consider a transformer with \(L\) layers, hidden dimension \(d\), MLP expansion factor 4, and single-head attention for simplicity. For a sequence of length \(n\), the dominant FLOP terms per layer are:

QKV projection. Each of the three projections is \(n \times d \to n \times d\), costing \(2nd^2\) FLOPs each:

\[\text{QKV} = 3 \times 2nd^2 = 6nd^2\]

Attention scores. Computing \(QK^\top\) is an \((n \times d) \times (d \times n)\) matrix multiply:

\[\text{Scores} = 2n^2 d\]

Weighted sum over values. Multiplying the \(n \times n\) attention weights by \(V \in \mathbb{R}^{n \times d}\):

\[\text{AttnOut} = 2n^2 d\]

Output projection. \(n \times d \to n \times d\):

\[\text{OutProj} = 2nd^2\]

MLP. Two linear layers \(d \to 4d \to d\):

\[\text{MLP} = 16nd^2\]

Summing over one layer and \(L\) layers:

\[C_{\text{fwd}} = L\bigl(24nd^2 + 4n^2 d\bigr)\]

When \(d \gg n\) the linear term dominates; when \(n \gg d\) the quadratic attention term \(4n^2 d\) takes over. Critically, the forward pass computes representations for all \(n = n_c + n_g\) tokens. The loss mask applied afterward does not change what the forward pass touches — every position’s QKV projections are computed, and every key and value vector appears in the attention of all later positions.

Sequence packing. An orthogonal inefficiency arises in batched training when sequences have variable lengths. The naive approach pads each sequence to the maximum length, wasting FLOPs on dummy tokens and inflating the \(n^2\) attention term. Sequence packing (also called sample packing or bin packing) eliminates this by concatenating multiple sequences end-to-end into one dense sequence, separated by end-of-sequence tokens. A block-diagonal attention mask prevents cross-sequence attention: if sequence \(j\) occupies positions \([s_j, e_j)\), then \(M_{ab} = -\infty\) whenever \(a\) and \(b\) belong to different sequences. FlashAttention-2’s varlen mode implements this without materializing the full mask. If the average sequence occupies fraction \(\rho\) of the padded length, packing recovers a factor of \(\rho^2\) on the quadratic attention FLOP — for instruction-tuning datasets with short and variable-length responses, \(\rho\) can be as low as 0.3–0.5, making packing effectively mandatory. Each packed example still has its own conditional prefix and generated suffix, so the loss mask must track both cross-sequence exclusions (via the attention mask) and within-sequence conditional positions (via the loss mask), applied at different points in the computation graph.

考虑一个具有 \(L\) 层、隐藏维度 \(d\)、MLP 扩展因子 4 且使用单头注意力(为简化起见)的 Transformer。对于长度为 \(n\) 的序列,每层的主要 FLOP 项为:

QKV 投影 三个投影各自为 \(n \times d \to n \times d\),每个消耗 \(2nd^2\) FLOPs:

\[\text{QKV} = 3 \times 2nd^2 = 6nd^2\]

注意力分数 计算 \(QK^\top\) 是 \((n \times d) \times (d \times n)\) 的矩阵乘法:

\[\text{Scores} = 2n^2 d\]

值的加权求和。 将 \(n \times n\) 的注意力权重矩阵乘以 \(V \in \mathbb{R}^{n \times d}\):

\[\text{AttnOut} = 2n^2 d\]

输出投影。 \(n \times d \to n \times d\):

\[\text{OutProj} = 2nd^2\]

MLP 两个线性层 \(d \to 4d \to d\):

\[\text{MLP} = 16nd^2\]

对单层求和并扩展到 \(L\) 层:

\[C_{\text{fwd}} = L\bigl(24nd^2 + 4n^2 d\bigr)\]

当 \(d \gg n\) 时线性项占主导;当 \(n \gg d\) 时二次注意力项 \(4n^2 d\) 接管。关键在于,前向传播为所有 \(n = n_c + n_g\) 个 token 计算表示。之后应用的损失掩码不会改变前向传播触及的内容——每个位置的 QKV 投影都会被计算,每个键和值向量都会出现在所有后续位置的注意力中。

序列打包(Sequence packing)。 批量训练中还存在一个正交的低效问题:当序列长度不一时,朴素方法将每个序列填充到最大长度,在无用 token 上浪费 FLOPs 并放大 \(n^2\) 注意力项。序列打包(也称样本打包或 bin packing)通过将多个序列首尾相连拼成一个密集序列来消除这一问题,序列之间用终止符分隔。分块对角注意力掩码阻止跨序列注意力:若序列 \(j\) 占据位置 \([s_j, e_j)\),则 \(a\) 和 \(b\) 属于不同序列时 \(M_{ab} = -\infty\)。FlashAttention-2 的 varlen 模式无需显式构造完整掩码即可实现这一点。若平均序列占填充长度的比例为 \(\rho\),打包可在二次注意力 FLOP 上节省 \(\rho^2\) 的因子——对于指令微调数据集中短且变长的回复,\(\rho\) 可低至 0.3–0.5,使得打包几乎是必须的。每个打包样本仍有各自的条件前缀和生成后缀,因此损失掩码必须同时追踪跨序列排斥(通过注意力掩码)和序列内条件位置(通过损失掩码),二者在计算图中的不同位置施加。

Backward Pass

The backward pass costs approximately 2× the forward pass FLOP:

\[C_{\text{bwd}} \approx 2\,C_{\text{fwd}}, \qquad C_{\text{total}} \approx 3\,C_{\text{fwd}}\]

This factor of two arises because each weight matrix participates in two gradient operations (gradient w.r.t. input and gradient w.r.t. weight). For the attention block, the \(n \times n\) attention weight matrix must be transposed for the backward pass through the weighted sum, adding another \(O(n^2 d)\) term:

\[C_{\text{attn, bwd}} \approx 2 \times \bigl(6nd^2 + 4n^2 d\bigr)\]

The same \(n = n_c + n_g\) appears in both terms — all positions participate in the backward FLOP, not just the generated ones.

Gradient flow. Consider the concrete case \(n_c = 10\), \(n_g = 20\), total \(n = 30\). Positions \(1\text{--}10\) are conditional; positions \(11\text{--}30\) are generated. Since \(\mathcal{L}\) depends only on logits at positions \(11\text{--}30\), backpropagation traces backward only through computation paths that connect to those positions. \(W_O\), \(W_Q\), and the MLP weights are touched only by generated positions. \(W_K\) and \(W_V\) are different: the attention output at each generated position \(t\) is \(o_t = \sum_{i \leq t} \alpha_{ti} v_i\) where \(v_i = W_V h_i\), so backpropagating through \(o_t\) produces gradients w.r.t. \(W_K\) and \(W_V\) at every attended position \(i \leq t\), including all \(n_c\) conditional tokens. The conditional tokens thus receive non-zero gradients for \(W_K\), \(W_V\), and their own input embeddings — despite contributing zero to \(\mathcal{L}\).

反向传播的 FLOP 大约为前向传播的 2 倍

\[C_{\text{bwd}} \approx 2\,C_{\text{fwd}}, \qquad C_{\text{total}} \approx 3\,C_{\text{fwd}}\]

这个 2 倍因子源于每个权重矩阵参与两个梯度运算(对输入的梯度和对权重的梯度)。对于注意力模块,\(n \times n\) 的注意力权重矩阵在反向传播经过加权求和时需要转置,增加了另一个 \(O(n^2 d)\) 项:

\[C_{\text{attn, bwd}} \approx 2 \times \bigl(6nd^2 + 4n^2 d\bigr)\]

两项中出现的都是同一个 \(n = n_c + n_g\)——所有位置都参与反向传播的 FLOP,而非仅生成位置。

梯度流向。 考虑具体情形 \(n_c = 10\),\(n_g = 20\),总长 \(n = 30\)。位置 \(1\text{--}10\) 为条件位置;位置 \(11\text{--}30\) 为生成位置。由于 \(\mathcal{L}\) 仅依赖位置 \(11\text{--}30\) 的 logits,反向传播仅沿连接到这些位置的计算路径回溯。\(W_O\)、\(W_Q\) 和 MLP 权重仅被生成位置触及。\(W_K\) 和 \(W_V\) 则不同:每个生成位置 \(t\) 的注意力输出为 \(o_t = \sum_{i \leq t} \alpha_{ti} v_i\),其中 \(v_i = W_V h_i\),因此对 \(o_t\) 的反向传播会对每个被关注位置 \(i \leq t\) 产生 \(W_K\) 和 \(W_V\) 的梯度,包括所有 \(n_c\) 个条件 token。因此,条件 token 会收到 \(W_K\)、\(W_V\) 及其输入嵌入的非零梯度——尽管它们对 \(\mathcal{L}\) 的贡献为零。

Figure 1. Step through the three panels to trace gradient flow from loss positions back through attention to all tokens.

图 1. 逐步浏览三个面板,追踪梯度从损失位置经注意力流向所有 token 的路径。

Component Positions receiving gradient
Loss \(11\text{--}30\) (generated only)
\(W_O\) (output proj.) \(11\text{--}30\)
\(W_Q\) (query proj.) \(11\text{--}30\)
\(W_K\) (key proj.) \(1\text{--}30\) (all tokens)
\(W_V\) (value proj.) \(1\text{--}30\) (all tokens)
MLP weights \(11\text{--}30\) (through residual from loss)
Input embeddings \(1\text{--}30\) (all tokens, via key/value paths)

Activation memory. The backward pass needs the intermediate activations stored during the forward pass to compute \(\partial \mathcal{L} / \partial W\). For softmax attention, this includes both the pre- and post-softmax \(n \times n\) attention weight matrix (needed for the backward through softmax), giving activation memory per layer of roughly \(O(n^2 + nd)\) — dominated by \(O(n^2)\) at long context. Since \(W_K\) and \(W_V\) receive gradients from all positions, the activations of all \(n_c\) conditional tokens must be retained even though they produce no loss.

Detaching conditional tokens. One common optimization is to process the prefix under torch.no_grad(), computing its KV cache without storing activations, then re-attach for the generated portion:

with torch.no_grad():
    past_kv = model(conditional_tokens, use_cache=True).past_key_values

outputs = model(generated_tokens, past_key_values=past_kv)
loss = criterion(outputs.logits, targets)
loss.backward()

This cuts all gradient paths through the conditional positions, reducing the backward attention FLOP from \(8n^2 d\) to \(8n_g^2 d\) — a ratio of \((n_g/n)^2\), which is significant when \(n_c \gg n_g\). The tradeoff is that \(W_K\) and \(W_V\) no longer receive gradient contributions from conditional positions, slightly altering training dynamics; in practice this is negligible for long prompts.

Gradient checkpointing. Because activation memory scales as \(O(n^2 L)\) (from the attention matrices), long sequences can exhaust GPU memory before the backward pass begins. Gradient checkpointing stores only a sparse set of activations during the forward pass and recomputes the rest on demand during the backward pass. The cost is one extra forward pass per checkpointed segment, raising total training FLOP from \(3C_{\text{fwd}}\) to \(4C_{\text{fwd}}\), while activation memory drops from \(O(nL)\) to \(O(n\sqrt{L})\) or \(O(n)\) depending on granularity. If conditional tokens are detached, their activations need not be checkpointed at all, compounding the savings.

组件 接收梯度的位置
损失 \(11\text{--}30\)(仅生成位置)
\(W_O\)(输出投影) \(11\text{--}30\)
\(W_Q\)(查询投影) \(11\text{--}30\)
\(W_K\)(键投影) \(1\text{--}30\)(所有 token)
\(W_V\)(值投影) \(1\text{--}30\)(所有 token)
MLP 权重 \(11\text{--}30\)(经残差连接从损失传回)
输入嵌入 \(1\text{--}30\)(所有 token,经键/值路径)

激活内存。 反向传播需要前向传播期间保存的中间激活来计算 \(\partial \mathcal{L} / \partial W\)。对于 softmax 注意力,这包括 softmax 前后的 \(n \times n\) 注意力权重矩阵(反向传播经过 softmax 时需要),因此每层的激活内存大致为 \(O(n^2 + nd)\)——在长上下文下由 \(O(n^2)\) 主导。由于 \(W_K\) 和 \(W_V\) 从所有位置接收梯度,即使条件 token 不产生损失,其激活也必须保留。

分离条件 token。 一种常见优化是在 torch.no_grad() 下处理前缀,计算其 KV 缓存但不存储激活,然后对生成部分重新连接梯度:

with torch.no_grad():
    past_kv = model(conditional_tokens, use_cache=True).past_key_values

outputs = model(generated_tokens, past_key_values=past_kv)
loss = criterion(outputs.logits, targets)
loss.backward()

这切断了所有经过条件位置的梯度路径,将反向注意力 FLOP 从 \(8n^2 d\) 降至 \(8n_g^2 d\)——比值为 \((n_g/n)^2\),当 \(n_c \gg n_g\) 时效果显著。代价是 \(W_K\) 和 \(W_V\) 不再接收来自条件位置的梯度贡献,略微改变训练动态;实践中对于长提示这一影响可以忽略。

梯度检查点。 由于激活内存随 \(O(n^2 L)\)(来自注意力矩阵)增长,长序列可能在反向传播开始前就耗尽 GPU 内存。梯度检查点仅在前向传播期间保存一组稀疏的激活,在反向传播期间按需重新计算其余部分。代价是每个检查点段额外进行一次前向传播,将总训练 FLOP 从 \(3C_{\text{fwd}}\) 提高到 \(4C_{\text{fwd}}\),同时激活内存从 \(O(nL)\) 降至 \(O(n\sqrt{L})\) 或 \(O(n)\),取决于粒度。若条件 token 已被分离,其激活完全无需检查点,叠加节省效果。

Linear Attention

Forward Pass

Linear attention replaces \(\exp(q^\top k)\) with a kernel approximation \(\phi(q)^\top \phi(k)\) for a positive-valued feature map \(\phi : \mathbb{R}^d \to \mathbb{R}^r\). Using associativity, the attention output at position \(t\) becomes:

\[o_t = \frac{\phi(q_t)^\top \bigl(\sum_{i \leq t} \phi(k_i) v_i^\top\bigr)}{\phi(q_t)^\top \bigl(\sum_{i \leq t} \phi(k_i)\bigr)} = \frac{\phi(q_t)^\top S_t}{\phi(q_t)^\top z_t}\]

where \(S_t = \sum_{i \leq t} \phi(k_i) v_i^\top \in \mathbb{R}^{r \times d}\) and \(z_t = \sum_{i \leq t} \phi(k_i) \in \mathbb{R}^r\) are built by scanning left to right. Each step costs \(O(rd)\) to update and \(O(rd)\) to query, making the total attention cost \(O(nrd)\) instead of \(O(n^2 d)\). Substituting into the per-layer FLOP count (with \(r = d\)):

\[C_{\text{layer}}^{\text{linear}} = 24nd^2 + O(nd^2) \approx 24nd^2\]

The quadratic term disappears entirely. The tradeoff is approximation quality: the kernel \(\phi(q)^\top \phi(k)\) is a rank-\(r\) approximation to the full softmax attention matrix. Softmax’s sharply peaked distribution is difficult to reproduce with a low-rank kernel, so linear attention models tend to underperform on tasks requiring precise long-range retrieval. Methods like RetNet and Mamba abandon the kernel approximation entirely, adopting structured recurrences with explicit forgetting mechanisms as a different inductive bias.

Linear attention also admits a parallel mode for training: expanding into matrix form, all positions can be computed simultaneously as \(O = (\Phi_Q \Phi_K^\top \odot M) V\) where \(M\) is the causal mask and \(\Phi_Q, \Phi_K \in \mathbb{R}^{n \times r}\) are the feature-mapped queries and keys. This costs \(O(n^2 r)\) — still quadratic in \(n\), but with a factor of \(r/d \ll 1\) compared to softmax attention — and is fully parallelizable across positions, making it GPU-efficient during training.

线性注意力将 \(\exp(q^\top k)\) 替换为核近似 \(\phi(q)^\top \phi(k)\),其中 \(\phi : \mathbb{R}^d \to \mathbb{R}^r\) 是正值特征映射。利用结合律,位置 \(t\) 的注意力输出变为:

\[o_t = \frac{\phi(q_t)^\top \bigl(\sum_{i \leq t} \phi(k_i) v_i^\top\bigr)}{\phi(q_t)^\top \bigl(\sum_{i \leq t} \phi(k_i)\bigr)} = \frac{\phi(q_t)^\top S_t}{\phi(q_t)^\top z_t}\]

其中 \(S_t = \sum_{i \leq t} \phi(k_i) v_i^\top \in \mathbb{R}^{r \times d}\) 和 \(z_t = \sum_{i \leq t} \phi(k_i) \in \mathbb{R}^r\) 通过从左到右扫描构建。每步更新和查询的开销均为 \(O(rd)\),使总注意力开销为 \(O(nrd)\) 而非 \(O(n^2 d)\)。代入每层 FLOP 计数(取 \(r = d\)):

\[C_{\text{layer}}^{\text{linear}} = 24nd^2 + O(nd^2) \approx 24nd^2\]

二次项完全消失。代价是近似质量:核 \(\phi(q)^\top \phi(k)\) 是完整 softmax 注意力矩阵的秩-\(r\) 近似。Softmax 的尖锐分布难以用低秩核复现,因此线性注意力模型在需要精确长程检索的任务上表现往往较差。RetNet 和 Mamba 等方法完全放弃核近似,转而采用带有显式遗忘机制的结构化递推作为不同的归纳偏置。

线性注意力还支持用于训练的并行模式:展开为矩阵形式后,所有位置可同时计算为 \(O = (\Phi_Q \Phi_K^\top \odot M) V\),其中 \(M\) 是因果掩码,\(\Phi_Q, \Phi_K \in \mathbb{R}^{n \times r}\) 是经特征映射的查询和键。开销为 \(O(n^2 r)\)——仍然关于 \(n\) 二次,但与 softmax 注意力相比有 \(r/d \ll 1\) 的因子——且完全可在位置间并行化,在训练时对 GPU 友好。

Figure 2. Step through the panels to see how reordering matrix multiplication eliminates the quadratic attention term.

图 2. 逐步浏览面板,观察矩阵乘法的重排序如何消除二次注意力项。

Backward Pass

Linear attention has a natural dual form: the same computation can be expressed either as a parallel matrix operation (for training) or as a sequential recurrence (for inference), and the two modes have very different backward-pass characteristics.

Backward through the parallel mode. In parallel mode, the causal attention output is \(O = (\Phi_Q \Phi_K^\top \odot M) V\). This is differentiable in the standard way: gradients flow back through the two matrix multiplications without any softmax inversion. There is no \(n \times n\) matrix to transpose during the backward — the Gram matrix \(\Phi_Q \Phi_K^\top \in \mathbb{R}^{n \times r}\) is much cheaper to store than the \(n \times n\) softmax attention matrix. The backward pass cost mirrors the forward: \(O(n^2 r)\) for the masked Gram matrix, plus \(O(nrd)\) for the value aggregation step. Crucially, no \(O(n^2)\) activation needs to be retained for the backward through softmax, because there is no softmax.

Backward through the serial mode. The serial recurrence \(S_t = S_{t-1} + \phi(k_t) v_t^\top\) has a particularly clean gradient structure: since \(S_t\) is a simple running sum, the gradient \(\partial \mathcal{L}/\partial S_t\) propagates additively backward through time. There are no vanishing or exploding gradient paths through the state transition (unlike an RNN with multiplicative gates), though in practice the lack of a forgetting mechanism means information from early positions can accumulate and dilute later context. The per-step backward cost matches the forward: \(O(rd)\) to propagate gradients through the rank-1 update.

Memory. The most concrete advantage over softmax attention is activation memory. Softmax attention requires storing the \(n \times n\) attention weight matrix for the backward pass through softmax, giving \(O(n^2)\) memory per layer. Linear attention in serial mode only needs the running state \(S_t \in \mathbb{R}^{r \times d}\) — a fixed \(O(rd)\) regardless of sequence length. This makes gradient checkpointing largely unnecessary for the attention block at long context, and eliminates the \(O(n^2)\) activation term entirely. The conditional-token detaching trick still applies: by processing the prefix without gradients, the state \(S_{n_c}\) is computed but not retained, and the backward pass only traverses the generated portion of the recurrence.

Inference. The serial recurrence makes autoregressive generation \(O(1)\) per token: the model maintains the fixed-size state \((S_t, z_t)\) and updates it with each new token, with no KV cache that grows with sequence length. This is the primary practical motivation for linear attention — the forward-pass FLOP advantage over softmax only kicks in when \(n \gtrsim 4d\) (roughly 16,000 tokens for \(d = 4096\)), but the \(O(1)\) inference cost is beneficial at any sequence length.

线性注意力具有天然的对偶形式:同一计算既可表示为并行矩阵运算(用于训练),也可表示为串行递推(用于推理),两种模式的反向传播特性截然不同。

并行模式的反向传播。 在并行模式下,因果注意力输出为 \(O = (\Phi_Q \Phi_K^\top \odot M) V\)。这以标准方式可微:梯度通过两个矩阵乘法回传,无需 softmax 求逆。反向传播中无需转置 \(n \times n\) 矩阵——Gram 矩阵 \(\Phi_Q \Phi_K^\top \in \mathbb{R}^{n \times r}\) 的存储开销远低于 \(n \times n\) 的 softmax 注意力矩阵。反向传播开销与前向一致:掩码 Gram 矩阵为 \(O(n^2 r)\),值聚合步骤为 \(O(nrd)\)。关键在于,无需为经过 softmax 的反向传播保留 \(O(n^2)\) 的激活,因为根本没有 softmax。

串行模式的反向传播。 串行递推 \(S_t = S_{t-1} + \phi(k_t) v_t^\top\) 具有特别简洁的梯度结构:由于 \(S_t\) 是简单的累加和,梯度 \(\partial \mathcal{L}/\partial S_t\) 通过时间加性地向后传播。状态转移中不存在梯度消失或爆炸的路径(不同于具有乘性门的 RNN),但实践中缺乏遗忘机制意味着来自早期位置的信息会不断积累并稀释后续上下文。每步反向开销与前向一致:\(O(rd)\) 用于通过秩-1 更新传播梯度。

内存。 相比 softmax 注意力最具体的优势在于激活内存。Softmax 注意力需要存储 \(n \times n\) 的注意力权重矩阵用于经过 softmax 的反向传播,每层占用 \(O(n^2)\) 内存。串行模式的线性注意力仅需运行状态 \(S_t \in \mathbb{R}^{r \times d}\)——固定的 \(O(rd)\),与序列长度无关。这使得长上下文下注意力模块的梯度检查点基本不再必要,并完全消除了 \(O(n^2)\) 的激活项。分离条件 token 的技巧仍然适用:在不计算梯度的情况下处理前缀,状态 \(S_{n_c}\) 被计算但不保留,反向传播仅遍历递推的生成部分。

推理。 串行递推使得自回归生成每 token 开销为 \(O(1)\):模型维护固定大小的状态 \((S_t, z_t)\) 并随每个新 token 更新,无需随序列长度增长的 KV 缓存。这是线性注意力的主要实际动机——前向传播的 FLOP 优势仅在 \(n \gtrsim 4d\)(对于 \(d = 4096\) 大约 16,000 个 token)时才显现,但 \(O(1)\) 的推理开销在任何序列长度下都有益。

Figure 3. The softmax backward requires storing the full n×n attention matrix; linear attention backward propagates an O(rd) gradient through the serial recurrence.

图 3. Softmax 反向传播需要存储完整的 n×n 注意力矩阵;线性注意力反向传播通过串行递推传递 O(rd) 的梯度。