Skip to content
Normalization & Initialization
Lesson 7 ⏱ 14 min

Xavier and He initialization: the math

Video coming soon

Xavier and He Initialization Derived

Step-by-step derivation of Xavier initialization for tanh networks and He initialization for ReLU networks, with numerical verification that variance stays stable through 10 layers.

⏱ ~7 min

🧮

Quick refresher

Variance of a uniform distribution

A uniform distribution U(a, b) has variance (b-a)²/12. To get Var(W) = c from a uniform distribution symmetric around 0, use U(-L, L) where L = √(3c).

Example

To initialize with Var(W) = 2/300: L = √(3 × 2/300) = √(1/50) ≈ 0.141.

Use U(-0.141, +0.141).

The previous lesson showed that we need Var(W)=1/n\text{Var}(W) = 1/n to keep variance stable through a linear layer. This was derived for a purely linear network. Two refinements make this practical: correcting for the forward-backward asymmetry (Xavier), and correcting for activation functions that kill neurons (He).

Xavier and Kaiming (He) initialization are the defaults in PyTorch and every other major deep learning library. They are why torch.nn.Linear and torch.nn.Conv2d work out of the box without requiring you to set the initial weights manually.

Xavier Initialization: Forward and Backward

The variance propagation analysis has two sides. The forward pass cares about . The backward pass cares about .

Forward pass condition: to keep activation variance stable:

Var(W)=1nin\text{Var}(W) = \frac{1}{n_\text{in}}
ninn_in
fan-in: number of inputs to each neuron in this layer

Backward pass condition: gradients flow backward through the transpose of each weight matrix. By an identical analysis, to keep gradient variance stable:

Var(W)=1nout\text{Var}(W) = \frac{1}{n_\text{out}}
noutn_out
fan-out: number of outputs from each neuron in this layer

These two conditions can't both be satisfied exactly unless nin=noutn_\text{in} = n_\text{out}. Glorot and Bengio (2010) proposed the harmonic mean:

Var(W)=2nin+nout\text{Var}(W) = \frac{2}{n_\text{in} + n_\text{out}}
Var(W)Var(W)
Xavier initialization variance

This approximately preserves variance in both directions. For the uniform version, recall that Var(U(L,L))=L2/3\text{Var}(U(-L, L)) = L^2/3, so we need L2/3=2/(nin+nout)L^2/3 = 2/(n_\text{in} + n_\text{out}):

L=6nin+noutL = \sqrt{\frac{6}{n_\text{in} + n_\text{out}}}
LL
half-width of uniform distribution

Xavier uniform: WU!(6nin+nout,;+6nin+nout)W \sim U!\left(-\sqrt{\frac{6}{n_\text{in}+n_\text{out}}},; +\sqrt{\frac{6}{n_\text{in}+n_\text{out}}}\right)

Xavier normal: WN!(0,;2nin+nout)W \sim \mathcal{N}!\left(0,; \sqrt{\frac{2}{n_\text{in}+n_\text{out}}}\right)

Xavier was designed for like tanh and sigmoid. The key assumption: the derivative of the activation function at zero is approximately 1 (true for tanh: \tanh'(0) = 1).

He Initialization: Correcting for ReLU

ReLU violates the assumption that the activation derivative is 1. ReLU(x)=max(0,x)\text{ReLU}(x) = \max(0, x) outputs zero for all negative inputs. For a layer whose inputs have mean 0, roughly half the values are negative, so roughly half the ReLU outputs are exactly zero.

This means the is nin/2n_\text{in}/2, not ninn_\text{in}. The forward variance analysis becomes:

Var(y)nin2Var(W)Var(x)\text{Var}(y) \approx \frac{n_\text{in}}{2} \cdot \text{Var}(W) \cdot \text{Var}(x)

For this to equal Var(x)\text{Var}(x):

Var(W)=2nin\text{Var}(W) = \frac{2}{n_\text{in}}
Var(W)Var(W)
He initialization variance

This is He initialization (He et al., 2015, also called Kaiming initialization). The factor of 2 compensates for ReLU zeroing half the neurons.

He normal: WN!(0,;2nin)W \sim \mathcal{N}!\left(0,; \sqrt{\frac{2}{n_\text{in}}}\right)

He uniform: WU!(6nin,;+6nin)W \sim U!\left(-\sqrt{\frac{6}{n_\text{in}}},; +\sqrt{\frac{6}{n_\text{in}}}\right)

Numerical Verification: 10-Layer ReLU Network

Let's verify both initializations through 10 layers.

Xavier with ReLU (wrong choice): Var(W)=2/(nin+nout)1/n\text{Var}(W) = 2/(n_\text{in}+n_\text{out}) \approx 1/n

Each layer: Var(y)(n/2)×(1/n)×Var(x)=0.5×Var(x)\text{Var}(y) \approx (n/2) \times (1/n) \times \text{Var}(x) = 0.5 \times \text{Var}(x)

After 10 layers: 0.5100.0010.5^{10} \approx 0.001. Variance shrinks 1000× — vanishing activations.

He initialization with ReLU: Var(W)=2/n\text{Var}(W) = 2/n

Each layer: Var(y)(n/2)×(2/n)×Var(x)=1.0×Var(x)\text{Var}(y) \approx (n/2) \times (2/n) \times \text{Var}(x) = 1.0 \times \text{Var}(x)

After 10 layers: 1.010=1.01.0^{10} = 1.0. Variance is stable. ✓

Variants for Leaky ReLU

Leaky ReLU has slope for negative inputs instead of 0. This means (1-a) fraction of inputs pass through, but with scaling aa rather than 0. The correction factor becomes:

Var(W)=2(1+a2)nin\text{Var}(W) = \frac{2}{(1+a^2) \cdot n_\text{in}}
aa
negative slope parameter of Leaky ReLU

For standard Leaky ReLU with a=0.01a = 0.01: denominator = 1.0001×ninnin1.0001 \times n_\text{in} \approx n_\text{in} — barely changes. For a=0.2a = 0.2: denominator = 1.04×nin1.04 \times n_\text{in} — a modest correction.

PyTorch Defaults and How to Override

import torch.nn as nn

# Xavier (Glorot) — default for Linear with tanh/sigmoid
nn.init.xavier_uniform_(layer.weight)
nn.init.xavier_normal_(layer.weight)

# He (Kaiming) — for ReLU
nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu')
nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')

# PyTorch defaults by layer type:
# nn.Linear    → Kaiming uniform (He) with fan_in mode
# nn.Conv2d    → Kaiming uniform (He) with fan_in mode
# nn.Embedding → Normal(0, 1) — you often want to override this

One subtlety: kaiming_uniform_ in PyTorch defaults to mode='fan_in', using only the input dimension. Pass mode='fan_out' to use the output dimension — useful if you're more concerned about gradient variance than forward variance.

The practical rule: use He/Kaiming for any network with ReLU or Leaky ReLU. Use Xavier for tanh, sigmoid, or linear activations. For transformers with Layer Normalization, initialization matters less because LayerNorm keeps statistics stable regardless — but Xavier or He are still sensible defaults.

Quiz

1 / 3

Xavier initialization uses Var(W) = 2/(fan_in + fan_out) rather than just 1/fan_in because...