Skip to content
Neural Networks
Lesson 8 ⏱ 12 min

Weight initialization: Xavier and He

Video coming soon

Weight Initialization - Preventing Signals from Exploding or Vanishing

Deriving the variance propagation formula for a linear layer, showing numerically how bad initialization kills or explodes signals in a 5-layer network, and arriving at Xavier and He initialization from the variance-preservation condition.

⏱ ~7 min

🧮

Quick refresher

Variance and standard deviation

Variance measures how spread out a set of values is: Var(x) = E[(x - E[x])²]. If values are zero-mean, Var(x) = E[x²]. The standard deviation is the square root of variance. For independent random variables, variances add: Var(X+Y) = Var(X) + Var(Y).

Example

If X ~ N(0, 1), then Var(X) = 1.

If Y = 3X, then Var(Y) = 9·Var(X) = 9.

The constant 3 multiplies the standard deviation and multiplies the variance by 3² = 9.

Why Initialization Matters

A neural network at initialization has random weights. Those random weights get multiplied together across every layer to produce activations and gradients. This is where a subtle but catastrophic failure can happen before training even begins.

Consider a 10-layer network. If each layer multiplies the signal magnitude by 1.5 on average, after 10 layers the magnitude is 1.510571.5^{10} \approx 57. If each layer multiplies by 0.7, after 10 layers the magnitude is 0.7100.0280.7^{10} \approx 0.028.

Exploding activations: values grow exponentially, gradients blow up, training fails immediately. Vanishing activations: values shrink to near zero, gradients also near zero, no learning happens.

The right initialization keeps signal magnitudes approximately stable as they pass through the network.

Deriving the Variance Condition

Consider a single layer: y=Wxy = Wx, where is n×m and is m×1. Output neuron ii:

yi=j=1mWijxjy_i = \sum_{j=1}^{m} W_{ij} x_j

Assume W and x are independent, both zero-mean. Each term WijxjW_{ij} x_j has variance Var(Wij)Var(xj)\text{Var}(W_{ij}) \cdot \text{Var}(x_j). Summing m independent terms:

Var(yi)=mVar(W)Var(x)\text{Var}(y_i) = m \cdot \text{Var}(W) \cdot \text{Var}(x)
Var(yi)\text{Var}(y_i)
variance of output neuron i
mm
fan_in — number of input neurons (input dimension)
Var(W)\text{Var}(W)
variance of each weight (all weights share this)
Var(x)\text{Var}(x)
variance of input features

To keep Var(y)=Var(x)\text{Var}(y) = \text{Var}(x) (no signal amplification), we need:

mVar(W)=1Var(W)=1mm \cdot \text{Var}(W) = 1 \quad \Longrightarrow \quad \text{Var}(W) = \frac{1}{m}
mm
fan_in

Xavier/Glorot Initialization

The forward-pass condition gives Var(W) = 1/fan_in. The backward-pass (gradient) condition gives Var(W) = 1/fan_out. Xavier/Glorot initialization takes the compromise:

\text{Var}(W) = \frac{2}{\text{fan_in} + \text{fan_out}}
\text{fan_in}
number of input neurons to this layer
\text{fan_out}
number of output neurons from this layer

In practice, this is usually implemented as a uniform distribution:

W_{ij} \sim \mathcal{U}!\left(-\sqrt{\frac{6}{\text{fan_in} + \text{fan_out}}},\ +\sqrt{\frac{6}{\text{fan_in} + \text{fan_out}}}\right)
WijW_{ij}
each weight, drawn uniformly

Xavier is designed for tanh and sigmoid activations — functions that are roughly linear near zero, preserving the variance analysis above.

He Initialization

ReLU introduces a new wrinkle. For a zero-mean Gaussian input, ReLU zeroes out roughly 50% of values (all the negatives). The variance of the output is:

Var(ReLU(x))12Var(x)\text{Var}(\text{ReLU}(x)) \approx \frac{1}{2} \text{Var}(x)

Each layer with ReLU halves the signal variance. He initialization (Kaiming initialization) corrects for this by doubling the weight variance:

W_{ij} \sim \mathcal{N}!\left(0,\ \sqrt{\frac{2}{\text{fan_in}}}\right)
\text{fan_in}
number of input neurons to this layer

The factor 2 in the numerator exactly compensates for ReLU's 50% kill rate, keeping variance stable across layers.

Worked Numerical Comparison

5-layer network, each layer has fan_in = fan_out = 256:

Bad initialization: Var(W) = 1 (too large)

LayerVar(activation)
1256 × 1 × 1 = 256
2256 × 1 × 256 = 65,536
3~16,777,216
4~4.3 × 10⁹
5overflow

He initialization: Var(W) = 2/256 ≈ 0.0078 (for ReLU)

LayerVar(activation)
1256 × 0.0078 × 1 ≈ 1.0 (ReLU halves: 0.5 → × 2 = 1)
2≈ 1.0
3≈ 1.0
4≈ 1.0
5≈ 1.0

Signal variance stays stable through the entire forward pass.

Which to Use

ActivationInitializationWhy
tanh, sigmoidXavier/GlorotDesigned for near-linear activations
ReLUHe (Kaiming)Corrects for 50% kill rate
GELU, SwishHe (usually)Similar to ReLU, safe default
Linear (no activation)Xavier or HeBoth work; Xavier common

Code: Initialization in PyTorch

import torch.nn as nn

# PyTorch default for Linear: Kaiming uniform (He-like)
layer = nn.Linear(256, 256)  # already initialized with Kaiming uniform

# Explicit initialization
nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')  # He normal
nn.init.xavier_uniform_(layer.weight)                       # Xavier uniform

# Custom initialization for a whole model
def init_weights(module):
    if isinstance(module, nn.Linear):
        nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
        nn.init.zeros_(module.bias)

model.apply(init_weights)

kaiming_normal_ is He initialization with a normal distribution. kaiming_uniform_ uses a uniform distribution. For most ReLU-based networks, either works — the difference is minor in practice when BatchNorm is present.

Quiz

1 / 3

For a layer computing y = Wx where W is n×m and x is m×1, both W and x are i.i.d. zero-mean. What is Var(yᵢ) in terms of n, m, Var(W), Var(x)?