Gradient Flow in Conditional Generation

4 conditional tokens  ·  4 generated tokens  ·  one attention layer

Step 1 of 3 — sequence & loss

conditional  (n, = 4) generated  (nᵧ = 4) C 1 C 2 C 3 C 4 G 1 G 2 G 3 G 4 loss region
The cross-entropy loss is computed only on the generated positions: $$\mathcal{L} = -\frac{1}{n_g} \sum_{t \in \mathcal{G}} \log p_\theta\!\left(x_t \mid x_{<t}\right)$$ Conditional tokens contribute zero to $\mathcal{L}$. Yet the full sequence of length $n = n_c + n_g$ passes through every transformer layer during the forward pass.
v₁ v₂ v₃ v₄ v₅ v₆ vᵖ = Wᵝ hᵖ  (value projection at each attended position) C 1 C 2 C 3 C 4 G 1 G 2 G 3 (query) G 4 masked
$G_3$ issues a query that attends over all $n_c + (t_{G_3} - 1)$ preceding tokens, including every conditional position. Its output aggregates their value vectors: $$o_{G_3} = \sum_{i \,\leq\, t_{G_3}} \alpha_{G_3,\, i} \cdot \underbrace{W_V \, h_i}_{\text{value at position }i}$$ $W_V$ is applied to $h_i$ at every attended position — including $h_1, \ldots, h_{n_c}$. This is the key dependency that determines where gradients flow.
C 1 C 2 C 3 C 4 G 1 G 2 G 3 G 4 ∂ℒ/∂Wᵎ & ∂ℒ/∂Wᵔ via K/V paths
The gradient of $\mathcal{L}$ with respect to $W_V$ sums over all attended positions at each generated step: $$\frac{\partial \mathcal{L}}{\partial W_V} = \sum_{t \in \mathcal{G}} \;\sum_{i=1}^{t} \frac{\partial \mathcal{L}}{\partial o_t} \cdot \alpha_{t,i} \cdot h_i^\top$$ Because $i$ ranges over $1, \ldots, n_c$ as well, every conditional $h_i$ contributes a non-zero outer-product to $\partial\mathcal{L}/\partial W_V$. The same holds for $W_K$ through the key path. $W_Q$, $W_O$, and the MLP weights only accumulate from generated positions, since only those positions generate loss.
Weight Gradient from $\mathcal{C}$  (conditional) Gradient from $\mathcal{G}$  (generated)
$W_Q$ ✗  none ✓  yes
$W_K$ ✓  yes ✓  yes
$W_V$ ✓  yes ✓  yes
$W_O$, MLP ✗  none ✓  yes
Consequence for memory.  Because $h_i$ for $i \in \mathcal{C}$ must be available during the backward pass (to form $h_i^\top$ in $\partial\mathcal{L}/\partial W_V$), the activations of all $n_c$ conditional tokens must be retained after the forward pass — even though they produce no loss. Processing the prefix under torch.no_grad() severs this dependency, removing $\mathcal{C}$ from the gradient graph and eliminating their activation storage cost.