Skip to content
Neural Networks
Lesson 7 ⏱ 12 min

Layer normalization

Video coming soon

Layer Normalization - Normalizing Within a Single Example

The limitations of BatchNorm for sequences and small batches, how LayerNorm shifts the normalization axis from 'across the batch' to 'across the features of one example', and its place in every major transformer architecture.

⏱ ~7 min

🧮

Quick refresher

Batch normalization

BatchNorm normalizes each feature across the batch dimension: for feature j, it subtracts the batch mean and divides by the batch standard deviation. This requires computing statistics over multiple examples simultaneously. The learnable γ and β rescale after normalization.

Example

With batch size 4 and feature dimension 3, BatchNorm computes one mean and one std per feature (3 means, 3 stds total) by averaging across the 4 examples.

Each of the 4 examples' feature j values are normalized using the same μⱼ and σⱼ.

The Limits of Batch Normalization

BatchNorm has a fundamental dependency: it needs multiple examples in a batch to compute meaningful statistics. In the three scenarios below, this dependency causes real problems:

Small batches: When memory constraints or architectural choices force batch size 1 or 2, batch statistics are noisy or undefined. BatchNorm degrades or breaks entirely.

Variable-length sequences: In NLP, sentences have different lengths. Padding brings them to equal length, but padded positions are not real data. Mixing statistics across real and padded positions corrupts the normalization.

Recurrent networks: In an RNN processing sequences step by step, each time step sees different input patterns. Computing batch statistics that mix different time steps conflates structurally different positions.

Layer Normalization sidesteps all of these problems with a single key change: normalize across the feature dimension, not the batch dimension.

The Algorithm

For a single example with feature vector of dimension :

Step 1: Compute the mean over features for this single example:

μ=1di=1dxi\mu = \frac{1}{d} \sum_{i=1}^{d} x_i
μ\mu
mean of this example's features
dd
number of features
xix_i
the i-th feature of this example

Step 2: Compute the variance over features:

σ2=1di=1d(xiμ)2\sigma^2 = \frac{1}{d} \sum_{i=1}^{d} (x_i - \mu)^2
σ2\sigma^2
variance of this example's features

Step 3: Normalize:

x^i=xiμσ2+ε\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \varepsilon}}
x^i\hat{x}_i
normalized value of feature i
ε\varepsilon
numerical stability constant — typical: 1e-5

Step 4: Learned rescaling:

y=γx^+βy = \gamma \odot \hat{x} + \beta
γ\gamma
learnable scale vector of dimension d
β\beta
learnable shift vector of dimension d
\odot
elementwise multiplication
yy
final LayerNorm output

Note that and here are vectors (one per feature), not scalars.

Worked Numerical Example

Single example with features x=[6,2,4,8]x = [6, 2, 4, 8], γ=[1,1,1,1]\gamma = [1,1,1,1], β=[0,0,0,0]\beta = [0,0,0,0]:

μ=(6+2+4+8)/4=5\mu = (6+2+4+8)/4 = 5
σ2=[(65)2+(25)2+(45)2+(85)2]/4=[1+9+1+9]/4=5\sigma^2 = [(6-5)^2 + (2-5)^2 + (4-5)^2 + (8-5)^2]/4 = [1+9+1+9]/4 = 5
σ=52.236\sigma = \sqrt{5} \approx 2.236
x^=[(65)/2.236, (25)/2.236, (45)/2.236, (85)/2.236]\hat{x} = [(6-5)/2.236,\ (2-5)/2.236,\ (4-5)/2.236,\ (8-5)/2.236]
x^[0.447, 1.342, 0.447, 1.342]\hat{x} \approx [0.447,\ -1.342,\ -0.447,\ 1.342]

These values have mean 0 and variance 1 within this single example, regardless of what other examples look like or what batch size is used.

Where Each Is Used

Architecture typeNorm typeWhy
Vision CNNs (ResNet, EfficientNet)BatchNormLarge batches, fixed spatial structure
Transformers (BERT, GPT, T5)LayerNormSequences, variable lengths, small batches
RNNs, LSTMsLayerNormSequential processing, batch stats unstable
Object detectionGroupNorm or BatchNormSmall batch sizes sometimes
Diffusion modelsGroupNormOperates on spatial feature groups

GroupNorm is a middle ground: normalize over groups of channels within each example. It avoids batch dependency (like LayerNorm) while respecting spatial structure (like BatchNorm). Used when batch size is forced to be small (e.g., detection with high-resolution images).

Code: LayerNorm in PyTorch

import torch.nn as nn

d_model = 512  # feature dimension

# LayerNorm over last dimension (features)
ln = nn.LayerNorm(d_model)  # γ and β are d_model-dimensional vectors

# In a transformer block:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.ff = nn.Sequential(nn.Linear(d_model, 4*d_model), nn.GELU(), nn.Linear(4*d_model, d_model))
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        # Pre-norm style (modern)
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.ff(self.norm2(x))
        return x

nn.LayerNorm(d_model) normalizes over the last d_model dimensions. It automatically handles any batch size including 1, making it safe for both training and inference.

Quiz

1 / 3

For a single feature vector x = [3, 1, 4, 2], what is the LayerNorm output (before γ and β), with ε ≈ 0?