Skip to content
Generative Models
Lesson 5 ⏱ 14 min

The reparameterization trick

Video coming soon

The Reparameterization Trick: Making Sampling Differentiable

Why you cannot backpropagate through a sampling operation, how reparameterization shifts the randomness outside the computational graph, and what the gradients look like explicitly.

⏱ ~7 min

🧮

Quick refresher

Backpropagation through the chain rule

Backprop computes ∂L/∂w by multiplying local gradients along the path from the loss to w. Every operation on the path must have a defined derivative for the chain rule to apply.

Example

For L = (wx)², ∂L/∂w = 2wx · x = 2x²w.

The chain rule multiplies the outer derivative (2wx) by the inner derivative (x).

The Problem with Sampling

Recall the VAE training loop from lesson 14-3:

  1. Encoder outputs μϕ(x)\mu_\phi(x) and σϕ(x)\sigma_\phi(x)
  2. Sample zN(μϕ(x),σϕ(x)2I)z \sim \mathcal{N}(\mu_\phi(x), \sigma_\phi(x)^2 \cdot I)
  3. Decoder produces x^=gθ(z)\hat{x} = g_\theta(z)
  4. Compute loss, backpropagate

The reparameterization trick is the single idea that makes VAE training work at all. Without it, gradients cannot flow through the sampling step, and the entire architecture is untrainable. It also appears in many other models with stochastic components — diffusion models, normalizing flows, and stochastic policies in RL all use versions of this trick.

Step 2 is the problem. Backpropagation needs to compute L/μϕ\partial L / \partial \mu_\phi and L/σϕ\partial L / \partial \sigma_\phi — the gradients with respect to the encoder parameters. To do that it needs to propagate gradients through the sampling step, which requires z/μ\partial z / \partial \mu and z/σ\partial z / \partial \sigma.

But zz is a random draw. It is not a deterministic function of μ\mu and σ\sigma. Every time step 2 runs, a different zz comes out. There is no fixed derivative to place in the computational graph. The gradient is undefined.

The Trick: Move the Randomness Outside

Instead of sampling zz directly, rewrite:

z=μ+σε,εN(0,I)z = \mu + \sigma \odot \varepsilon, \qquad \varepsilon \sim \mathcal{N}(0, I)
zz
the latent code we want to sample
μ\mu
encoder mean output
σ\sigma
encoder standard deviation (element-wise)
ε\varepsilon
auxiliary noise sampled from standard normal — independent of encoder parameters
\odot
element-wise multiplication

Check: is zz still distributed as N(μ,σ2I)\mathcal{N}(\mu, \sigma^2 I)? Yes — adding a constant to a Gaussian shifts its mean, scaling by σ\sigma scales its standard deviation. The distribution is identical to before.

What has changed: the randomness now lives entirely in ε\varepsilon, which is independent of the encoder parameters. The path from μ\mu and σ\sigma to zz is now a deterministic function:

z=μ+σεz = \mu + \sigma \odot \varepsilon

This is differentiable. The local gradients are:

zjμj=1,zjσj=εj\frac{\partial z_j}{\partial \mu_j} = 1, \qquad \frac{\partial z_j}{\partial \sigma_j} = \varepsilon_j
zj/μj\partial z_j / \partial \mu_j
gradient of the j-th latent dimension with respect to the j-th mean — always 1
zj/σj\partial z_j / \partial \sigma_j
gradient with respect to the j-th standard deviation — equals the sampled noise ε_j

Backprop can now flow gradients from the decoder loss all the way back through zz, through μ\mu and σ\sigma, and into the encoder network weights. The encoder is trainable.

Visualizing the Computational Graph

Without reparameterization (broken):

μ,σ;sample;z;gθ;x^;L\mu, \sigma ;\xrightarrow{\text{sample}}; z ;\xrightarrow{g_\theta}; \hat{x} ;\to L

The arrow labeled "sample" has no backward pass. Gradients stop at zz and never reach μ\mu or σ\sigma.

With reparameterization (working):

εN(0,I);;z=μ+σε;gθ;x^;L\varepsilon \sim \mathcal{N}(0,I) ;\to; z = \mu + \sigma \odot \varepsilon ;\xrightarrow{g_\theta}; \hat{x} ;\to L
μ;;σ\mu ;\nearrow\hspace{-1em}\swarrow; \sigma

The path μzx^L\mu \to z \to \hat{x} \to L and σzx^L\sigma \to z \to \hat{x} \to L are fully differentiable. The noise ε\varepsilon is a leaf node with no parameters — its gradient is never needed.

A Complete Numerical Walk-Through

Suppose encoder outputs μ=0.6\mu = 0.6, σ=0.4\sigma = 0.4 for a single latent dimension. We draw ε=0.7\varepsilon = -0.7.

Forward pass:

z=0.6+0.4×(0.7)=0.60.28=0.32z = 0.6 + 0.4 \times (-0.7) = 0.6 - 0.28 = 0.32

The decoder receives z=0.32z = 0.32 and produces some output. Suppose after computing the full loss, L/z=1.5\partial L / \partial z = 1.5.

Backward pass through the reparameterization:

Lμ=Lzzμ=1.5×1=1.5\frac{\partial L}{\partial \mu} = \frac{\partial L}{\partial z} \cdot \frac{\partial z}{\partial \mu} = 1.5 \times 1 = 1.5
Lσ=Lzzσ=1.5×(0.7)=1.05\frac{\partial L}{\partial \sigma} = \frac{\partial L}{\partial z} \cdot \frac{\partial z}{\partial \sigma} = 1.5 \times (-0.7) = -1.05

These gradients flow backward into the encoder network. The KL gradient (from lesson 14-4) adds on top:

KLμ=μ=0.6,KLσ=σ1σ=0.42.5=2.1\frac{\partial \text{KL}}{\partial \mu} = \mu = 0.6, \qquad \frac{\partial \text{KL}}{\partial \sigma} = \sigma - \frac{1}{\sigma} = 0.4 - 2.5 = -2.1

The encoder weights are updated using the sum of both gradient contributions — reconstruction signal (via reparameterization) and regularization signal (from the KL term directly).

Generalizing Beyond Gaussians

The reparameterization trick applies to any distribution with a tractable inverse CDF (quantile function). To sample from distribution pp:

  1. Draw uUniform(0,1)u \sim \text{Uniform}(0, 1)
  2. Return z=F1(u)z = F^{-1}(u) where F1F^{-1} is the inverse CDF

Since F1F^{-1} is differentiable (when it exists), the path θz\theta \to z is differentiable. This works for Exponential, Logistic, Laplace, Beta (approximately), and others. For distributions without easy inverse CDFs, alternatives like implicit reparameterization (Figurnov et al., 2018) generalize the idea further.

Interactive example

Visualize how gradients flow through reparameterized vs direct sampling — toggle to see the broken graph

Coming soon

Quiz

1 / 3

Why can't you backpropagate through the operation z ~ N(μ, σ²) directly?