Skip to content
Neural Networks
Lesson 6 ⏱ 14 min

Batch normalization

Video coming soon

Batch Normalization - Stabilizing the Shifting Distribution Problem

The internal covariate shift problem visualized, the BatchNorm algorithm step-by-step on a small batch, why learnable γ and β are essential, and the effects on training stability and speed.

⏱ ~8 min

🧮

Quick refresher

Standardization (z-score normalization)

Standardizing a set of values means subtracting the mean and dividing by the standard deviation: z = (x - μ) / σ. The result has mean 0 and standard deviation 1. This makes values from different scales comparable and is the core operation in batch normalization.

Example

Values [2, 4, 4, 4, 5, 5, 7, 9]: mean μ=5, std σ=2.

Standardized: [(2-5)/2, (4-5)/2, …, (9-5)/2] = [-1.5, -0.5, -0.5, -0.5, 0, 0, 1, 2].

The new mean is 0 and std is 1.

The Shifting Distribution Problem

Think about what happens inside a deep neural network during training. The first layer's weights change with each gradient update. This changes the distribution of values passed to the second layer. The second layer's weights then need to adapt not only to get better at its task, but also to compensate for the shifting input distribution. The third layer faces an even more complex, doubly-shifted input.

Imagine a factory assembly line with ten stations. Station 1 cuts parts to a certain size and passes them to Station 2. Station 2 has calibrated its tools for those sizes. If Station 1 suddenly starts cutting parts 20% larger, Station 2 must re-calibrate — and when Station 2's calibration changes, Station 3 must also re-calibrate, and so on. None of the stations can settle into their optimal configuration because the upstream is always changing.

This cascade of shifting distributions is called . Deep networks are essentially trying to hit a moving target at every layer.

Batch normalization attacks this problem directly: after each layer (or before, depending on architecture), normalize the activations back to a stable distribution.

The Algorithm

For a single feature across a mini-batch of examples x1,,xm{x_1, \ldots, x_m}:

Step 1: Compute batch mean:

μB=1mi=1mxi\mu_B = \frac{1}{m} \sum_{i=1}^{m} x_i
μB\mu_B
batch mean for this feature
mm
number of examples in the batch
xix_i
the value of this feature for example i

Step 2: Compute batch variance:

σB2=1mi=1m(xiμB)2\sigma_B^2 = \frac{1}{m} \sum_{i=1}^{m} (x_i - \mu_B)^2
σB2\sigma_B^2
batch variance for this feature
ε\varepsilon
small constant for numerical stability — typical: 1e-5

Step 3: Normalize:

x^i=xiμBσB2+ε\hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \varepsilon}}
x^i\hat{x}_i
normalized value of x_i

Step 4: Scale and shift with learnable parameters:

yi=γx^i+βy_i = \gamma \cdot \hat{x}_i + \beta
γ\gamma
learnable scale parameter (initialized to 1)
β\beta
learnable shift parameter (initialized to 0)
yiy_i
final BatchNorm output for example i

The and parameters are trained by backpropagation. They give the network the freedom to undo the normalization if that's optimal, ensuring BatchNorm never reduces expressivity.

Worked Numerical Example

A batch of 4 examples, one feature with values x=[2,4,8,10]x = [2, 4, 8, 10]. With γ=2\gamma = 2, β=1\beta = 1 (learned values, shown for illustration):

Batch mean: μB=(2+4+8+10)/4=6\mu_B = (2 + 4 + 8 + 10)/4 = 6

Batch variance: σB2=[(26)2+(46)2+(86)2+(106)2]/4=[16+4+4+16]/4=10\sigma_B^2 = [(2-6)^2 + (4-6)^2 + (8-6)^2 + (10-6)^2]/4 = [16+4+4+16]/4 = 10

Normalized values: x^[1.265, 0.632, 0.632, 1.265]\hat{x} \approx [-1.265,\ -0.632,\ 0.632,\ 1.265]

After learnable rescaling: y=γx^+β=2x^+1y = \gamma \hat{x} + \beta = 2\hat{x} + 1, giving y[1.53, 0.74, 2.26, 3.53]y \approx [-1.53,\ 0.74,\ 2.26,\ 3.53]

The output is no longer constrained to zero mean / unit variance — γ and β shifted it. But the normalization step ensured this layer's output is not impacted by whatever scale the previous layer operated at.

Benefits in Practice

  1. Higher learning rates: Without BatchNorm, learning rates above 0.01 often cause divergence. With BatchNorm, learning rates of 0.1 or higher are stable, directly speeding up training.
  2. Less sensitive to initialization: The normalized activations stay in a reasonable range regardless of how weights were initialized.
  3. Mild regularization: As noted, the batch noise adds implicit regularization.

Placement in the Network

BatchNorm is typically placed after the linear transformation, before the activation function:

Linear → BatchNorm → ReLU → Linear → BatchNorm → ReLU → ...

Some architectures place it after the activation. Empirically, the pre-activation order often works slightly better, but the difference is usually small. In residual networks (ResNets), BatchNorm is placed inside the residual branch.

Code: BatchNorm in PyTorch

import torch.nn as nn

class MLP(nn.Module):
    def __init__(self, in_features, hidden, out):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_features, hidden),
            nn.BatchNorm1d(hidden),   # BatchNorm after linear
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.BatchNorm1d(hidden),
            nn.ReLU(),
            nn.Linear(hidden, out)
        )

    def forward(self, x):
        return self.net(x)

# IMPORTANT: switch mode for training vs evaluation
model.train()   # uses batch statistics + updates running stats
model.eval()    # uses running statistics (no updates)

nn.BatchNorm1d is for 1D feature vectors (fully-connected layers). nn.BatchNorm2d is for 2D feature maps (convolutional layers), normalizing across the batch and spatial dimensions.

Quiz

1 / 3

A batch contains four values for one feature: [1, 3, 5, 7]. What is the batch-normalized output (before the learnable rescaling)?