Skip to content
Backpropagation
Lesson 4 ⏱ 12 min

The full training loop

Video coming soon

The Complete Training Loop: From Data to Trained Model

Walking through the full training algorithm - initialization, forward pass, loss, backward pass, and weight update - and mapping it to PyTorch's five-line training step.

⏱ ~7 min

🧮

Quick refresher

Gradient descent update rule

Gradient descent updates each weight in the direction that decreases the loss: w ← w - α(∂L/∂w). The forward pass computes the prediction. Backpropagation computes all the gradients. The update rule applies them.

Example

optimizer.zero_grad() clears old gradients.

model(X) runs forward pass.

loss.backward() runs backprop.

optimizer.step() applies all weight updates.

Putting Everything Together

You now have all the pieces. Gradient descent gives the update rule. Forward pass computes the prediction. Cross-entropy measures the error. Backpropagation computes all the gradients. Here is how they fit together into the complete training loop.

Think of it as a student sitting an exam, then reviewing it:

  1. Forward pass — the student answers all the questions (makes a prediction).
  2. Loss — the grade comes back: here is what you got wrong and by how much.
  3. Backward pass — the student traces back: which knowledge gaps caused which errors? How should I study differently?
  4. Update — the student adjusts their understanding. Repeat for the next exam (next batch).

After thousands of such cycles, the student's understanding converges. That is training.

The Full Algorithm

The training loop in pseudocode:

Initialize W^(l), b^(l) for all layers l (small random values)

For epoch = 1 to max_epochs:
  Shuffle the training dataset

  For each mini-batch (X_batch, y_batch) of size B:

    # FORWARD PASS
    a^(0) = X_batch
    For l = 1 to L:
      z^(l) = W^(l) · a^(l-1) + b^(l)   <- cache
      a^(l) = activation(z^(l))           <- cache

    # LOSS
    L_batch = (1/B) * sum_i loss(a^(L)_i, y_i)

    # BACKWARD PASS
    delta^(L) = y_hat - y
    For l = L-1 down to 1:
      delta^(l) = (W^(l+1)^T · delta^(l+1)) * sigma'(z^(l))
      dL/dW^(l) = delta^(l) · (a^(l-1))^T
      dL/db^(l) = delta^(l)

    # UPDATE
    For l = 1 to L:
      W^(l) <- W^(l) - alpha * dL/dW^(l)
      b^(l) <- b^(l) - alpha * dL/db^(l)

  Evaluate on validation set, stop early if not improving

In PyTorch: Five Lines

You will almost never write this loop at the matrix level. PyTorch makes it five lines per training step:

optimizer.zero_grad()           # 1. Clear previous gradients
output = model(X)               # 2. Forward pass
loss = criterion(output, y)     # 3. Compute loss
loss.backward()                 # 4. Backward pass: all gradients
optimizer.step()                # 5. Apply weight updates

Every weight in the network, across every layer, gets its gradient computed and applied in these five lines. The rest is data loading and logging.

Why optimizer.zero_grad() Matters

PyTorch by default — each .backward() call adds to existing gradient values rather than replacing them.

Forgetting zero_grad() is one of the most common PyTorch bugs. Your loss might still decrease, just much more slowly and erratically, because gradients from many previous batches are stacking up.

Weight Initialization

The loop starts with "initialize weights." This matters more than it might seem.

If you initialize all weights to zero: every neuron in a layer computes the same output, receives the same gradient, and updates identically. The layer stays symmetric forever. This is the — the whole layer is effectively one neuron regardless of width.

Solution: initialize weights randomly. scales random values based on layer dimensions to preserve gradient variance:

WU![6nin+nout,;6nin+nout]W \sim \mathcal{U}!\left[-\sqrt{\frac{6}{n_{in}+n_{out}}},; \sqrt{\frac{6}{n_{in}+n_{out}}}\right]
nn
number of inputs to the layer
nn
number of outputs from the layer

For ReLU activations, works better, accounting for ReLU zeroing roughly half its inputs:

WN!(0,;2nin)W \sim \mathcal{N}!\left(0,; \sqrt{\frac{2}{n_{in}}}\right)
nn
number of inputs to this layer

Biases are typically initialized to zero — unlike weights, they do not contribute to the symmetry problem.

The Computation Graph

PyTorch's forward pass builds a dynamic computation graph as it runs. Every operation creates a node; every tensor records which operation created it. When you call loss.backward(), PyTorch traverses this graph backward, applying the chain rule at each node.

"Dynamic" means the graph is built freshly each forward pass. This is why you can use normal Python if-statements and for-loops in model code — the graph accurately reflects whatever path the computation actually took.

Epochs vs. Steps

One = one complete pass through the entire training dataset.

One step (or iteration) = one forward + backward + update on one mini-batch.

Steps per epoch = dataset size / batch size. For 50,000 examples with batch size 128: 50000/12839150000 / 128 \approx 391 steps per epoch.

Shuffling data before each epoch ensures the sequence of mini-batches differs each time. Without shuffling, the optimizer always sees examples in the same order, potentially creating periodic patterns in gradient updates.

Interactive example

Watch loss decrease over epochs - adjust learning rate and batch size to see their effect

Coming soon

Validation and Early Stopping

After each epoch, evaluate the model on a held-out validation set (data not used for training). The validation loss tells you whether the model is generalizing or just memorizing training data ( ).

Early stopping monitors validation performance and stops training when it has not improved for several consecutive epochs. It is one of the simplest and most effective regularization techniques.

Quiz

1 / 3

In PyTorch's training step, what does loss.backward() do?