The Setup
You have a true distribution , but you are using an approximate distribution to reason about the world.
How much are you paying for this approximation? How many extra bits do you waste every time you use when the truth is ?
The answers this precisely.
If you have two probability distributions — what your model predicts versus what actually happens — how do you measure how wrong the model is? Subtracting them misses how the mismatch compounds across low-probability events. You need a number that captures the full shape of the disagreement. KL divergence is that number, and it turns out to be exactly what you're minimizing every time you train a model with cross-entropy loss.
Definition
The from to is:
- KL divergence from Q to P — read as 'KL of P from Q'
- probability under the true distribution P
- probability under the approximate distribution Q
- sum over all values x where P(x) > 0
Expanding using :
- entropy of P
- cross-entropy of P and Q
where is the .
So: . KL divergence is the extra cost beyond what you'd pay if you used the optimal code.
KL ≥ 0 Always (via Jensen's Inequality)
This is a key result. We need : for any convex function , .
Since is convex:
- expectation taken under distribution P
Therefore always, with equality if and only if everywhere. This is sometimes called Gibbs' inequality.
Worked Numerical Example
Suppose there are only two outcomes (like a biased coin), and:
| Outcome | (true) | (model) |
|---|---|---|
| Heads | 0.7 | 0.5 |
| Tails | 0.3 | 0.5 |
Compute (using natural log):
- contribution from Heads
- contribution from Tails
Now compute (reversed):
.
KL divergence is not symmetric. It is not a true distance metric.
Why Asymmetry Matters: Forward vs Reverse KL
The asymmetry has deep consequences in practice.
KL Divergence in the Wild
Variational Autoencoders (VAEs): the ELBO loss contains — it forces the approximate posterior to stay close to the prior.
RLHF (Reinforcement Learning from Human Feedback): the fine-tuning objective contains a KL penalty to prevent the model from drifting too far from the reference policy.
Information bottleneck: the tradeoff between compression and prediction is formalized as minimizing KL divergence between the learned representation and a target distribution.
In Code
import torch import torch.nn.functional as F # KL divergence: F.kl_div expects log-probabilities for input log_q = torch.log(torch.tensor([0.5, 0.5])) p = torch.tensor([0.7, 0.3]) # F.kl_div(log_q, p) computes sum(p * (log_p - log_q)) # reduction='sum' or 'batchmean' depending on use kl = F.kl_div(log_q, p, reduction='sum') # ≈ 0.082 nats, matching our manual calculation
The next lesson connects KL divergence to mutual information — measuring how much two variables tell you about each other.