John Schulman: Approximating KL Divergence

Presenter: John Schulman
Host Institute: OpenAI
This post distills John Schulman's blog post Approximating KL Divergence (2020). Schulman describes a trick he has used in various code: approximating KL divergence via low-variance, nearly-unbiased estimators that only require log-probabilities, not full distributions. This is especially useful in reinforcement learning, where KL is often used as a diagnostic (e.g., monitoring how far the policy has drifted from a reference in PPO).
本文整理自 John Schulman 的博文 Approximating KL Divergence(2020)。Schulman 介绍了他在代码中常用的一个技巧:用低方差、近乎无偏的估计量来近似 KL 散度,只需 log 概率而不需要完整分布。这在强化学习中尤其有用,因为 KL 散度经常被用作诊断指标(例如在 PPO 中监控策略相对参考策略的漂移程度)。

The Setup: Why Monte Carlo?

We want to estimate the KL divergence from \(q\) to \(p\):

我们想估计从 \(q\) 到 \(p\) 的 KL 散度:

\[\mathrm{KL}[q, p] = \sum_x q(x) \log \frac{q(x)}{p(x)} = \mathbb{E}_{x \sim q}\!\left[\log \frac{q(x)}{p(x)}\right].\]

Our options for computing KL depend on what kind of access we have to \(p\) and \(q\). Here we assume we can evaluate the probabilities (or probability densities) \(p(x)\) and \(q(x)\) for any given \(x\), but we cannot calculate the sum over \(x\) analytically. Why not?

  1. Computation/memory: the state space is too large to enumerate (e.g., all possible token sequences).
  2. No closed form: the distributions don’t belong to a family with a known KL formula.
  3. Code simplicity: we only store the log-prob \(\log \pi_\theta(a \vert s)\), not the full distribution. This is a reasonable design choice when KL is just used as a diagnostic, as is often the case in reinforcement learning (e.g., logging KL between the current policy and a reference policy during PPO training).

In all three cases, we turn to Monte Carlo estimation. Given samples \(x_1, x_2, \ldots \sim q\), how can we construct a good estimate?

我们能否计算 KL 取决于对 \(p\) 和 \(q\) 有什么样的访问。这里假设我们可以对任何给定的 \(x\) 计算概率(或概率密度)\(p(x)\) 和 \(q(x)\),但无法解析地计算对 \(x\) 的求和。原因可能是:

  1. 计算/内存限制:状态空间太大,无法枚举(例如所有可能的 token 序列)。
  2. 没有解析表达式:分布不属于某个有已知 KL 公式的参数族。
  3. 代码简洁性:我们只存储了 log 概率 \(\log \pi_\theta(a \vert s)\),而不是完整分布。当 KL 仅用作诊断指标时,这是一个合理的设计选择,在强化学习中很常见(例如在 PPO 训练中记录当前策略与参考策略之间的 KL)。

在以上三种情况下,我们都需要Monte Carlo 估计。给定样本 \(x_1, x_2, \ldots \sim q\),如何构造一个好的估计量?

What Makes a Good Estimator?

A good estimator has two properties:

  • Unbiased: its expected value equals the true KL, i.e. \(\mathbb{E}[\hat{k}] = \mathrm{KL}[q,p]\).
  • Low variance: individual samples don’t fluctuate wildly around the mean.

We’ll define the probability ratio \(r = p(x)/q(x)\), so that \(\log r = \log p(x) - \log q(x)\). All three estimators below are functions of \(r\) (or equivalently, of \(\log r\)). This is convenient because in practice we often already have \(\log p(x)\) and \(\log q(x)\) computed — e.g., the log-probability of an action under two different policies.

一个好的估计量应具备两个性质:

  • 无偏:期望值等于真实 KL,即 \(\mathbb{E}[\hat{k}] = \mathrm{KL}[q,p]\)。
  • 低方差:单个样本不会剧烈偏离均值。

我们定义概率比 \(r = p(x)/q(x)\),即 \(\log r = \log p(x) - \log q(x)\)。下面三个估计量都是 \(r\)(或等价地,\(\log r\))的函数。这很方便,因为在实践中我们通常已经计算好了 \(\log p(x)\) 和 \(\log q(x)\)——例如,一个动作在两个不同策略下的 log 概率。

Three Estimators

k₁: The Naive Estimator

The most straightforward unbiased estimator follows directly from the definition of KL:

\[k_1 = -\log r = \log \frac{q(x)}{p(x)}.\]

Since \(\mathbb{E}_{x \sim q}[k_1] = \mathbb{E}_{x \sim q}\!\left[\log \frac{q(x)}{p(x)}\right] = \mathrm{KL}[q,p]\), this is exactly unbiased.

However, it has high variance. To see why, note that KL divergence is always non-negative (\(\mathrm{KL}[q,p] \geq 0\)), yet \(k_1\) takes negative values whenever \(r > 1\) (i.e., whenever \(p(x) > q(x)\)). For similar distributions, this happens for roughly half the samples. An estimator that’s negative half the time for a quantity that’s always positive is clearly noisy — we’re relying on cancellation between positive and negative samples to get the right mean.

最直接的无偏估计量直接来自 KL 的定义:

\[k_1 = -\log r = \log \frac{q(x)}{p(x)}.\]

由于 \(\mathbb{E}_{x \sim q}[k_1] = \mathbb{E}_{x \sim q}\!\left[\log \frac{q(x)}{p(x)}\right] = \mathrm{KL}[q,p]\),这是精确无偏的。

然而,它的方差很大。要理解原因,注意 KL 散度始终非负(\(\mathrm{KL}[q,p] \geq 0\)),但 \(k_1\) 在 \(r > 1\)(即 \(p(x) > q(x)\))时取负值。对于相似的分布,这大约发生在一半的样本上。一个估计量对于一个始终为正的量,却有一半时间取负值,显然噪声很大——我们依赖于正负样本的相消来得到正确的均值。

Why is KL always non-negative? (Click to expand)

By Jensen's inequality applied to the convex function \(-\log\):

$$\mathrm{KL}[q,p] = \mathbb{E}_{x \sim q}\!\left[-\log \frac{p(x)}{q(x)}\right] \geq -\log \mathbb{E}_{x \sim q}\!\left[\frac{p(x)}{q(x)}\right] = -\log 1 = 0.$$

This is known as Gibbs' inequality. The same inequality \(\log x \leq x - 1\) that we will use below to construct \(k_3\) provides an alternative proof: \(\mathrm{KL}[q,p] = \mathbb{E}_q[-\log r] \geq \mathbb{E}_q[1 - r] = 1 - 1 = 0\).

为什么 KL 始终非负?(点击展开)

对凸函数 \(-\log\) 使用 Jensen 不等式:

$$\mathrm{KL}[q,p] = \mathbb{E}_{x \sim q}\!\left[-\log \frac{p(x)}{q(x)}\right] \geq -\log \mathbb{E}_{x \sim q}\!\left[\frac{p(x)}{q(x)}\right] = -\log 1 = 0.$$

这就是 Gibbs 不等式。我们后面用来构造 \(k_3\) 的不等式 \(\log x \leq x - 1\) 也给出另一种证明:\(\mathrm{KL}[q,p] = \mathbb{E}_q[-\log r] \geq \mathbb{E}_q[1 - r] = 1 - 1 = 0\)。

The interactive figure below plots \(k_1\) alongside \(k_2\) and \(k_3\) (defined next) for comparison. Notice how \(k_1\) dips below zero for \(r > 1\) — this is where its high variance comes from.

下方的交互式图将 \(k_1\) 与 \(k_2\) 和 \(k_3\)(后文定义)一起绘制以供比较。注意 \(k_1\) 在 \(r > 1\) 时降到零以下——这正是其高方差的来源。

k₂: The Squared Log-Ratio

An alternative with lower variance but slight bias:

\[k_2 = \frac{1}{2}(\log r)^2.\]

Intuitively, \(k_2\) seems better because:

  • It is always non-negative (it’s a square).
  • Each sample directly measures how far apart \(p\) and \(q\) are at point \(x\), regardless of which direction the ratio goes.

Empirically, \(k_2\) indeed has much lower variance than \(k_1\), and also has remarkably low bias. But why is the bias small? The answer comes from f-divergences.

一个方差更低但略有偏差的替代方案:

\[k_2 = \frac{1}{2}(\log r)^2.\]

直觉上,\(k_2\) 更好,因为:

  • 它始终非负(是一个平方)。
  • 每个样本直接度量了 \(p\) 和 \(q\) 在点 \(x\) 处有多远,与比值的方向无关。

经验上,\(k_2\) 的方差确实比 \(k_1\) 低得多,偏差也非常小。但为什么偏差小?答案来自f-散度

f-Divergence Perspective

An f-divergence is a general family of divergences defined as:

\[D_f(p, q) = \mathbb{E}_{x \sim q}\!\left[f\!\left(\frac{p(x)}{q(x)}\right)\right] = \mathbb{E}_{x \sim q}[f(r)]\]

for a convex function \(f\) with \(f(1) = 0\). Many well-known divergences are special cases:

  • KL divergence \(\mathrm{KL}[q, p]\): \(f(r) = -\log r\)
  • Reverse KL \(\mathrm{KL}[p, q]\): \(f(r) = r \log r\)
  • Chi-squared divergence: \(f(r) = (r-1)^2\)

The expectation of \(k_2\) is \(\mathbb{E}_q\!\left[\frac{1}{2}(\log r)^2\right]\), which is also an f-divergence with \(f(r) = \frac{1}{2}(\log r)^2\).

Now here is the key non-obvious fact: all f-divergences with differentiable \(f\) look like KL divergence up to second order when \(q\) is close to \(p\). Specifically, for a parameterized distribution \(p_\theta\):

\[D_f(p_0, p_\theta) = \frac{f''(1)}{2}\,\theta^\top F\,\theta + O(\theta^3),\]

where \(F\) is the Fisher information matrix for \(p_\theta\) evaluated at \(p_\theta = p_0\).

Both \(k_2\)’s f-divergence (\(f(r) = \frac{1}{2}(\log r)^2\)) and KL (\(f(r) = -\log r\)) have \(f''(1) = 1\). So both look like the same quadratic distance function \(\frac{1}{2}\theta^\top F\,\theta\) when \(p \approx q\). The bias of \(k_2\) only comes from third-order and higher terms, which explains why it is negligible when \(p\) and \(q\) are close.

f-散度是一类通用的散度,定义为:

\[D_f(p, q) = \mathbb{E}_{x \sim q}\!\left[f\!\left(\frac{p(x)}{q(x)}\right)\right] = \mathbb{E}_{x \sim q}[f(r)]\]

其中 \(f\) 是满足 \(f(1) = 0\) 的凸函数。许多常见散度都是特例:

  • KL 散度 \(\mathrm{KL}[q, p]\):\(f(r) = -\log r\)
  • 反向 KL \(\mathrm{KL}[p, q]\):\(f(r) = r \log r\)
  • 卡方散度:\(f(r) = (r-1)^2\)

\(k_2\) 的期望为 \(\mathbb{E}_q\!\left[\frac{1}{2}(\log r)^2\right]\),也是一个 f-散度,对应 \(f(r) = \frac{1}{2}(\log r)^2\)。

关键的非显然事实是:所有具有可微 \(f\) 的 f-散度在 \(q\) 接近 \(p\) 时,二阶展开都与 KL 散度相同。具体来说,对于参数化分布 \(p_\theta\):

\[D_f(p_0, p_\theta) = \frac{f''(1)}{2}\,\theta^\top F\,\theta + O(\theta^3),\]

其中 \(F\) 是 \(p_\theta\) 在 \(p_\theta = p_0\) 处的 Fisher 信息矩阵。

\(k_2\) 的 f-散度(\(f(r) = \frac{1}{2}(\log r)^2\))和 KL(\(f(r) = -\log r\))都满足 \(f''(1) = 1\)。因此当 \(p \approx q\) 时,两者都近似于相同的二次距离函数 \(\frac{1}{2}\theta^\top F\,\theta\)。\(k_2\) 的偏差仅来自三阶及更高阶项,这解释了为什么在 \(p\) 和 \(q\) 接近时偏差可以忽略不计。

What is the Fisher information matrix, and why does it appear here? (Click to expand)

The Fisher information matrix \(F\) of a parametric family \(p_\theta\) is defined as:

$$F_{ij} = \mathbb{E}_{x \sim p_\theta}\!\left[\frac{\partial \log p_\theta(x)}{\partial \theta_i}\,\frac{\partial \log p_\theta(x)}{\partial \theta_j}\right] = -\mathbb{E}_{x \sim p_\theta}\!\left[\frac{\partial^2 \log p_\theta(x)}{\partial \theta_i \,\partial \theta_j}\right].$$

Intuitively, \(F\) measures how sensitive the distribution is to small changes in \(\theta\). If changing \(\theta_i\) by a tiny amount causes the log-likelihood to fluctuate a lot (high Fisher information), then the distribution is very "curved" in that direction — a small step in parameter space creates a large change in distribution space.

The interactive figure below makes this concrete. Both panels apply the same perturbation δ to the mean of a Gaussian. On the left, σ is small (high Fisher info F = 1/σ²) — the distributions barely overlap. On the right, σ is large (low Fisher info) — the same δ changes almost nothing. Try dragging the sliders.

Why does \(F\) appear in the f-divergence expansion? Consider \(p_\theta\) near \(p_0\) (i.e., \(\theta\) small). The ratio is:

$$r(\theta) = \frac{p_\theta(x)}{p_0(x)} = \exp\!\big(\log p_\theta(x) - \log p_0(x)\big).$$

Taylor-expanding \(\log p_\theta(x)\) around \(\theta = 0\):

$$\log p_\theta(x) = \log p_0(x) + \theta^\top \nabla_\theta \log p_0(x) + \frac{1}{2}\theta^\top \nabla^2_\theta \log p_0(x)\,\theta + O(\theta^3),$$

so \(\log r \approx \theta^\top s(x) + \frac{1}{2}\theta^\top H(x)\,\theta\), where \(s(x) = \nabla_\theta \log p_0(x)\) is the score function and \(H(x) = \nabla^2_\theta \log p_0(x)\) is its Hessian. Two key facts about the score:

  • \(\mathbb{E}_{p_0}[s(x)] = 0\) (the score has zero mean), and
  • \(\mathbb{E}_{p_0}[s(x)\,s(x)^\top] = F\) (its covariance is the Fisher matrix).

Substituting into \(D_f = \mathbb{E}_{p_0}[f(r)]\) and expanding \(f\) around \(r = 1\): since \(f(1) = 0\) and \(f'(1)\) contributes terms proportional to \(\mathbb{E}[s(x)] = 0\), the leading term is:

$$D_f \approx \frac{f''(1)}{2}\,\mathbb{E}_{p_0}\!\big[(\theta^\top s(x))^2\big] = \frac{f''(1)}{2}\,\theta^\top F\,\theta.$$

This is why all f-divergences share the same local geometry: they all reduce to a quadratic form in \(\theta\) weighted by the Fisher matrix, differing only by the scalar \(f''(1)\). The Fisher matrix is the unique "metric tensor" on the space of distributions (up to scale) — this is the foundation of information geometry.

In RL, the Fisher matrix of the policy \(\pi_\theta\) is exactly what defines the natural policy gradient: the direction \(F^{-1}\nabla_\theta J\) that makes the steepest improvement per unit of KL divergence, rather than per unit of Euclidean distance in parameter space.

什么是 Fisher 信息矩阵,为什么它会出现在这里?(点击展开)

参数族 \(p_\theta\) 的Fisher 信息矩阵 \(F\) 定义为:

$$F_{ij} = \mathbb{E}_{x \sim p_\theta}\!\left[\frac{\partial \log p_\theta(x)}{\partial \theta_i}\,\frac{\partial \log p_\theta(x)}{\partial \theta_j}\right] = -\mathbb{E}_{x \sim p_\theta}\!\left[\frac{\partial^2 \log p_\theta(x)}{\partial \theta_i \,\partial \theta_j}\right].$$

直觉上,\(F\) 度量的是分布对 \(\theta\) 的微小变化有多敏感。如果稍微改变 \(\theta_i\) 就能导致 log 似然大幅波动(高 Fisher 信息),那么分布在该方向上非常"弯曲"——参数空间中的一小步就会在分布空间中产生巨大变化。

下面的交互式图直观展示了这一点。两个面板对高斯分布的均值施加相同的扰动 δ。左侧 σ 小(高 Fisher 信息 F = 1/σ²)——两个分布几乎不重叠。右侧 σ 大(低 Fisher 信息)——同样的 δ 几乎没有改变分布。试试拖动滑块。

为什么 \(F\) 出现在 f-散度的展开中? 考虑 \(p_\theta\) 在 \(p_0\) 附近(即 \(\theta\) 很小)。概率比为:

$$r(\theta) = \frac{p_\theta(x)}{p_0(x)} = \exp\!\big(\log p_\theta(x) - \log p_0(x)\big).$$

将 \(\log p_\theta(x)\) 在 \(\theta = 0\) 处 Taylor 展开:

$$\log p_\theta(x) = \log p_0(x) + \theta^\top \nabla_\theta \log p_0(x) + \frac{1}{2}\theta^\top \nabla^2_\theta \log p_0(x)\,\theta + O(\theta^3),$$

于是 \(\log r \approx \theta^\top s(x) + \frac{1}{2}\theta^\top H(x)\,\theta\),其中 \(s(x) = \nabla_\theta \log p_0(x)\) 是得分函数(score function),\(H(x) = \nabla^2_\theta \log p_0(x)\) 是其 Hessian。关于得分函数的两个关键事实:

  • \(\mathbb{E}_{p_0}[s(x)] = 0\)(得分函数的均值为零),以及
  • \(\mathbb{E}_{p_0}[s(x)\,s(x)^\top] = F\)(其协方差就是 Fisher 矩阵)。

将其代入 \(D_f = \mathbb{E}_{p_0}[f(r)]\) 并将 \(f\) 在 \(r = 1\) 处展开:由于 \(f(1) = 0\) 且 \(f'(1)\) 贡献的项正比于 \(\mathbb{E}[s(x)] = 0\),主导项为:

$$D_f \approx \frac{f''(1)}{2}\,\mathbb{E}_{p_0}\!\big[(\theta^\top s(x))^2\big] = \frac{f''(1)}{2}\,\theta^\top F\,\theta.$$

这就是为什么所有 f-散度共享相同的局部几何:它们都归结为以 Fisher 矩阵加权的 \(\theta\) 的二次型,仅在标量 \(f''(1)\) 上不同。Fisher 矩阵是分布空间上唯一的"度量张量"(差一个尺度因子)——这是信息几何的基础。

在 RL 中,策略 \(\pi_\theta\) 的 Fisher 矩阵正是定义自然策略梯度的关键:方向 \(F^{-1}\nabla_\theta J\) 使得每单位 KL 散度(而非参数空间中的欧氏距离)的改进最大。

The interactive figure below plots \(k_2\) alongside \(k_1\) (dashed). Notice that \(k_2\) is always non-negative — a square can’t be negative. The zoomed inset near \(r = 1\) shows why the bias is small: \(k_2\) and KL agree to second order. Drag the \(\mu_p\) slider to see how the bias grows as the distributions diverge.

下方的交互式图将 \(k_2\) 与 \(k_1\)(虚线)一起绘制。注意 \(k_2\) 始终非负——平方不可能为负。\(r = 1\) 附近的放大插图展示了偏差为何很小:\(k_2\) 和 KL 在二阶展开上一致。拖动 \(\mu_p\) 滑块观察偏差如何随分布差异增大而增长。

k₃: The Best of Both Worlds

Can we get an estimator that is both unbiased (like \(k_1\)) and always non-negative (like \(k_2\))?

The general technique for reducing variance of an unbiased estimator is a control variate: add something with zero expectation that is negatively correlated with the original estimator. The only interesting quantity guaranteed to have zero expectation under \(q\) is:

\[\mathbb{E}_{x \sim q}\!\left[\frac{p(x)}{q(x)} - 1\right] = \mathbb{E}_{x \sim q}[r - 1] = \sum_x q(x) \cdot \frac{p(x)}{q(x)} - 1 = 1 - 1 = 0.\]

So for any \(\lambda\), the expression

\[-\log r + \lambda(r - 1)\]

is an unbiased estimator of \(\mathrm{KL}[q,p]\). We could minimize the variance over \(\lambda\), but this yields an expression that depends on \(p\) and \(q\) and is hard to compute analytically.

Instead, we can choose a good \(\lambda\) using a simpler and more elegant argument. Since \(\log\) is concave, we have the fundamental inequality:

\[\log x \leq x - 1 \quad \text{for all } x > 0,\]

with equality only at \(x = 1\). Setting \(\lambda = 1\), the estimator becomes:

\[k_3 = (r - 1) - \log r = \underbrace{-\log r}_{k_1} + \underbrace{(r - 1)}_{\text{control variate}}.\]

By the inequality above, \((r-1) - \log r \geq 0\) for all \(r > 0\), with equality only when \(r = 1\) (i.e., \(p(x) = q(x)\)). So \(k_3\) is:

  • Unbiased (since \(\mathbb{E}[r-1] = 0\), we’re just adding zero in expectation to \(k_1\)).
  • Always non-negative (by the concavity of \(\log\)).
  • Low variance (the control variate cancels much of \(k_1\)’s noise).

能否得到一个既无偏(像 \(k_1\))又始终非负(像 \(k_2\))的估计量?

降低无偏估计量方差的通用技术是控制变量:加上一个期望为零但与原始估计量负相关的项。在 \(q\) 下唯一保证期望为零的有趣量是:

\[\mathbb{E}_{x \sim q}\!\left[\frac{p(x)}{q(x)} - 1\right] = \mathbb{E}_{x \sim q}[r - 1] = \sum_x q(x) \cdot \frac{p(x)}{q(x)} - 1 = 1 - 1 = 0.\]

因此对任意 \(\lambda\),

\[-\log r + \lambda(r - 1)\]

都是 \(\mathrm{KL}[q,p]\) 的无偏估计量。我们可以对 \(\lambda\) 最小化方差,但得到的表达式依赖于 \(p\) 和 \(q\),难以解析计算。

我们可以用一个更简洁优雅的论证来选择一个好的 \(\lambda\)。由于 \(\log\) 是凹函数,我们有基本不等式:

\[\log x \leq x - 1 \quad \text{对所有 } x > 0,\]

等号仅在 \(x = 1\) 时成立。令 \(\lambda = 1\),估计量变为:

\[k_3 = (r - 1) - \log r = \underbrace{-\log r}_{k_1} + \underbrace{(r - 1)}_{\text{控制变量}}.\]

由上述不等式,\((r-1) - \log r \geq 0\) 对所有 \(r > 0\) 成立,等号仅当 \(r = 1\)(即 \(p(x) = q(x)\))时成立。所以 \(k_3\):

  • 无偏(因为 \(\mathbb{E}[r-1] = 0\),我们只是在期望上给 \(k_1\) 加了零)。
  • 始终非负(由 \(\log\) 的凹性保证)。
  • 低方差(控制变量消除了 \(k_1\) 的大部分噪声)。

Geometric Interpretation: Bregman Divergence

There is a beautiful geometric way to see why \(k_3\) is non-negative. Consider the convex function \(\phi(r) = -\log r\). Its tangent line at \(r = 1\) is \(\ell(r) = -(r - 1)\). Then:

\[k_3 = (r-1) - \log r = \phi(r) - \ell(r) = (-\log r) - (-(r-1)).\]

This is the vertical gap between the convex function and its tangent line. Since convex functions always lie above their tangent lines, this gap is always non-negative.

This construction — measuring distance as the gap between a convex function and its tangent plane — is called a Bregman divergence. It appears throughout optimization, information theory, and machine learning, and has many beautiful properties (e.g., the “three-point identity” that generalizes the Pythagorean theorem).

You can see this geometry in the interactive figure below. Drag the slider to see how the gap grows as \(r\) moves away from 1.

有一种优美的几何方式来理解 \(k_3\) 为什么非负。考虑凸函数 \(\phi(r) = -\log r\),其在 \(r = 1\) 处的切线为 \(\ell(r) = -(r - 1)\)。则:

\[k_3 = (r-1) - \log r = \phi(r) - \ell(r) = (-\log r) - (-(r-1)).\]

这是凸函数与其切线之间的垂直距离。由于凸函数总是在其切线之上,这个距离始终非负。

这种构造——用凸函数与其切平面之间的间距来度量距离——称为 Bregman 散度。它出现在优化、信息论和机器学习的各个角落,有许多优美的性质(例如推广了勾股定理的”三点恒等式”)。

你可以在下方的交互式图中看到这一几何关系。拖动滑块观察当 \(r\) 远离 1 时间距如何增长。

Numerical Experiments

To see how these estimators compare in practice, consider Gaussian experiments from Schulman’s post. Let \(q = \mathcal{N}(0, 1)\) and \(p = \mathcal{N}(\mu, 1)\), so the true KL is \(\mu^2/2\). Try the two preset experiments (\(\mu = 0.1\) and \(\mu = 1.0\)), or drag \(\mu\) to any value to see how bias and variance change:

为了看看这些估计量在实践中如何比较,考虑 Schulman 博文中的高斯实验。令 \(q = \mathcal{N}(0, 1)\),\(p = \mathcal{N}(\mu, 1)\),真实 KL 为 \(\mu^2/2\)。试试两个预设实验(\(\mu = 0.1\) 和 \(\mu = 1.0\)),或拖动 \(\mu\) 到任意值观察偏差和方差的变化:

Key observations as you drag \(\mu\):

  • Small \(\mu\) (≈ 0.1): \(k_1\)’s std is ~20× the true KL — you’d need hundreds of samples for a reliable sign. \(k_2\) and \(k_3\) are nearly identical (\(k_2\)’s bias ≈ 0.2%).
  • Large \(\mu\) (≈ 1.0): \(k_2\)’s bias grows to ~25% — no longer negligible. \(k_3\) stays unbiased with low variance. \(k_3\) is strictly better.

拖动 \(\mu\) 时的关键观察:

  • 小 \(\mu\)(≈ 0.1):\(k_1\) 的标准差是真实 KL 的 ~20 倍——需要数百个样本才能可靠判断正负号。\(k_2\) 和 \(k_3\) 几乎相同(\(k_2\) 偏差 ≈ 0.2%)。
  • 大 \(\mu\)(≈ 1.0):\(k_2\) 的偏差增长到 ~25%——不再可忽略。\(k_3\) 保持无偏且低方差。\(k_3\) 是严格更优的估计量。

Summary

For samples \(x \sim q\) and ratio \(r = p(x)/q(x)\), the three estimators are:

  Estimator Unbiased? Always ≥ 0? Variance
\(k_1\) \(-\log r\) Yes No High
\(k_2\) \(\frac{1}{2}(\log r)^2\) No (low bias when \(p \approx q\)) Yes Low
\(k_3\) \((r-1) - \log r\) Yes Yes Low

\(k_3\) is the clear winner: unbiased, always non-negative, and low variance. It achieves this by adding the control variate \((r-1)\) to the naive estimator \(k_1\), and its non-negativity follows from the concavity of \(\log\) (equivalently, the Bregman divergence interpretation).

对于样本 \(x \sim q\) 和概率比 \(r = p(x)/q(x)\),三个估计量为:

  估计量 无偏? 始终 ≥ 0? 方差
\(k_1\) \(-\log r\)
\(k_2\) \(\frac{1}{2}(\log r)^2\) 否(\(p \approx q\) 时偏差低)
\(k_3\) \((r-1) - \log r\)

\(k_3\) 是明显的赢家:无偏、始终非负、低方差。它通过将控制变量 \((r-1)\) 加到朴素估计量 \(k_1\) 上来实现这一点,其非负性来自 \(\log\) 的凹性(等价地,来自 Bregman 散度的解释)。

Generalization: Estimators for Any f-Divergence

The Bregman divergence trick generalizes elegantly. For any f-divergence \(D_f(p,q) = \mathbb{E}_{x \sim q}[f(r)]\) with convex \(f\), the estimator

\[f(r) - f'(1)(r - 1)\]

is:

  • Unbiased: because \(\mathbb{E}_q[f'(1)(r-1)] = f'(1) \cdot 0 = 0\).
  • Always non-negative: because \(f\) is convex, it lies above its tangent at \(r = 1\), so \(f(r) \geq f(1) + f'(1)(r-1) = f'(1)(r-1)\) (using \(f(1) = 0\)).

This is the Bregman divergence of \(f\) at point \(r\) relative to \(r = 1\).

Bregman 散度技巧可以优雅地推广。对于任意 f-散度 \(D_f(p,q) = \mathbb{E}_{x \sim q}[f(r)]\)(\(f\) 为凸函数),估计量

\[f(r) - f'(1)(r - 1)\]

满足:

  • 无偏:因为 \(\mathbb{E}_q[f'(1)(r-1)] = f'(1) \cdot 0 = 0\)。
  • 始终非负:因为 \(f\) 是凸的,它在 \(r = 1\) 处的切线之上,所以 \(f(r) \geq f(1) + f'(1)(r-1) = f'(1)(r-1)\)(利用 \(f(1) = 0\))。

这就是 \(f\) 在点 \(r\) 相对于 \(r = 1\) 的 Bregman 散度。

Application: Reverse KL

The most notable application is to \(\mathrm{KL}[p, q]\) (note \(p\) and \(q\) are swapped). This corresponds to \(f(r) = r \log r\), which has \(f'(1) = 1\). The Bregman-based estimator becomes:

\[r\log r - (r - 1).\]

最重要的应用是 \(\mathrm{KL}[p, q]\)(注意 \(p\) 和 \(q\) 交换了)。对应 \(f(r) = r \log r\),\(f'(1) = 1\)。基于 Bregman 的估计量为:

\[r\log r - (r - 1).\]

Final summary: for samples \(x \sim q\) with \(r = p(x)/q(x)\), the recommended estimators are:

Divergence Estimator Properties
\(\mathrm{KL}[q, p]\) \((r - 1) - \log r\) Unbiased, non-negative, low variance
\(\mathrm{KL}[p, q]\) \(r\log r - (r - 1)\) Unbiased, non-negative, low variance

Both are special cases of the general Bregman divergence estimator \(f(r) - f'(1)(r-1)\) for their respective f-divergence generators. In practice, you can drop these into any codebase that computes log-probs — no need to store or compute full distributions.

最终总结:对于样本 \(x \sim q\),\(r = p(x)/q(x)\),推荐的估计量为:

散度 估计量 性质
\(\mathrm{KL}[q, p]\) \((r - 1) - \log r\) 无偏、非负、低方差
\(\mathrm{KL}[p, q]\) \(r\log r - (r - 1)\) 无偏、非负、低方差

两者都是通用 Bregman 散度估计量 \(f(r) - f'(1)(r-1)\) 对各自 f-散度生成函数的特例。在实践中,你可以直接将它们加入任何计算 log 概率的代码库——无需存储或计算完整分布。