Skip to content
Classification
Lesson 9 ⏱ 12 min

Class imbalance and weighted loss

Video coming soon

Class Imbalance - Teaching a Model to Care About the Minority

Why the naive classifier dominates on imbalanced data, how class-weighted loss rebalances the gradient signal, a comparison of weighting vs. resampling strategies, and the focal loss idea.

⏱ ~8 min

🧮

Quick refresher

Cross-entropy loss

For multi-class classification, cross-entropy loss is L = -Σ y_c · log(p_c) summed over classes c. For binary classification: L = -[y·log(p) + (1-y)·log(1-p)]. Each training example contributes equally to the average loss. If one class dominates, its examples dominate the gradient.

Example

100 examples: 95 negative (label 0), 5 positive (label 1).

After one epoch, the model sees 95 gradient signals saying 'predict 0' and only 5 saying 'predict 1'.

A model that outputs p=0.05 for everything gets average loss ≈ 0.21 — lower than many models that try to learn the minority class.

The Imbalanced Dataset Trap

Imagine training a model to detect a rare medical condition that affects 1% of patients. You collect 10,000 training examples: 100 positive (has condition), 9,900 negative (healthy).

You train a model and achieve 99% accuracy. Success? No. A model that predicts "healthy" for every patient also achieves 99% accuracy — and it never helps a single sick person.

The problem runs deeper than the accuracy metric. It's in the loss function itself.

Why Loss Fails on Imbalanced Data

During training, gradients are computed and averaged across all examples in a mini-batch. If a batch of 100 examples contains 99 negatives and 1 positive, the gradient update is dominated 99:1 by negative-class information.

The model learns a powerful lesson from the data: predict negative. It gets rewarded for this 99 times in every 100 examples. The minority class barely registers. By the time training converges, the model has learned to output very low probabilities for everything, effectively predicting "negative" always.

Solution 1: Class-Weighted Loss

The fix is direct: amplify the loss from minority class examples so they contribute proportionally to the gradient.

The standard weight formula for class with classes, total examples, and examples in class c:

wc=NkNcw_c = \frac{N}{k \cdot N_c}
wcw_c
the weight assigned to class c
NN
total number of training examples
kk
number of classes
NcN_c
number of examples in class c

Example: 10,000 examples, binary (k=2), 100 positives, 9,900 negatives:

  • Value: wpositive=10000/(2×100)=50w_{\text{positive}} = 10000 / (2 \times 100) = 50
  • Value: wnegative=10000/(2×9900)0.505w_{\text{negative}} = 10000 / (2 \times 9900) \approx 0.505

Now each positive example contributes 50/0.50599×50 / 0.505 \approx 99\times more than each negative — exactly canceling the 99:1 imbalance.

The weighted cross-entropy loss becomes:

Lweighted=wcilogy^ciL_{\text{weighted}} = -w_{c_i} \cdot \log \hat{y}_{c_i}
wcw_c
class weight for the true class c
logy^c\log \hat{y}_c
log probability assigned to the true class

Solution 2: Oversampling (SMOTE)

Oversampling duplicates minority class examples (or creates synthetic ones) until classes are balanced.

SMOTE (Synthetic Minority Oversampling Technique) creates synthetic examples by interpolating between real minority class examples in feature space:

  1. For each minority example, find its k nearest neighbors (also minority class)
  2. Create a new synthetic example by random linear interpolation between the example and one of its neighbors

SMOTE is effective when the minority class has a meaningful geometric structure in feature space. It's widely used in tabular data settings.

Downsampling (undersampling): randomly discard majority class examples until classes are balanced. Simple, but wastes data. Only advisable when you have abundant majority class examples.

Solution 3: Focal Loss

Introduced for object detection (RetinaNet, 2017), focal loss adds a modulating factor that downweights easy examples and focuses training on hard ones:

FL(pt)=(1pt)γlog(pt)\text{FL}(p_t) = -(1 - p_t)^\gamma \cdot \log(p_t)
ptp_t
the model's probability for the correct class
γ\gamma
focusing parameter — typically 2. Higher γ = more focus on hard examples
FL\text{FL}
focal loss

When ptp_t is large (the model is already confident and correct), (1pt)γ0(1-p_t)^\gamma \approx 0, downweighting the loss. When ptp_t is small (the model is struggling), the factor 1\approx 1 and the loss is nearly standard cross-entropy.

Worked comparison with γ=2:

p_tStandard CE lossFocal loss (γ=2)Downweight factor
0.90.1050.00110.5× less
0.50.6930.1734.0× less
0.12.3031.8681.2× less

Easy examples (p_t=0.9) are downweighted by more than 10×; hard examples (p_t=0.1) barely change. This lets the model focus gradient capacity where it's most needed.

Choosing a Strategy

SituationRecommended approach
Moderate imbalance (5:1 to 20:1)Class-weighted loss
Severe imbalance (100:1+)Class weights + possibly oversampling
Tabular data with sufficient minority samplesSMOTE
Object detection, dense predictionFocal loss
Small datasetOversampling / class weights
Large dataset with abundant majorityUndersampling or class weights

Code: Class-Weighted Loss and Focal Loss in PyTorch

import torch
import torch.nn as nn

# Class-weighted CrossEntropy
# Weights: minority class gets higher weight
class_counts = torch.tensor([9900.0, 100.0])  # [negative, positive]
weights = len(class_counts) * class_counts.sum() / (class_counts * class_counts.sum())
# Or simply: weights = 1.0 / class_counts (unnormalized, also works)

criterion = nn.CrossEntropyLoss(weight=weights.to(device))

# Focal loss (simple binary version)
def focal_loss(logits, targets, gamma=2.0):
    bce = nn.functional.binary_cross_entropy_with_logits(
        logits, targets.float(), reduction='none'
    )
    p_t = torch.exp(-bce)  # = predicted probability for the correct class
    return ((1 - p_t) ** gamma * bce).mean()

For production, the torchvision.ops.sigmoid_focal_loss function provides a well-tested implementation.

Quiz

1 / 3

With class weights w₀=1, w₁=99 applied to a binary cross-entropy loss, what happens to the gradient signal from minority class (class 1) examples?