The Shifting Distribution Problem
Think about what happens inside a deep neural network during training. The first layer's weights change with each gradient update. This changes the distribution of values passed to the second layer. The second layer's weights then need to adapt not only to get better at its task, but also to compensate for the shifting input distribution. The third layer faces an even more complex, doubly-shifted input.
Imagine a factory assembly line with ten stations. Station 1 cuts parts to a certain size and passes them to Station 2. Station 2 has calibrated its tools for those sizes. If Station 1 suddenly starts cutting parts 20% larger, Station 2 must re-calibrate — and when Station 2's calibration changes, Station 3 must also re-calibrate, and so on. None of the stations can settle into their optimal configuration because the upstream is always changing.
This cascade of shifting distributions is called . Deep networks are essentially trying to hit a moving target at every layer.
Batch normalization attacks this problem directly: after each layer (or before, depending on architecture), normalize the activations back to a stable distribution.
The Algorithm
For a single feature across a mini-batch of examples :
Step 1: Compute batch mean:
- batch mean for this feature
- number of examples in the batch
- the value of this feature for example i
Step 2: Compute batch variance:
- batch variance for this feature
- small constant for numerical stability — typical: 1e-5
Step 3: Normalize:
- normalized value of x_i
Step 4: Scale and shift with learnable parameters:
- learnable scale parameter (initialized to 1)
- learnable shift parameter (initialized to 0)
- final BatchNorm output for example i
The and parameters are trained by backpropagation. They give the network the freedom to undo the normalization if that's optimal, ensuring BatchNorm never reduces expressivity.
Worked Numerical Example
A batch of 4 examples, one feature with values . With , (learned values, shown for illustration):
Batch mean:
Batch variance:
Normalized values:
After learnable rescaling: , giving
The output is no longer constrained to zero mean / unit variance — γ and β shifted it. But the normalization step ensured this layer's output is not impacted by whatever scale the previous layer operated at.
Benefits in Practice
- Higher learning rates: Without BatchNorm, learning rates above 0.01 often cause divergence. With BatchNorm, learning rates of 0.1 or higher are stable, directly speeding up training.
- Less sensitive to initialization: The normalized activations stay in a reasonable range regardless of how weights were initialized.
- Mild regularization: As noted, the batch noise adds implicit regularization.
Placement in the Network
BatchNorm is typically placed after the linear transformation, before the activation function:
Linear → BatchNorm → ReLU → Linear → BatchNorm → ReLU → ...
Some architectures place it after the activation. Empirically, the pre-activation order often works slightly better, but the difference is usually small. In residual networks (ResNets), BatchNorm is placed inside the residual branch.
Code: BatchNorm in PyTorch
import torch.nn as nn
class MLP(nn.Module):
def __init__(self, in_features, hidden, out):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_features, hidden),
nn.BatchNorm1d(hidden), # BatchNorm after linear
nn.ReLU(),
nn.Linear(hidden, hidden),
nn.BatchNorm1d(hidden),
nn.ReLU(),
nn.Linear(hidden, out)
)
def forward(self, x):
return self.net(x)
# IMPORTANT: switch mode for training vs evaluation
model.train() # uses batch statistics + updates running stats
model.eval() # uses running statistics (no updates)
nn.BatchNorm1d is for 1D feature vectors (fully-connected layers). nn.BatchNorm2d is for 2D feature maps (convolutional layers), normalizing across the batch and spatial dimensions.