The Imbalanced Dataset Trap
Imagine training a model to detect a rare medical condition that affects 1% of patients. You collect 10,000 training examples: 100 positive (has condition), 9,900 negative (healthy).
You train a model and achieve 99% accuracy. Success? No. A model that predicts "healthy" for every patient also achieves 99% accuracy — and it never helps a single sick person.
The problem runs deeper than the accuracy metric. It's in the loss function itself.
Why Loss Fails on Imbalanced Data
During training, gradients are computed and averaged across all examples in a mini-batch. If a batch of 100 examples contains 99 negatives and 1 positive, the gradient update is dominated 99:1 by negative-class information.
The model learns a powerful lesson from the data: predict negative. It gets rewarded for this 99 times in every 100 examples. The minority class barely registers. By the time training converges, the model has learned to output very low probabilities for everything, effectively predicting "negative" always.
Solution 1: Class-Weighted Loss
The fix is direct: amplify the loss from minority class examples so they contribute proportionally to the gradient.
The standard weight formula for class with classes, total examples, and examples in class c:
- the weight assigned to class c
- total number of training examples
- number of classes
- number of examples in class c
Example: 10,000 examples, binary (k=2), 100 positives, 9,900 negatives:
- Value:
- Value:
Now each positive example contributes more than each negative — exactly canceling the 99:1 imbalance.
The weighted cross-entropy loss becomes:
- class weight for the true class c
- log probability assigned to the true class
Solution 2: Oversampling (SMOTE)
Oversampling duplicates minority class examples (or creates synthetic ones) until classes are balanced.
SMOTE (Synthetic Minority Oversampling Technique) creates synthetic examples by interpolating between real minority class examples in feature space:
- For each minority example, find its k nearest neighbors (also minority class)
- Create a new synthetic example by random linear interpolation between the example and one of its neighbors
SMOTE is effective when the minority class has a meaningful geometric structure in feature space. It's widely used in tabular data settings.
Downsampling (undersampling): randomly discard majority class examples until classes are balanced. Simple, but wastes data. Only advisable when you have abundant majority class examples.
Solution 3: Focal Loss
Introduced for object detection (RetinaNet, 2017), focal loss adds a modulating factor that downweights easy examples and focuses training on hard ones:
- the model's probability for the correct class
- focusing parameter — typically 2. Higher γ = more focus on hard examples
- focal loss
When is large (the model is already confident and correct), , downweighting the loss. When is small (the model is struggling), the factor and the loss is nearly standard cross-entropy.
Worked comparison with γ=2:
| p_t | Standard CE loss | Focal loss (γ=2) | Downweight factor |
|---|---|---|---|
| 0.9 | 0.105 | 0.001 | 10.5× less |
| 0.5 | 0.693 | 0.173 | 4.0× less |
| 0.1 | 2.303 | 1.868 | 1.2× less |
Easy examples (p_t=0.9) are downweighted by more than 10×; hard examples (p_t=0.1) barely change. This lets the model focus gradient capacity where it's most needed.
Choosing a Strategy
| Situation | Recommended approach |
|---|---|
| Moderate imbalance (5:1 to 20:1) | Class-weighted loss |
| Severe imbalance (100:1+) | Class weights + possibly oversampling |
| Tabular data with sufficient minority samples | SMOTE |
| Object detection, dense prediction | Focal loss |
| Small dataset | Oversampling / class weights |
| Large dataset with abundant majority | Undersampling or class weights |
Code: Class-Weighted Loss and Focal Loss in PyTorch
import torch
import torch.nn as nn
# Class-weighted CrossEntropy
# Weights: minority class gets higher weight
class_counts = torch.tensor([9900.0, 100.0]) # [negative, positive]
weights = len(class_counts) * class_counts.sum() / (class_counts * class_counts.sum())
# Or simply: weights = 1.0 / class_counts (unnormalized, also works)
criterion = nn.CrossEntropyLoss(weight=weights.to(device))
# Focal loss (simple binary version)
def focal_loss(logits, targets, gamma=2.0):
bce = nn.functional.binary_cross_entropy_with_logits(
logits, targets.float(), reduction='none'
)
p_t = torch.exp(-bce) # = predicted probability for the correct class
return ((1 - p_t) ** gamma * bce).mean()
For production, the torchvision.ops.sigmoid_focal_loss function provides a well-tested implementation.