You've seen the BatchNorm algorithm: compute batch statistics, normalize, scale and shift. It works beautifully during training. But the moment you deploy a model and run it on individual inputs, a problem emerges — there's no batch.
This is the gap between research and production that trips up many practitioners. A model trained with batch normalization behaves differently at inference time, and getting it wrong causes silent accuracy degradation. Every deep learning framework has special handling for this, and understanding it is essential before deploying any BatchNorm model.
The Inference Problem
During training, BatchNorm computes and over a mini-batch of examples. For a batch size of 64, that's 64 values per feature to compute statistics from — reasonable.
At inference time, you process one input at a time. Your "batch" has . The mean of a single number is that number itself. The variance of a single number is zero. Normalizing: . Every single activation becomes zero. Your entire network's output is garbage.
The Solution: Running Statistics
The fix is to maintain a running estimate of the population mean and variance during training. After processing each mini-batch, update these running statistics using an :
- accumulated running mean estimate
- momentum — weight given to the current running value (default 0.9 in PyTorch notation)
- batch mean from the current mini-batch
- accumulated running variance estimate
After training on many batches, and converge to the true population mean and variance of the training data. At inference time, BatchNorm uses these running statistics instead of batch statistics.
Worked Example: Accumulating Running Mean
Suppose we're tracking one feature. Initial . Training proceeds:
| Batch | μ_B | μ_running (after update, α=0.9) |
|---|---|---|
| 1 | 4.2 | 0.9×0 + 0.1×4.2 = 0.42 |
| 2 | 3.8 | 0.9×0.42 + 0.1×3.8 = 0.758 |
| 3 | 4.1 | 0.9×0.758 + 0.1×4.1 = 1.092 |
| 10 | ~4.0 | ≈ 3.58 |
| 50 | ~4.0 | ≈ 3.98 |
After ~50 batches, the running mean has converged close to the true feature mean of ≈ 4.0. The EMA progressively forgets old batches and converges to a stable estimate.
How Batch Size Affects Quality
The running statistics are only as good as the batch statistics that feed them. With large batches (128+), each is a precise estimate of the true mean — the running average converges quickly and accurately.
With small batches (4 or fewer), each has high variance: the batch might happen to contain all large values, or all small ones, by chance. The running statistics become noisy, and at inference, the normalization can be miscalibrated.
The Two Modes in Practice
BatchNorm behaves differently in training vs. inference mode. Here's the full picture:
Training mode (model.train()):
- Uses batch statistics μ_B and σ²_B for normalization
- Updates running_mean and running_var via EMA
- Introduces stochastic noise (regularization effect)
Inference mode (model.eval()):
- Uses running_mean and running_var (not batch statistics)
- Deterministic: same input always gives same output
- No running stat updates
The critical insight: the learned parameters (γ and β) are used identically in both modes. Only which statistics to normalize with changes.
What This Means for Your Workflow
Every time you train a BatchNorm model and then evaluate it:
- Call
model.train()before the training loop — BatchNorm uses batch stats and updates running stats - Call
model.eval()before validation/inference — BatchNorm uses running stats - When loading a pretrained model for fine-tuning, decide: do you want running stats from pretraining (keep eval mode) or recalibrate to your data (train mode)?
A particularly subtle bug arises with transfer learning: if you freeze the BatchNorm layers but keep them in training mode, the running statistics get corrupted by your new data distribution. The safe pattern is to freeze BatchNorm layers in eval mode when fine-tuning.
# Freeze BN layers properly during fine-tuning
for module in model.modules():
if isinstance(module, nn.BatchNorm2d):
module.eval()
for param in module.parameters():
param.requires_grad = False
The inference-time behavior of BatchNorm is one of the most important — and most frequently forgotten — details in practical deep learning. Next, we'll see LayerNorm, which sidesteps this problem entirely by not depending on batch statistics at all.