The Limits of Batch Normalization
BatchNorm has a fundamental dependency: it needs multiple examples in a batch to compute meaningful statistics. In the three scenarios below, this dependency causes real problems:
Small batches: When memory constraints or architectural choices force batch size 1 or 2, batch statistics are noisy or undefined. BatchNorm degrades or breaks entirely.
Variable-length sequences: In NLP, sentences have different lengths. Padding brings them to equal length, but padded positions are not real data. Mixing statistics across real and padded positions corrupts the normalization.
Recurrent networks: In an RNN processing sequences step by step, each time step sees different input patterns. Computing batch statistics that mix different time steps conflates structurally different positions.
Layer Normalization sidesteps all of these problems with a single key change: normalize across the feature dimension, not the batch dimension.
The Algorithm
For a single example with feature vector of dimension :
Step 1: Compute the mean over features for this single example:
- mean of this example's features
- number of features
- the i-th feature of this example
Step 2: Compute the variance over features:
- variance of this example's features
Step 3: Normalize:
- normalized value of feature i
- numerical stability constant — typical: 1e-5
Step 4: Learned rescaling:
- learnable scale vector of dimension d
- learnable shift vector of dimension d
- elementwise multiplication
- final LayerNorm output
Note that and here are vectors (one per feature), not scalars.
Worked Numerical Example
Single example with features , , :
These values have mean 0 and variance 1 within this single example, regardless of what other examples look like or what batch size is used.
Where Each Is Used
| Architecture type | Norm type | Why |
|---|---|---|
| Vision CNNs (ResNet, EfficientNet) | BatchNorm | Large batches, fixed spatial structure |
| Transformers (BERT, GPT, T5) | LayerNorm | Sequences, variable lengths, small batches |
| RNNs, LSTMs | LayerNorm | Sequential processing, batch stats unstable |
| Object detection | GroupNorm or BatchNorm | Small batch sizes sometimes |
| Diffusion models | GroupNorm | Operates on spatial feature groups |
GroupNorm is a middle ground: normalize over groups of channels within each example. It avoids batch dependency (like LayerNorm) while respecting spatial structure (like BatchNorm). Used when batch size is forced to be small (e.g., detection with high-resolution images).
Code: LayerNorm in PyTorch
import torch.nn as nn
d_model = 512 # feature dimension
# LayerNorm over last dimension (features)
ln = nn.LayerNorm(d_model) # γ and β are d_model-dimensional vectors
# In a transformer block:
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.ff = nn.Sequential(nn.Linear(d_model, 4*d_model), nn.GELU(), nn.Linear(4*d_model, d_model))
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
# Pre-norm style (modern)
x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
x = x + self.ff(self.norm2(x))
return x
nn.LayerNorm(d_model) normalizes over the last d_model dimensions. It automatically handles any batch size including 1, making it safe for both training and inference.