John Schulman: Approximating KL Divergence
The Setup: Why Monte Carlo?
问题设定:为什么需要 Monte Carlo?
We want to estimate the KL divergence from \(q\) to \(p\):
Our options for computing KL depend on what kind of access we have to \(p\) and \(q\). Here we assume we can evaluate the probabilities (or probability densities) \(p(x)\) and \(q(x)\) for any given \(x\), but we cannot calculate the sum over \(x\) analytically. Why not?
- Computation/memory: the state space is too large to enumerate (e.g., all possible token sequences).
- No closed form: the distributions don’t belong to a family with a known KL formula.
- Code simplicity: we only store the log-prob \(\log \pi_\theta(a \vert s)\), not the full distribution. This is a reasonable design choice when KL is just used as a diagnostic, as is often the case in reinforcement learning (e.g., logging KL between the current policy and a reference policy during PPO training).
In all three cases, we turn to Monte Carlo estimation. Given samples \(x_1, x_2, \ldots \sim q\), how can we construct a good estimate?
What Makes a Good Estimator?
什么是好的估计量?
A good estimator has two properties:
- Unbiased: its expected value equals the true KL, i.e. \(\mathbb{E}[\hat{k}] = \mathrm{KL}[q,p]\).
- Low variance: individual samples don’t fluctuate wildly around the mean.
We’ll define the probability ratio \(r = p(x)/q(x)\), so that \(\log r = \log p(x) - \log q(x)\). All three estimators below are functions of \(r\) (or equivalently, of \(\log r\)). This is convenient because in practice we often already have \(\log p(x)\) and \(\log q(x)\) computed — e.g., the log-probability of an action under two different policies.
Three Estimators
三个估计量
k₁: The Naive Estimator
k₁:朴素估计量
The most straightforward unbiased estimator follows directly from the definition of KL:
\[k_1 = -\log r = \log \frac{q(x)}{p(x)}.\]Since \(\mathbb{E}_{x \sim q}[k_1] = \mathbb{E}_{x \sim q}\!\left[\log \frac{q(x)}{p(x)}\right] = \mathrm{KL}[q,p]\), this is exactly unbiased.
However, it has high variance. To see why, note that KL divergence is always non-negative (\(\mathrm{KL}[q,p] \geq 0\)), yet \(k_1\) takes negative values whenever \(r > 1\) (i.e., whenever \(p(x) > q(x)\)). For similar distributions, this happens for roughly half the samples. An estimator that’s negative half the time for a quantity that’s always positive is clearly noisy — we’re relying on cancellation between positive and negative samples to get the right mean.
Why is KL always non-negative? (Click to expand)
By Jensen's inequality applied to the convex function \(-\log\):
$$\mathrm{KL}[q,p] = \mathbb{E}_{x \sim q}\!\left[-\log \frac{p(x)}{q(x)}\right] \geq -\log \mathbb{E}_{x \sim q}\!\left[\frac{p(x)}{q(x)}\right] = -\log 1 = 0.$$
This is known as Gibbs' inequality. The same inequality \(\log x \leq x - 1\) that we will use below to construct \(k_3\) provides an alternative proof: \(\mathrm{KL}[q,p] = \mathbb{E}_q[-\log r] \geq \mathbb{E}_q[1 - r] = 1 - 1 = 0\).
The interactive figure below plots \(k_1\) alongside \(k_2\) and \(k_3\) (defined next) for comparison. Notice how \(k_1\) dips below zero for \(r > 1\) — this is where its high variance comes from.
k₂: The Squared Log-Ratio
k₂:对数比的平方
An alternative with lower variance but slight bias:
\[k_2 = \frac{1}{2}(\log r)^2.\]Intuitively, \(k_2\) seems better because:
- It is always non-negative (it’s a square).
- Each sample directly measures how far apart \(p\) and \(q\) are at point \(x\), regardless of which direction the ratio goes.
Empirically, \(k_2\) indeed has much lower variance than \(k_1\), and also has remarkably low bias. But why is the bias small? The answer comes from f-divergences.
f-Divergence Perspective
f-散度视角
An f-divergence is a general family of divergences defined as:
\[D_f(p, q) = \mathbb{E}_{x \sim q}\!\left[f\!\left(\frac{p(x)}{q(x)}\right)\right] = \mathbb{E}_{x \sim q}[f(r)]\]for a convex function \(f\) with \(f(1) = 0\). Many well-known divergences are special cases:
- KL divergence \(\mathrm{KL}[q, p]\): \(f(r) = -\log r\)
- Reverse KL \(\mathrm{KL}[p, q]\): \(f(r) = r \log r\)
- Chi-squared divergence: \(f(r) = (r-1)^2\)
The expectation of \(k_2\) is \(\mathbb{E}_q\!\left[\frac{1}{2}(\log r)^2\right]\), which is also an f-divergence with \(f(r) = \frac{1}{2}(\log r)^2\).
Now here is the key non-obvious fact: all f-divergences with differentiable \(f\) look like KL divergence up to second order when \(q\) is close to \(p\). Specifically, for a parameterized distribution \(p_\theta\):
\[D_f(p_0, p_\theta) = \frac{f''(1)}{2}\,\theta^\top F\,\theta + O(\theta^3),\]where \(F\) is the Fisher information matrix for \(p_\theta\) evaluated at \(p_\theta = p_0\).
Both \(k_2\)’s f-divergence (\(f(r) = \frac{1}{2}(\log r)^2\)) and KL (\(f(r) = -\log r\)) have \(f''(1) = 1\). So both look like the same quadratic distance function \(\frac{1}{2}\theta^\top F\,\theta\) when \(p \approx q\). The bias of \(k_2\) only comes from third-order and higher terms, which explains why it is negligible when \(p\) and \(q\) are close.
What is the Fisher information matrix, and why does it appear here? (Click to expand)
The Fisher information matrix \(F\) of a parametric family \(p_\theta\) is defined as:
$$F_{ij} = \mathbb{E}_{x \sim p_\theta}\!\left[\frac{\partial \log p_\theta(x)}{\partial \theta_i}\,\frac{\partial \log p_\theta(x)}{\partial \theta_j}\right] = -\mathbb{E}_{x \sim p_\theta}\!\left[\frac{\partial^2 \log p_\theta(x)}{\partial \theta_i \,\partial \theta_j}\right].$$
Intuitively, \(F\) measures how sensitive the distribution is to small changes in \(\theta\). If changing \(\theta_i\) by a tiny amount causes the log-likelihood to fluctuate a lot (high Fisher information), then the distribution is very "curved" in that direction — a small step in parameter space creates a large change in distribution space.
The interactive figure below makes this concrete. Both panels apply the same perturbation δ to the mean of a Gaussian. On the left, σ is small (high Fisher info F = 1/σ²) — the distributions barely overlap. On the right, σ is large (low Fisher info) — the same δ changes almost nothing. Try dragging the sliders.
Why does \(F\) appear in the f-divergence expansion? Consider \(p_\theta\) near \(p_0\) (i.e., \(\theta\) small). The ratio is:
$$r(\theta) = \frac{p_\theta(x)}{p_0(x)} = \exp\!\big(\log p_\theta(x) - \log p_0(x)\big).$$
Taylor-expanding \(\log p_\theta(x)\) around \(\theta = 0\):
$$\log p_\theta(x) = \log p_0(x) + \theta^\top \nabla_\theta \log p_0(x) + \frac{1}{2}\theta^\top \nabla^2_\theta \log p_0(x)\,\theta + O(\theta^3),$$
so \(\log r \approx \theta^\top s(x) + \frac{1}{2}\theta^\top H(x)\,\theta\), where \(s(x) = \nabla_\theta \log p_0(x)\) is the score function and \(H(x) = \nabla^2_\theta \log p_0(x)\) is its Hessian. Two key facts about the score:
- \(\mathbb{E}_{p_0}[s(x)] = 0\) (the score has zero mean), and
- \(\mathbb{E}_{p_0}[s(x)\,s(x)^\top] = F\) (its covariance is the Fisher matrix).
Substituting into \(D_f = \mathbb{E}_{p_0}[f(r)]\) and expanding \(f\) around \(r = 1\): since \(f(1) = 0\) and \(f'(1)\) contributes terms proportional to \(\mathbb{E}[s(x)] = 0\), the leading term is:
$$D_f \approx \frac{f''(1)}{2}\,\mathbb{E}_{p_0}\!\big[(\theta^\top s(x))^2\big] = \frac{f''(1)}{2}\,\theta^\top F\,\theta.$$
This is why all f-divergences share the same local geometry: they all reduce to a quadratic form in \(\theta\) weighted by the Fisher matrix, differing only by the scalar \(f''(1)\). The Fisher matrix is the unique "metric tensor" on the space of distributions (up to scale) — this is the foundation of information geometry.
In RL, the Fisher matrix of the policy \(\pi_\theta\) is exactly what defines the natural policy gradient: the direction \(F^{-1}\nabla_\theta J\) that makes the steepest improvement per unit of KL divergence, rather than per unit of Euclidean distance in parameter space.
The interactive figure below plots \(k_2\) alongside \(k_1\) (dashed). Notice that \(k_2\) is always non-negative — a square can’t be negative. The zoomed inset near \(r = 1\) shows why the bias is small: \(k_2\) and KL agree to second order. Drag the \(\mu_p\) slider to see how the bias grows as the distributions diverge.
k₃: The Best of Both Worlds
k₃:两全其美
Can we get an estimator that is both unbiased (like \(k_1\)) and always non-negative (like \(k_2\))?
The general technique for reducing variance of an unbiased estimator is a control variate: add something with zero expectation that is negatively correlated with the original estimator. The only interesting quantity guaranteed to have zero expectation under \(q\) is:
\[\mathbb{E}_{x \sim q}\!\left[\frac{p(x)}{q(x)} - 1\right] = \mathbb{E}_{x \sim q}[r - 1] = \sum_x q(x) \cdot \frac{p(x)}{q(x)} - 1 = 1 - 1 = 0.\]So for any \(\lambda\), the expression
\[-\log r + \lambda(r - 1)\]is an unbiased estimator of \(\mathrm{KL}[q,p]\). We could minimize the variance over \(\lambda\), but this yields an expression that depends on \(p\) and \(q\) and is hard to compute analytically.
Instead, we can choose a good \(\lambda\) using a simpler and more elegant argument. Since \(\log\) is concave, we have the fundamental inequality:
\[\log x \leq x - 1 \quad \text{for all } x > 0,\]with equality only at \(x = 1\). Setting \(\lambda = 1\), the estimator becomes:
\[k_3 = (r - 1) - \log r = \underbrace{-\log r}_{k_1} + \underbrace{(r - 1)}_{\text{control variate}}.\]By the inequality above, \((r-1) - \log r \geq 0\) for all \(r > 0\), with equality only when \(r = 1\) (i.e., \(p(x) = q(x)\)). So \(k_3\) is:
- Unbiased (since \(\mathbb{E}[r-1] = 0\), we’re just adding zero in expectation to \(k_1\)).
- Always non-negative (by the concavity of \(\log\)).
- Low variance (the control variate cancels much of \(k_1\)’s noise).
Geometric Interpretation: Bregman Divergence
几何解释:Bregman 散度
There is a beautiful geometric way to see why \(k_3\) is non-negative. Consider the convex function \(\phi(r) = -\log r\). Its tangent line at \(r = 1\) is \(\ell(r) = -(r - 1)\). Then:
\[k_3 = (r-1) - \log r = \phi(r) - \ell(r) = (-\log r) - (-(r-1)).\]This is the vertical gap between the convex function and its tangent line. Since convex functions always lie above their tangent lines, this gap is always non-negative.
This construction — measuring distance as the gap between a convex function and its tangent plane — is called a Bregman divergence. It appears throughout optimization, information theory, and machine learning, and has many beautiful properties (e.g., the “three-point identity” that generalizes the Pythagorean theorem).
You can see this geometry in the interactive figure below. Drag the slider to see how the gap grows as \(r\) moves away from 1.
Numerical Experiments
数值实验
To see how these estimators compare in practice, consider Gaussian experiments from Schulman’s post. Let \(q = \mathcal{N}(0, 1)\) and \(p = \mathcal{N}(\mu, 1)\), so the true KL is \(\mu^2/2\). Try the two preset experiments (\(\mu = 0.1\) and \(\mu = 1.0\)), or drag \(\mu\) to any value to see how bias and variance change:
Key observations as you drag \(\mu\):
- Small \(\mu\) (≈ 0.1): \(k_1\)’s std is ~20× the true KL — you’d need hundreds of samples for a reliable sign. \(k_2\) and \(k_3\) are nearly identical (\(k_2\)’s bias ≈ 0.2%).
- Large \(\mu\) (≈ 1.0): \(k_2\)’s bias grows to ~25% — no longer negligible. \(k_3\) stays unbiased with low variance. \(k_3\) is strictly better.
Summary
总结
For samples \(x \sim q\) and ratio \(r = p(x)/q(x)\), the three estimators are:
| Estimator | Unbiased? | Always ≥ 0? | Variance | |
|---|---|---|---|---|
| \(k_1\) | \(-\log r\) | Yes | No | High |
| \(k_2\) | \(\frac{1}{2}(\log r)^2\) | No (low bias when \(p \approx q\)) | Yes | Low |
| \(k_3\) | \((r-1) - \log r\) | Yes | Yes | Low |
\(k_3\) is the clear winner: unbiased, always non-negative, and low variance. It achieves this by adding the control variate \((r-1)\) to the naive estimator \(k_1\), and its non-negativity follows from the concavity of \(\log\) (equivalently, the Bregman divergence interpretation).
Generalization: Estimators for Any f-Divergence
推广:任意 f-散度的估计量
The Bregman divergence trick generalizes elegantly. For any f-divergence \(D_f(p,q) = \mathbb{E}_{x \sim q}[f(r)]\) with convex \(f\), the estimator
\[f(r) - f'(1)(r - 1)\]is:
- Unbiased: because \(\mathbb{E}_q[f'(1)(r-1)] = f'(1) \cdot 0 = 0\).
- Always non-negative: because \(f\) is convex, it lies above its tangent at \(r = 1\), so \(f(r) \geq f(1) + f'(1)(r-1) = f'(1)(r-1)\) (using \(f(1) = 0\)).
This is the Bregman divergence of \(f\) at point \(r\) relative to \(r = 1\).
Application: Reverse KL
应用:反向 KL
The most notable application is to \(\mathrm{KL}[p, q]\) (note \(p\) and \(q\) are swapped). This corresponds to \(f(r) = r \log r\), which has \(f'(1) = 1\). The Bregman-based estimator becomes:
\[r\log r - (r - 1).\]Final summary: for samples \(x \sim q\) with \(r = p(x)/q(x)\), the recommended estimators are:
| Divergence | Estimator | Properties |
|---|---|---|
| \(\mathrm{KL}[q, p]\) | \((r - 1) - \log r\) | Unbiased, non-negative, low variance |
| \(\mathrm{KL}[p, q]\) | \(r\log r - (r - 1)\) | Unbiased, non-negative, low variance |
Both are special cases of the general Bregman divergence estimator \(f(r) - f'(1)(r-1)\) for their respective f-divergence generators. In practice, you can drop these into any codebase that computes log-probs — no need to store or compute full distributions.