Skip to content
Normalization & Initialization
Lesson 3 ⏱ 10 min

BatchNorm at inference time

Video coming soon

BatchNorm at Inference: Running Statistics and the model.eval() Bug

Explains why batch statistics can't be used at inference, how exponential moving averages accumulate population statistics during training, and the notorious PyTorch bug that bites beginners.

⏱ ~6 min

🧮

Quick refresher

Exponential moving average

An exponential moving average (EMA) smoothly tracks a changing quantity by blending each new value with the accumulated estimate. Old values decay geometrically, so recent observations contribute more.

Example

With α=0.9: if the running estimate is 5 and a new batch gives μ=7, update to 0.9×5 + 0.1×7 = 5.2.

The estimate shifts toward the new value but doesn't jump there instantly.

You've seen the BatchNorm algorithm: compute batch statistics, normalize, scale and shift. It works beautifully during training. But the moment you deploy a model and run it on individual inputs, a problem emerges — there's no batch.

This is the gap between research and production that trips up many practitioners. A model trained with batch normalization behaves differently at inference time, and getting it wrong causes silent accuracy degradation. Every deep learning framework has special handling for this, and understanding it is essential before deploying any BatchNorm model.

The Inference Problem

During training, BatchNorm computes μB\mu_B and σB2\sigma^2_B over a mini-batch of examples. For a batch size of 64, that's 64 values per feature to compute statistics from — reasonable.

At inference time, you process one input at a time. Your "batch" has m=1m = 1. The mean of a single number is that number itself. The variance of a single number is zero. Normalizing: x^=(xx)/0+ε=0\hat{x} = (x - x)/\sqrt{0 + \varepsilon} = 0. Every single activation becomes zero. Your entire network's output is garbage.

The Solution: Running Statistics

The fix is to maintain a running estimate of the population mean and variance during training. After processing each mini-batch, update these running statistics using an :

μrunningαμrunning+(1α)μB\mu_\text{running} \leftarrow \alpha \cdot \mu_\text{running} + (1 - \alpha) \cdot \mu_B
μrunningμ_running
accumulated running mean estimate
αα
momentum — weight given to the current running value (default 0.9 in PyTorch notation)
μBμ_B
batch mean from the current mini-batch
σrunning2ασrunning2+(1α)σB2\sigma^2_\text{running} \leftarrow \alpha \cdot \sigma^2_\text{running} + (1 - \alpha) \cdot \sigma^2_B
σr2unningσ²_running
accumulated running variance estimate

After training on many batches, μrunning\mu_\text{running} and σrunning2\sigma^2_\text{running} converge to the true population mean and variance of the training data. At inference time, BatchNorm uses these running statistics instead of batch statistics.

Worked Example: Accumulating Running Mean

Suppose we're tracking one feature. Initial μrunning=0\mu_\text{running} = 0. Training proceeds:

Batchμ_Bμ_running (after update, α=0.9)
14.20.9×0 + 0.1×4.2 = 0.42
23.80.9×0.42 + 0.1×3.8 = 0.758
34.10.9×0.758 + 0.1×4.1 = 1.092
10~4.03.58
50~4.03.98

After ~50 batches, the running mean has converged close to the true feature mean of ≈ 4.0. The EMA progressively forgets old batches and converges to a stable estimate.

How Batch Size Affects Quality

The running statistics are only as good as the batch statistics that feed them. With large batches (128+), each μB\mu_B is a precise estimate of the true mean — the running average converges quickly and accurately.

With small batches (4 or fewer), each μB\mu_B has high variance: the batch might happen to contain all large values, or all small ones, by chance. The running statistics become noisy, and at inference, the normalization can be miscalibrated.

The Two Modes in Practice

BatchNorm behaves differently in training vs. inference mode. Here's the full picture:

Training mode (model.train()):

  • Uses batch statistics μ_B and σ²_B for normalization
  • Updates running_mean and running_var via EMA
  • Introduces stochastic noise (regularization effect)

Inference mode (model.eval()):

  • Uses running_mean and running_var (not batch statistics)
  • Deterministic: same input always gives same output
  • No running stat updates

The critical insight: the learned parameters (γ and β) are used identically in both modes. Only which statistics to normalize with changes.

What This Means for Your Workflow

Every time you train a BatchNorm model and then evaluate it:

  1. Call model.train() before the training loop — BatchNorm uses batch stats and updates running stats
  2. Call model.eval() before validation/inference — BatchNorm uses running stats
  3. When loading a pretrained model for fine-tuning, decide: do you want running stats from pretraining (keep eval mode) or recalibrate to your data (train mode)?

A particularly subtle bug arises with transfer learning: if you freeze the BatchNorm layers but keep them in training mode, the running statistics get corrupted by your new data distribution. The safe pattern is to freeze BatchNorm layers in eval mode when fine-tuning.

# Freeze BN layers properly during fine-tuning
for module in model.modules():
    if isinstance(module, nn.BatchNorm2d):
        module.eval()
        for param in module.parameters():
            param.requires_grad = False

The inference-time behavior of BatchNorm is one of the most important — and most frequently forgotten — details in practical deep learning. Next, we'll see LayerNorm, which sidesteps this problem entirely by not depending on batch statistics at all.

Quiz

1 / 3

At inference time, BatchNorm uses...