Skip to content
Normalization & Initialization
Lesson 5 ⏱ 10 min

Instance, group, and weight normalization

Video coming soon

The Normalization Family: Organizing by Axis

Unifies BatchNorm, LayerNorm, Instance Norm, and Group Norm into a single framework by showing which dimensions each one averages over, with a visual guide for the 4D CNN tensor.

⏱ ~6 min

🧮

Quick refresher

Tensors and dimensions

A 4D tensor has four axes: batch (N), channels (C), height (H), and width (W). The values can be thought of as a 4D array of numbers. Normalizing 'over dimensions H and W' means computing mean and variance by averaging across all spatial positions.

Example

A CNN feature map with shape [N=2, C=3, H=4, W=4] has 2 examples, 3 channels, and 4×4 spatial maps.

BatchNorm computes one mean per channel, averaging over N=2 examples and H×W=16 positions, giving 3 means total.

You've seen BatchNorm and LayerNorm. The rest of the normalization family follows a single organizing principle: which dimensions do you average over to compute the statistics? Once you see this, the whole family falls into place.

Instance normalization powers neural style transfer, group normalization is the standard for object detection with small batch sizes, and root mean square normalization is used in LLaMA. Knowing which normalization to reach for — and why — is a practical skill every deep learning practitioner needs.

The 4D CNN Tensor

For convolutional networks, activations have shape [N,C,H,W][N, C, H, W]:

  • Symbol: : batch size
  • Symbol: : channels
  • Symbol: : spatial height
  • Symbol: : spatial width

A typical mid-network activation for image classification might be [32,256,14,14][32, 256, 14, 14]: 32 images, 256 feature maps, 14×14 pixels each.

The four normalization schemes differ only in which subset of N,C,H,W{N, C, H, W} they reduce over:

MethodNormalize overOne statistic per
BatchNormN, H, WChannel C
LayerNormC, H, WExample N
Instance NormH, W(N, C) pair
Group Norm(subset of C), H, W(N, group) pair

Let's understand each one concretely.

BatchNorm (Recap for CNNs)

BatchNorm computes one mean per channel, averaging over all NN examples and all H×WH \times W spatial positions. For 256 channels, you get 256 means and 256 variances — one pair per channel, shared across the whole batch and all spatial locations.

This means every pixel in the same channel, across all examples in the batch, gets normalized by the same statistics. Works great when N is large. Breaks down when N is small.

Instance Normalization

The goes further: for each of the N×CN \times C feature maps, compute separate statistics from the H×WH \times W spatial positions within that map.

For a feature map at example nn, channel cc:

μnc=1HWh=1Hw=1Wxnchw\mu_{nc} = \frac{1}{HW} \sum_{h=1}^{H} \sum_{w=1}^{W} x_{nchw}
μncμ_{nc}
mean of feature map at example n, channel c
xnchwx_{nchw}
activation at spatial position (h,w)

Each feature map is normalized against itself, with no mixing between examples or channels.

Use case: style transfer. In style transfer (Gatys et al., 2016 and follow-ups), you want to transfer the style of one image to the content of another. Style is captured by per-channel activation statistics. If you use BatchNorm, statistics are computed across the batch — mixing the style of different images. Instance Norm keeps each image's style statistics pure, making it far more effective for style transfer and image generation tasks.

Group Normalization

The is a middle ground between LayerNorm (all channels together) and Instance Norm (one channel at a time). Divide the C channels into G groups of C/G channels each. For each (example, group) pair, compute statistics over the channels in that group and the H×W spatial positions.

μng=1(C/G)HWcgroup gh,wxnchw\mu_{ng} = \frac{1}{(C/G) \cdot HW} \sum_{c \in \text{group } g} \sum_{h,w} x_{nchw}
GG
number of groups to divide channels into
C/GC/G
channels per group

With G=C, Group Norm is Instance Norm. With G=1, it's LayerNorm (over all channels).

Use case: object detection. State-of-the-art object detectors like Faster R-CNN and Mask R-CNN train with 1-2 high-resolution images per GPU. BatchNorm is useless at these batch sizes. Group Norm maintains stable statistics because it doesn't depend on N at all. Wu & He (2018) showed Group Norm with G=32 matches or exceeds BatchNorm performance for batch sizes ≤ 8.

Weight Normalization: A Different Approach

All methods above normalize activations. Weight Normalization (Salimans & Kingma, 2016) normalizes the weights themselves.

For each weight vector , reparameterize as:

w=gvv\mathbf{w} = \frac{g}{|\mathbf{v}|} \mathbf{v}
ww
weight vector to be reparameterized
gg
scalar magnitude parameter — learned
vv
direction vector — learned
v||v||
Euclidean norm of v

Instead of learning ww directly, the network learns gg (scalar magnitude) and vv (direction vector) separately. The weight is always scaled to magnitude gg, regardless of how vv changes.

Why this helps: gradient descent can independently scale the magnitude without rotating the direction, and vice versa. In the original parameterization, scaling and rotating are coupled — changing one element of ww affects both its magnitude and direction simultaneously.

Key difference from other norms: weight normalization has no batch or layer statistics. It's entirely determined by the current weight values. This makes it useful for recurrent networks and reinforcement learning, where batch statistics are unreliable or undefined.

Decision Guide

What architecture are you using?
├── CNN with batch size ≥ 16?       → BatchNorm
├── CNN with batch size < 8?        → Group Norm (G=32)
├── Transformer / language model?   → LayerNorm
├── Style transfer / image generation? → Instance Norm
├── RNN or online learning?         → Weight Norm or Layer Norm
└── Variable-length, batch size 1?  → LayerNorm

In PyTorch:

nn.BatchNorm2d(C)    # norm over (N, H, W), per channel
nn.LayerNorm([C, H, W])  # norm over (C, H, W), per example
nn.InstanceNorm2d(C) # norm over (H, W), per (N, C) pair
nn.GroupNorm(G, C)   # norm over (C/G, H, W), per (N, group) pair

Next we turn to the other side of initialization: not which values the network produces, but where it starts — the initial weight values before any training begins.

Quiz

1 / 3

Instance Normalization normalizes over...