Skip to content
Normalization & Initialization
Lesson 6 ⏱ 12 min

Why weight initialization matters

Video coming soon

Why Weight Initialization Matters: Variance Propagation

Numerically demonstrates activation explosion and vanishing from poor initialization, derives the variance propagation formula through a linear layer, and shows what scale of initialization is just right.

⏱ ~6 min

🧮

Quick refresher

Variance of a product of random variables

If X and Y are independent random variables with mean 0, then Var(XY) = Var(X)·Var(Y). This means variance multiplies when you multiply independent zero-mean random variables together.

Example

X ~ N(0, 2) and Y ~ N(0, 3): Var(XY) = 2×3 = 6.

Standard deviation of XY = √6 ≈ 2.45.

You've set up your network architecture. Every parameter starts as a random number. But how random? What distribution should you sample from?

This sounds like an implementation detail — surely the network will learn its way to good weights regardless. It won't. The initial scale of weights determines whether training succeeds at all, especially for deep networks. Let's see exactly why.

Poor weight initialization can cause training to fail before it starts. A network initialized with weights that are slightly too large will explode activations through every layer; slightly too small and all gradients vanish. Xavier and Kaiming initialization are the principled solutions that every modern deep learning library uses by default.

The Problem: Weights Are Multiplied

A network's forward pass is a long chain of matrix multiplications. Each layer takes the previous layer's output, multiplies it by a weight matrix, and passes the result forward. If the weight matrices are slightly too large, the multiplication amplifies signals. Slightly too small, and it shrinks them.

Neither effect sounds catastrophic for a single layer. But they're applied again, and again, and again — once per layer. Amplification or shrinkage compounds exponentially.

Variance Propagation: The Math

Consider one layer of neurons. Each output neuron computes:

y=i=1nwixiy = \sum_{i=1}^{n} w_i x_i
yy
output of one neuron
wiw_i
weight connecting input i to this neuron
xix_i
i-th input value

Assume wiw_i and xix_i are independent, each with mean 0 and variances and .

Since each term wixiw_i x_i is independent with variance Var(W)Var(x)\text{Var}(W) \cdot \text{Var}(x), and we're summing nn such terms:

Biologists will recognize this structure from the variance of a sum of independent random variables — this is the same variance propagation rule used in error propagation analysis.

Var(y)=nVar(W)Var(x)\text{Var}(y) = n \cdot \text{Var}(W) \cdot \text{Var}(x)
Var(y)Var(y)
variance of the output neuron

This one equation governs everything. Layer by layer, variance is multiplied by nVar(W)n \cdot \text{Var}(W).

The Explosion Case

Say n=100n = 100 neurons per layer, weights drawn from N(0,1)\mathcal{N}(0, 1) so Var(W)=1\text{Var}(W) = 1:

Var(y)=100×1×Var(x)=100Var(x)\text{Var}(y) = 100 \times 1 \times \text{Var}(x) = 100 \cdot \text{Var}(x)

Each layer multiplies variance by 100. Starting with Var(x0)=1\text{Var}(x_0) = 1, after LL layers:

VarL=100L\text{Var}_L = 100^L
VarLVar_L
variance after L layers
LayersVarianceStd Dev
110010
210,000100
510¹⁰10⁵
1010²⁰10¹⁰

After 10 layers, activations have standard deviation ten billion. Any loss computed on these values is numerically meaningless — usually NaN within the first forward pass.

The Vanishing Case

Now try small weights: N(0,0.001)\mathcal{N}(0, 0.001), so Var(W)=0.001\text{Var}(W) = 0.001:

Var(y)=100×0.001=0.1 per layer\text{Var}(y) = 100 \times 0.001 = 0.1 \text{ per layer}

After 10 layers: 0.110=10100.1^{10} = 10^{-10}. Standard deviation: 10510^{-5}. All activations collapse to near-zero.

A network with vanished activations learns nothing. All the neurons produce the same near-zero output regardless of input — no signal can propagate.

The Goldilocks Condition

We want variance to be preserved through the network: Var(y)=Var(x)\text{Var}(y) = \text{Var}(x). From the propagation formula:

nVar(W)Var(x)=Var(x)n \cdot \text{Var}(W) \cdot \text{Var}(x) = \text{Var}(x)

Solving:

Var(W)=1n\text{Var}(W) = \frac{1}{n}

Initialize each weight as wN(0,1/n)w \sim \mathcal{N}(0, 1/n) and the variance stays constant through every layer. Activations neither grow nor shrink.

Let's verify numerically. With n=100n = 100 and Var(W)=1/100=0.01\text{Var}(W) = 1/100 = 0.01:

Var(y)=100×0.01×Var(x)=1.0×Var(x)\text{Var}(y) = 100 \times 0.01 \times \text{Var}(x) = 1.0 \times \text{Var}(x)

Variance is exactly preserved. After 10 layers: 1.010=1.01.0^{10} = 1.0. ✓

Why This Still Isn't Complete

The analysis above assumed linear layers (no activation functions). With nonlinearities, the story changes:

ReLU zeros out half of all activations (those where the input is negative). This effectively halves the variance at every layer. You'd need to compensate by doubling the weight variance: Var(W)=2/n\text{Var}(W) = 2/n.

Tanh and sigmoid compress values into [1,1][-1, 1]. For small inputs near zero, they're approximately linear (slope ≈ 1). For large inputs, they saturate (slope ≈ 0). If activations are too large, all units saturate and gradients vanish — even with well-initialized weights.

The next lesson derives the two most important initialization schemes — Xavier (for tanh/sigmoid) and He (for ReLU) — from exactly this analysis.

Interactive example

Track activation variance across layers — adjust weight scale and layer count

Coming soon

A Quick Reference

# BAD: Default Normal — explodes at scale 1
nn.init.normal_(layer.weight, mean=0, std=1)

# GOOD: Fan-in normalization
n = layer.weight.shape[1]  # fan-in
nn.init.normal_(layer.weight, mean=0, std=(1/n)**0.5)

# BETTER: Use the derived schemes (next lesson)
nn.init.xavier_normal_(layer.weight)   # for tanh/sigmoid
nn.init.kaiming_normal_(layer.weight)  # for ReLU

The derived schemes apply exactly the analysis above, with careful corrections for the specific activation function used.

Quiz

1 / 3

A 10-layer network uses weights W ~ N(0,1) with 100 neurons per layer. Starting with Var(input)=1, what is the approximate variance after 10 layers?