Adaptive Sampling and Curriculum Methods
The Problem: Group Sampling and Signal Recovery
问题:组采样与信号恢复
Why group sampling loses signal
组采样为何丢失信号
Notation. Throughout this post: \(x\) denotes a prompt (the input question), \(a\) denotes a response (a model-generated rollout), \(\pi_\theta(a \vert x)\) is the policy (LLM), and \(r(x, a) \in \{0, 1\}\) is a binary reward (correct or not). For each prompt, we sample a group of \(n\) responses and use their rewards to estimate a learning signal.
RAFT, RLOO, GRPO, and online DPO all rely on this group sampling to estimate learning signals — for each prompt, sample a group of \(n\) responses, then use the reward distribution within the group to construct a gradient. But they all share a blind spot: no learning signal on prompts with uniform rewards.
To see why, let \(A^+\) and \(A^-\) denote the correct and incorrect rollouts within a group. The RAFT gradient only trains on correct rollouts (rejection sampling):
\[g_\theta^{RAFT}(x) = \sum_{a_i \in A^+} \nabla_\theta \log \pi_\theta(a_i \vert x) - 0 \cdot \sum_{a_j \in A^-} \nabla_\theta \log \pi_\theta(a_j \vert x)\]The GRPO gradient uses the group mean \(\mu_r\) and standard deviation \(\sigma_r\) to normalize rewards into per-response weights:
\[g_\theta^{GRPO}(x) = \sum_{a_i \in A} \frac{r(x, a_i) - \mu_r}{\sigma_r + \epsilon} \cdot \nabla_\theta \log \pi_\theta(a_i \vert x)\]If all \(n\) responses in the group are correct (\(r = 1\) for all), then \(\mu_r = 1\) and \(\sigma_r = 0\) — every GRPO weight is \(0/0\), and the gradient vanishes. Similarly, if all are wrong, \(\mu_r = 0\) and \(\sigma_r = 0\). For RAFT, if no response is correct, \(A^+\) is empty and there is literally nothing to train on. For DPO, which needs a winner-loser pair: if all responses have the same reward, no pair can be formed.
During GRPO training, the ratio of “effective prompts” (those with mixed rewards in their group) fluctuates wildly and can be as low as 30% — the model masters easy prompts quickly (all-correct groups), while still failing on hard ones (all-wrong groups). In both cases, compute is wasted with zero learning signal.
Larger groups help, but uniform scaling is too expensive
增大组大小有帮助,但均匀增大太昂贵
A larger group size uncovers more non-trivial reward signals. Pass@K curves show that with \(K = 256\), models can solve 81.3% of problems, and increasing \(K\) from 4 to 32 adds 16.9% coverage, while going from 32 to 256 adds only 1.7%.
However, uniformly enlarging group size is slow: GRPO-n16 achieves the highest final reward in training steps but is the slowest in wall-clock time.
Different prompts need different budgets
不同 prompt 需要不同的采样预算
The Pass@K data reveals huge variance across prompts. We can partition prompts into three tiers based on when they first yield a correct answer:
| Tier | Solved by | Fraction of prompts | Samples needed |
|---|---|---|---|
| Easy | \(K = 4\) | 62.7% | 4 |
| Medium | \(K = 32\) (but not 4) | 16.9% | 32 |
| Hard | \(K = 256\) (but not 32) | 1.7% | 256 |
With an adaptive strategy — start with \(K = 4\), escalate to \(K = 32\) only for prompts that got no correct answer, then to \(K = 256\) for the remaining — the expected average cost per prompt is:
\[4 \times 0.627 + 32 \times 0.169 + 256 \times 0.017 \approx 2.5 + 5.4 + 4.4 = 12.3 < 13\]Compare this to uniform \(K = 256\) for all prompts, which costs 256 per prompt — adaptive sampling achieves similar coverage at \(\sim 5\%\) of the cost. The principle: only increase group size when extra samples can reveal new learning signals.
Reinforce-Ada: Adaptive Sequential Sampling
Reinforce-Ada:自适应顺序采样
The algorithm: sample until useful signal appears
算法:采样直到有用信号出现
Since we don’t know in advance how many samples a prompt requires, the idea is to keep sampling until enough useful learning signal is recovered. Reinforce-Ada-Seq implements this as a sequential process: (1) initialize all prompts as active with a first batch of rollouts; (2) check whether each prompt already has enough signal; (3) keep sampling only the unresolved prompts, stopping at a maximum budget. Different stopping rules realize different implicit allocation schedules.
Stopping rules and implicit allocation
停止规则与隐式分配
Two stopping rules are proposed, each with a different philosophy:
Ada-Seq-Pos: stop after collecting \(k\) positive (correct) samples. Since each sample is correct with probability \(p\), the expected number of samples needed to see \(k\) successes follows a negative binomial distribution:
\[\mathbb{E}[n] = \frac{k}{p}\]This means easy prompts (high \(p\)) stop quickly, while hard prompts (low \(p\)) get more samples — naturally focusing compute on difficult problems. For example, with \(k = 4\): a prompt with \(p = 0.8\) needs only \(\sim 5\) samples on average, while a prompt with \(p = 0.05\) needs \(\sim 80\).
What is the negative binomial distribution? (Click to expand)
The negative binomial distribution models the number of i.i.d. Bernoulli trials needed to observe exactly \(k\) successes, where each trial succeeds with probability \(p\). If \(n\) is the total number of trials, then:
$$\Pr(n = m) = \binom{m-1}{k-1} p^k (1-p)^{m-k}, \quad m = k, k+1, \ldots$$
The expected value is \(\mathbb{E}[n] = k/p\). Intuitively: each success takes on average \(1/p\) trials, and we need \(k\) of them. The variance is \(\mathrm{Var}(n) = k(1-p)/p^2\), so harder prompts (small \(p\)) have both higher expected cost and higher variance in cost.
Ada-Seq-Balance: stop after collecting \(k\) positive and \(k\) negative samples. Here the expected stopping time is:
\[\mathbb{E}[n] \propto \frac{1}{p(1-p)}\]This allocates the most samples to prompts with extreme pass rates (\(p \approx 0\) or \(p \approx 1\)), where it is hardest to find both positive and negative examples. Prompts near \(p = 0.5\) stop fastest because both outcomes appear frequently. The function \(1/(p(1-p))\) is U-shaped — minimal at \(p = 0.5\) (value = 4) and diverging as \(p \to 0\) or \(p \to 1\).
Infer more, train less
多推理、少训练
Training on all samples from adaptive sampling can be unstable: stochastic batch sizes cause OOM, too many negative samples can hurt or even collapse training, and training is more expensive than inference. The figure below compares two strategies on Qwen2.5-Math-1.5B: the blue curve uses coupled allocation (\(n_{\text{sample}} = 1/\hat{p},\; n_{\text{grad}} = 1/\sqrt{\hat{p}}\), where both sampling and training scale with difficulty), while the red curve decouples them (\(n_{\text{sample}} = 1/p,\; n_{\text{grad}} = 1\), i.e. adaptive sampling but fixed-size training). The decoupled approach trains significantly better.
The fix: decouple the sampling budget from the training budget. Adaptive sampling may generate many rollouts for a single hard prompt (e.g., 200 samples to find 4 correct ones), but we only need a small balanced subset for training. Specifically, for each prompt, sub-sample a fixed-size training batch of \(n/2\) positive and \(n/2\) negative samples:
\[\text{Sampling budget (variable): } n_i^{\text{sample}} \propto \frac{1}{p_i(1-p_i)} \quad \longrightarrow \quad \text{Training batch (fixed): } n^{\text{train}} = n\]This “infer more, train less” design has three benefits: (1) fixed batch size avoids OOM from stochastic sizes; (2) balanced positive/negative ratio stabilizes training; (3) inference is much cheaper than training (forward pass only, no backward pass or optimizer step), so extra sampling is cost-effective.
Why is decoupling mathematically valid? The concern is: if we sub-sample \(k\) positive and \(k\) negative from a larger pool, does the gradient estimate degrade? The answer is no — the quality of the gradient estimate depends on the training batch composition, not on how many samples were needed to find it.
Concretely, after Ada-Seq-Balance collects enough rollouts for prompt \(x_i\), we sub-sample \(k\) positive responses \(\{a_j^+\}\) and \(k\) negative responses \(\{a_j^-\}\). The per-prompt gradient estimate becomes:
\[\hat{h}_i = \frac{1}{2k}\left[\sum_{j=1}^{k}(1-p_i)\,\nabla_\theta \log \pi(a_j^+ \vert x_i) + \sum_{j=1}^{k}(-p_i)\,\nabla_\theta \log \pi(a_j^- \vert x_i)\right]\]The variance of this estimate (for a fixed \(k\)) comes from two independent sources:
\[\mathrm{Var}(\hat{h}_i) = \frac{(1-p_i)^2}{k}\,\mathrm{Var}[\nabla \log \pi \vert r=1] + \frac{p_i^2}{k}\,\mathrm{Var}[\nabla \log \pi \vert r=0]\]This variance depends only on \(k\) (the training batch size) and the prompt’s pass rate \(p_i\) — it does not depend on \(n_i^{\text{sample}}\) (how many total rollouts were needed to find those \(k\) positives and \(k\) negatives). The extra samples collected during the adaptive search were necessary to find the \(k\) positive examples, but they contribute nothing additional to the gradient once found.
The cost decomposition makes this concrete:
\[\text{Total cost per prompt} = \underbrace{n_i^{\text{sample}} \cdot c_{\text{infer}}}_{\text{adaptive search}} + \underbrace{2k \cdot c_{\text{train}}}_{\text{fixed training}}\]Since \(c_{\text{train}} \gg c_{\text{infer}}\) (backward pass + optimizer vs. forward pass only), fixing the training batch at \(2k\) saves far more than the extra inference costs.
Theoretical Framework: Nonlinear Objectives and Variance Reduction
理论框架:非线性目标与方差缩减
Nonlinear objectives induce prompt-dependent weights
非线性目标引入依赖于 prompt 的权重
Instead of optimizing the standard pass rate \(p_\theta(x)\) directly, consider a general nonlinear objective:
\[J_f(\theta) = \mathbb{E}_{x \sim d}\left[f\!\left(p_\theta(x)\right)\right]\]where \(p_\theta(x) = \mathbb{E}_{a \sim \pi_\theta(\cdot \vert x)}[r(x, a)]\) is the pass rate of the current policy on prompt \(x\), and \(f: [0,1] \to \mathbb{R}\) is a differentiable transform. The key insight is that by the chain rule, the policy gradient of \(J_f\) introduces a prompt-dependent weight:
\[\nabla_\theta J_f(\theta) = \mathbb{E}_{x \sim d}\left[ f'(p_\theta(x)) \cdot \nabla_\theta p_\theta(x) \right]\]Expanding \(\nabla_\theta p_\theta(x)\) via the standard policy gradient theorem (\(\nabla_\theta \mathbb{E}_{a \sim \pi_\theta}[r] = \mathbb{E}_{a \sim \pi_\theta}[r \cdot \nabla_\theta \log \pi_\theta]\)), we get the full per-sample form:
\[\nabla_\theta \mathbb{E}_{x \sim d} f(p_\theta(x)) = \mathbb{E}_{x \sim d}\left[ f'(p_\theta(x)) \cdot \mathbb{E}_{a \sim \pi_\theta}\left[ r(x, a) \cdot \nabla_\theta \log \pi_\theta(a \vert x) \right] \right]\]The derivative \(f'(p)\) acts as a per-prompt importance weight — different choices of \(f\) prioritize different difficulty levels. For example:
- Linear \(f(p) = p\): then \(f'(p) = 1\), all prompts are weighted equally (standard objective).
- Log \(f(p) = \log p\): then \(f'(p) = 1/p\), hard prompts (low \(p\)) get upweighted. This corresponds to RAFT’s implicit objective.
- GRPO’s implicit transform \(f(p) = 2\arcsin(\sqrt{p})\): applying the chain rule,
This is the same U-shaped function as Ada-Seq-Balance’s allocation! Prompts with extreme pass rates (near 0 or 1) get upweighted, while \(p \approx 0.5\) prompts get the least weight. This provides a principled framework: different RL objectives implicitly prioritize prompts differently, and adaptive sampling can exploit this structure for variance reduction.
Why arcsin? This is not a design choice but a mathematical consequence of GRPO’s per-response normalization. The derivation proceeds in three steps.
Step 1: From GRPO weights to per-prompt gradient contribution. Recall the GRPO gradient weights each response by \((r_i - \mu_r)/\sigma_r\). Since \(\sigma_r\) is constant within a group, we can pull it out of the expectation. For a prompt with pass rate \(p\) (so \(\mu_r = p\)):
\[\mathbb{E}\!\left[\frac{r - \mu_r}{\sigma_r} \nabla_\theta \log \pi\right] = \frac{1}{\sigma_r}\,\mathbb{E}\!\left[(r - p)\,\nabla_\theta \log \pi\right]\]Now we split the expectation:
\[\mathbb{E}[(r - p)\,\nabla_\theta \log \pi] = \underbrace{\mathbb{E}[r\,\nabla_\theta \log \pi]}_{\text{(A)}} - p\,\underbrace{\mathbb{E}[\nabla_\theta \log \pi]}_{\text{(B)}}\]For (A), by the REINFORCE log-derivative trick: \(\mathbb{E}_{a \sim \pi}[r(x,a)\,\nabla_\theta \log \pi_\theta(a \vert x)] = \nabla_\theta \mathbb{E}_{a \sim \pi}[r(x,a)] = \nabla_\theta p_\theta(x)\).
For (B), the score function has zero mean: \(\mathbb{E}_{a \sim \pi}[\nabla_\theta \log \pi_\theta(a \vert x)] = \nabla_\theta \sum_a \pi_\theta(a \vert x) = \nabla_\theta 1 = 0\). This is the classic result that a constant baseline does not change the expected gradient.
REINFORCE / log-derivative trick (Click to expand)
The identity \(\nabla_\theta \mathbb{E}_{a \sim \pi_\theta}[f(a)] = \mathbb{E}_{a \sim \pi_\theta}[f(a)\,\nabla_\theta \log \pi_\theta(a)]\) follows from:
$$\nabla_\theta \mathbb{E}_{a \sim \pi_\theta}[f(a)] = \nabla_\theta \sum_a \pi_\theta(a)\,f(a) = \sum_a f(a)\,\nabla_\theta \pi_\theta(a) = \sum_a f(a)\,\pi_\theta(a)\,\frac{\nabla_\theta \pi_\theta(a)}{\pi_\theta(a)} = \mathbb{E}_{a \sim \pi_\theta}\!\left[f(a)\,\nabla_\theta \log \pi_\theta(a)\right]$$
The key step is multiplying and dividing by \(\pi_\theta(a)\) to turn the sum into an expectation. This allows us to estimate policy gradients via sampling, without knowing the normalizing constant of \(\pi_\theta\). In RL, \(f(a) = r(x, a)\) gives the REINFORCE estimator (Williams, 1992).
Why does a constant baseline not change the gradient? (Click to expand)
For any constant \(b\) (not depending on \(a\)):
$$\mathbb{E}_{a \sim \pi_\theta}[b\,\nabla_\theta \log \pi_\theta(a)] = b\,\nabla_\theta \sum_a \pi_\theta(a) = b\,\nabla_\theta 1 = 0$$
So \(\mathbb{E}[(r - b)\nabla \log \pi] = \mathbb{E}[r\,\nabla \log \pi] - b \cdot 0 = \mathbb{E}[r\,\nabla \log \pi]\). The baseline \(b\) vanishes in expectation but reduces variance because it makes the per-sample weights \(r - b\) smaller in magnitude. The optimal baseline (minimizing variance) is \(b^* = \mathbb{E}[r \lVert\nabla \log \pi\rVert^2] / \mathbb{E}[\lVert\nabla \log \pi\rVert^2]\), but any constant works — here we use \(b = p\).
So the baseline \(p\) vanishes, and with \(\sigma_r = \sqrt{p(1-p)}\) for binary rewards:
\[\frac{1}{\sigma_r}\,\mathbb{E}[(r - p)\,\nabla_\theta \log \pi] = \frac{1}{\sqrt{p(1-p)}} \cdot \nabla_\theta p_\theta(x)\]Step 2: Read off \(f'(p)\) and integrate. Comparing with the nonlinear framework \(f'(p) \cdot \nabla_\theta p\), we identify \(f'(p) = 1/\sqrt{p(1-p)}\). To recover \(f\), we integrate using the substitution \(p = \sin^2\theta\), so \(dp = 2\sin\theta\cos\theta\,d\theta\) and \(\sqrt{p(1-p)} = \sin\theta\cos\theta\):
\[f(p) = \int \frac{dp}{\sqrt{p(1-p)}} = \int \frac{2\sin\theta\cos\theta\,d\theta}{\sin\theta\cos\theta} = 2\theta = 2\arcsin(\sqrt{p})\]Step 3: Connection to classical statistics. The function \(2\arcsin(\sqrt{p})\) is exactly the classical arcsine (angular) transformation from statistics (Fisher, 1940s), originally designed for variance stabilization of binomial proportions. Its key property: if \(\hat{p}\) is a sample proportion from \(n\) Bernoulli trials, then \(2\arcsin(\sqrt{\hat{p}})\) has variance \(\approx 1/n\) regardless of the true \(p\). In other words, GRPO’s per-response normalization implicitly applies the unique transformation that equalizes gradient variance across all difficulty levels — which is precisely why uniform sampling is already optimal for GRPO.
Optimal allocation via variance reduction
通过方差缩减得到最优分配
In practice, we need to estimate the gradient from finite samples. The batch gradient estimator is derived in three steps from the population gradient \(\nabla_\theta J_f = \mathbb{E}_{x}[f'(p) \cdot \nabla_\theta p]\):
Step 1. Replace the expectation over prompts \(\mathbb{E}_{x \sim d}\) with a sample average over \(B\) prompts \(\{x_1, \ldots, x_B\}\):
\[\nabla_\theta J_f \approx \frac{1}{B} \sum_{i=1}^{B} f'(p_i) \cdot \nabla_\theta p_\theta(x_i)\]Step 2. For each prompt \(x_i\), estimate \(\nabla_\theta p_\theta(x_i)\) using the REINFORCE trick with \(n_i\) sampled rollouts \(\{a_{i1}, \ldots, a_{in_i}\}\):
\[\nabla_\theta p_\theta(x_i) = \mathbb{E}_{a \sim \pi_\theta}\!\left[r(x_i, a)\,\nabla_\theta \log \pi_\theta(a \vert x_i)\right] \approx \frac{1}{n_i} \sum_{j=1}^{n_i} r_{ij}\,\nabla_\theta \log \pi_\theta(a_{ij} \vert x_i)\]Step 3. Subtract the baseline \(p_i\) from each reward. As shown earlier, \(\mathbb{E}[\nabla_\theta \log \pi] = 0\), so \(p_i\) does not change the expected gradient but reduces variance:
\[\nabla_\theta p_\theta(x_i) \approx \frac{1}{n_i} \sum_{j=1}^{n_i} (r_{ij} - p_i)\,\nabla_\theta \log \pi_\theta(a_{ij} \vert x_i)\]Combining all three steps gives the batch estimator:
\[\hat{g}_{\text{batch}} = \frac{1}{B} \sum_{i=1}^{B} \frac{f'(p_i)}{n_i} \sum_{j=1}^{n_i} \nabla_\theta \log \pi_\theta(a_{ij} \vert x_i) \left(r_{ij} - p_i\right)\]Each prompt contributes with weight \(f'(p_i)\) (from the nonlinear objective), and within each prompt, \(n_i\) rollouts provide a Monte Carlo estimate of \(\nabla_\theta p\).
Now we derive the variance of this estimator step by step.
Step 1: Decompose into per-prompt contributions. Write the estimator as \(\hat{g}_{\text{batch}} = \frac{1}{B}\sum_{i=1}^{B} f'(p_i) \cdot \hat{h}_i\), where \(\hat{h}_i = \frac{1}{n_i}\sum_{j=1}^{n_i}(r_{ij} - p_i)\,\nabla_\theta \log \pi_\theta(a_{ij} \vert x_i)\) is the per-prompt gradient estimate. Since prompts are sampled independently, the variance of a sum of independent terms equals the sum of variances:
\[\mathrm{Var}(\hat{g}_{\text{batch}}) = \frac{1}{B^2}\sum_{i=1}^{B} f'(p_i)^2 \cdot \mathrm{Var}(\hat{h}_i)\]Step 2: Variance of the per-prompt estimate. Within each prompt, the \(n_i\) rollouts are i.i.d. from \(\pi_\theta(\cdot \vert x_i)\). For i.i.d. samples, the variance of a sample mean is \(1/n_i\) times the variance of a single sample:
\[\mathrm{Var}(\hat{h}_i) = \frac{1}{n_i}\,\mathrm{Var}_{a \sim \pi_\theta}\!\left[(r(x_i, a) - p_i)\,\nabla_\theta \log \pi_\theta(a \vert x_i)\right] \;=\; \frac{\sigma_g^2(x_i)}{n_i}\]where \(\sigma_g^2(x_i)\) denotes the single-sample gradient variance for prompt \(x_i\).
Step 3: Evaluate \(\sigma_g^2(x_i)\) for binary rewards. Expanding the variance:
\[\sigma_g^2(x_i) = \mathbb{E}\!\left[(r - p_i)^2\,\lVert\nabla_\theta \log \pi\rVert^2\right] - \lVert\underbrace{\mathbb{E}[(r - p_i)\,\nabla_\theta \log \pi]}_{= \nabla_\theta p_\theta(x_i)}\rVert^2\]The second term \(\lVert\nabla_\theta p\rVert^2\) is typically small relative to the first (this is why we need variance reduction in the first place). Dropping it and focusing on the dominant term:
\[\sigma_g^2(x_i) \approx \mathbb{E}\!\left[(r - p_i)^2\right] \cdot \mathbb{E}\!\left[\lVert\nabla_\theta \log \pi\rVert^2\right]\]For binary \(r \in \{0,1\}\) with \(\Pr(r=1) = p_i\):
\[\mathbb{E}[(r - p_i)^2] = p_i(1-p_i)^2 + (1-p_i)\,p_i^2 = p_i(1-p_i)\]This is simply the Bernoulli variance. Assuming the gradient norm \(\mathbb{E}[\lVert\nabla_\theta \log \pi\rVert^2]\) is roughly similar across prompts (a standard simplifying assumption), we get \(\sigma_g^2(x_i) \propto p_i(1-p_i)\).
Why is the Bernoulli variance \(p(1-p)\)? (Click to expand)
For a Bernoulli random variable \(r \in \{0, 1\}\) with \(\Pr(r=1) = p\), the variance is computed directly from the definition \(\mathrm{Var}(r) = \mathbb{E}[r^2] - (\mathbb{E}[r])^2\). Since \(r^2 = r\) (because \(0^2=0\) and \(1^2=1\)), we have \(\mathbb{E}[r^2] = \mathbb{E}[r] = p\), so:
$$\mathrm{Var}(r) = p - p^2 = p(1-p)$$
This is maximized at \(p = 0.5\) (where \(\mathrm{Var} = 0.25\)) and zero at \(p = 0\) or \(p = 1\). For the centered version \(r - p\), the same holds: \(\mathbb{E}[(r-p)^2] = p(1-p)^2 + (1-p)p^2 = p(1-p)[(1-p)+p] = p(1-p)\).
Putting it all together. Substituting back (and dropping the constant \(1/B^2\), which does not affect the allocation optimization):
\[V \propto \sum_{i=1}^{B} \frac{f'(p_i)^2 \cdot p_i(1-p_i)}{n_i}\]Given a total budget of \(N = \sum_i n_i\) samples across \(B\) prompts, we want to minimize this total gradient variance:
\[\min_{n_1, \ldots, n_B} \sum_{i=1}^{B} \frac{f'(p_i)^2 \cdot p_i(1-p_i)}{n_i} \quad \text{subject to} \quad \sum_{i=1}^{B} n_i = N, \quad n_i \geq 1\]Solving the optimization. By the method of Lagrange multipliers, introduce \(\lambda\) for the budget constraint and set the partial derivative to zero:
\[\frac{\partial}{\partial n_i}\left[\frac{f'(p_i)^2\,p_i(1-p_i)}{n_i} + \lambda\,n_i\right] = -\frac{f'(p_i)^2\,p_i(1-p_i)}{n_i^2} + \lambda = 0\]Solving: \(n_i^\star = \frac{1}{\sqrt{\lambda}}\,f'(p_i)\,\sqrt{p_i(1-p_i)}\), i.e.,
\[n_i^\star \propto f'(p_i) \cdot \sqrt{p_i(1-p_i)}\]The principle: allocate more samples to prompts that contribute more variance to the gradient — both because they are heavily weighted (large \(f'(p_i)\)) and because their reward signal is noisy (\(p_i\) near 0.5).
Lagrange multipliers for constrained optimization (Click to expand)
The method of Lagrange multipliers solves problems of the form: minimize \(f(x_1, \ldots, x_n)\) subject to \(g(x_1, \ldots, x_n) = 0\). The idea: at a constrained optimum, the gradient of \(f\) must be parallel to the gradient of \(g\) (otherwise we could improve \(f\) while staying on the constraint surface). So we solve \(\nabla f = \lambda \nabla g\) for some scalar \(\lambda\).
Here, \(f = \sum_i c_i / n_i\) (with \(c_i = f'(p_i)^2 p_i(1-p_i)\)) and \(g = \sum_i n_i - N = 0\). Setting \(\partial f / \partial n_i + \lambda \cdot \partial g / \partial n_i = 0\) gives \(-c_i / n_i^2 + \lambda = 0\), so \(n_i = \sqrt{c_i / \lambda}\). Since all \(n_i\) share the same \(\lambda\), we get \(n_i \propto \sqrt{c_i}\).
Concrete allocation rules for different objectives
不同目标函数的具体分配规则
Substituting different \(f\) into the optimal allocation formula \(n_i^\star \propto f'(p_i) \cdot \sqrt{p_i(1-p_i)}\) yields concrete allocation rules:
\(f(p) = \log p\) (RAFT’s implicit objective): \(f'(p) = 1/p\), so
\[n_i \propto \frac{1}{p} \cdot \sqrt{p(1-p)} = \sqrt{\frac{1-p}{p}} \approx \sqrt{\frac{1}{p}} \quad \text{(when } p \text{ is small)}\]Hard prompts (small \(p\)) get many more samples — a prompt with \(p = 0.01\) gets \(10\times\) the samples of one with \(p = 1\).
\(f(p) = \log p\) with baseline (policy gradient): The baseline changes the effective variance. The allocation becomes \(n_i \propto \sqrt{(1-p)/p}\), which also favors hard prompts but less aggressively than RAFT without baseline.
\(f(p) = p^\alpha\) (power objective): \(f'(p) = \alpha p^{\alpha-1}\), so
\[n_i \propto p^{\alpha-1} \cdot \sqrt{p(1-p)}\]For \(\alpha > 1\), this upweights easy prompts; for \(\alpha < 1\), it upweights hard prompts.
\(f(p) = 2\arcsin(\sqrt{p})\) (GRPO): \(f'(p) = 1/\sqrt{p(1-p)}\), so
\[n_i \propto \frac{1}{\sqrt{p(1-p)}} \cdot \sqrt{p(1-p)} = 1\]The weight and the Bernoulli variance exactly cancel — uniform allocation is already variance-optimal for GRPO! This is a remarkable result: GRPO’s implicit nonlinear objective is precisely the one for which no adaptive allocation can improve upon uniform sampling.
Experiments: Comparing Allocation Strategies
实验:对比分配策略
logp-VR vs. Ada-Seq-Balance
logp-VR 与 Ada-Seq-Balance 的对比
Comparing the two approaches with per-prompt budget clip at 256:
- logp-VR: variance-reduction allocation with \(n_i \propto \sqrt{(1-p)/p}\), average per-prompt budget \(n = 128\). Both sampling and training counts follow the same allocation curve. Average cost = 512 samples.
- Ada-Seq-Balance: \(n/2\) training batch (fixed), \(2n\) maximal sampling budget. Sampling count follows the U-shaped \(1/(p(1-p))\) curve, but training size stays constant. Average cost = 383 samples.
The key difference is whether sampling and training are coupled or decoupled:
-
logp-VR (coupled): for each prompt, sample \(n_i\) rollouts and train on all \(n_i\) of them. A hard prompt with \(p = 0.01\) might get \(n_i = 200\) samples, and the training step processes all 200. This means both sampling cost and training cost scale together (left plot: the two curves overlap). Total cost per prompt = \(n_i^{\text{sample}} + n_i^{\text{train}} = 2n_i\), so average cost = \(2 \times 256 = 512\).
-
Ada-Seq-Balance (decoupled): for each prompt, sample adaptively until \(k\) positive and \(k\) negative rollouts are found (sampling count varies with difficulty), but then sub-sample a fixed-size training batch of \(n\) rollouts regardless of how many were sampled. A hard prompt might require 200 samples to find 4 correct ones, but training only uses those 4 correct + 4 incorrect = 8 rollouts. This is shown in the right plot: the blue curve (sampling) is U-shaped, but the orange line (training) is flat. Since training (backward pass + optimizer) is much more expensive per sample than inference (forward pass only), the fixed training size saves significant compute. Average cost = 383 (lower because training cost is constant).
In short: Ada-Seq-Balance spends extra inference compute (cheap) to recover signal from hard prompts, but does not spend extra training compute (expensive) on them.
Ada-Seq-Balance is consistently competitive
Ada-Seq-Balance 始终具有竞争力
Comparing GRPO, Ada-Est (log \(p\) variance-reduction), Ada-Seq-Pos-4, and Ada-Seq-Balance-4 across Qwen2.5-Math-1.5B and 7B, Ada-Seq-Balance is consistently competitive or best. On Qwen2.5-Math-7B, it reaches 54.6% weighted test accuracy versus GRPO-n4’s 53.0%.
Ada-Seq-Balance is more compute efficient than GRPO
Ada-Seq-Balance 比 GRPO 更计算高效
When matching per-step compute and training for up to 1000 steps, Ada-Seq-Balance-n8 consistently outperforms GRPO-n16 on the accuracy-vs-cost curve, and Ada-Seq-Balance-n16 outperforms GRPO-n32.
The accuracy-entropy frontier also improves: Ada-Seq-Balance achieves higher accuracy at the same entropy level, indicating better exploration-exploitation trade-off.