<RETURN_TO_BASE

Focal Loss vs BCE: How to Fix Imbalanced Binary Classification

'Compare Focal Loss and Binary Cross-Entropy on a 99:1 imbalanced dataset to see how Focal Loss improves minority-class detection and yields more meaningful decision boundaries.'

Why standard Binary Cross-Entropy fails on imbalanced data

Binary cross-entropy (BCE) treats errors from both classes equally. When one class is extremely rare, that symmetry becomes a problem: the network can achieve high overall accuracy by simply predicting the majority class most of the time, while completely ignoring the minority class. For example, a minority sample labeled 1 predicted at 0.3 and a majority sample labeled 0 predicted at 0.7 produce the same BCE loss value, yet the first error is far more consequential in practice.

How Focal Loss changes the training focus

Focal Loss addresses this imbalance by down-weighting easy, confident predictions and amplifying the contribution of hard, misclassified samples. It introduces two key hyperparameters: gamma, which controls how aggressively easy examples are suppressed, and alpha, which allows placing extra weight on the minority class. Together they steer the optimizer to pay more attention to rare and difficult examples instead of being dominated by the abundant majority class.

Experimental setup: 99:1 imbalance

To demonstrate the difference, we train two identical small neural networks on a synthetic dataset with a 99:1 class imbalance (6000 samples). One model uses BCE, the other uses Focal Loss. Below are the installation steps, dataset creation, model and loss implementations, training loop, and visualization code used in the experiment.

Install dependencies

pip install numpy pandas matplotlib scikit-learn torch

Creating an imbalanced dataset

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
 
# Generate imbalanced dataset
X, y = make_classification(
    n_samples=6000,
    n_features=2,
    n_redundant=0,
    n_clusters_per_class=1,
    weights=[0.99, 0.01],   
    class_sep=1.5,
    random_state=42
)
 
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42
)
 
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)
X_test  = torch.tensor(X_test,  dtype=torch.float32)
y_test  = torch.tensor(y_test,  dtype=torch.float32).unsqueeze(1)

Creating the neural network

class SimpleNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(2, 16),
            nn.ReLU(),
            nn.Linear(16, 8),
            nn.ReLU(),
            nn.Linear(8, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.layers(x)

Focal Loss implementation

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
 
    def forward(self, preds, targets):
        eps = 1e-7
        preds = torch.clamp(preds, eps, 1 - eps)
        
        pt = torch.where(targets == 1, preds, 1 - preds)
        loss = -self.alpha * (1 - pt) ** self.gamma * torch.log(pt)
        return loss.mean()

Training and comparison

We train two identical models with the same optimizer and hyperparameters: one with nn.BCELoss and one with FocalLoss(alpha=0.25, gamma=2). The training loop below computes test accuracy after training.

def train(model, loss_fn, lr=0.01, epochs=30):
    opt = optim.Adam(model.parameters(), lr=lr)
 
    for _ in range(epochs):
        preds = model(X_train)
        loss = loss_fn(preds, y_train)
        opt.zero_grad()
        loss.backward()
        opt.step()
 
    with torch.no_grad():
        test_preds = model(X_test)
        test_acc = ((test_preds > 0.5).float() == y_test).float().mean().item()
    return test_acc, test_preds.squeeze().detach().numpy()
 
# Models
model_bce = SimpleNN()
model_focal = SimpleNN()
 
acc_bce, preds_bce = train(model_bce, nn.BCELoss())
acc_focal, preds_focal = train(model_focal, FocalLoss(alpha=0.25, gamma=2))
 
print("Test Accuracy (BCE):", acc_bce)
print("Test Accuracy (Focal Loss):", acc_focal)

On this dataset BCE often reports deceptively high accuracy (e.g., 98%) because it predicts almost everything as the majority class. Focal Loss tends to improve detection of minority-class samples and therefore yields a more meaningful measure of performance for imbalanced data.

Visualizing decision boundaries

def plot_decision_boundary(model, title):
    # Create a grid
    x_min, x_max = X[:,0].min()-1, X[:,0].max()+1
    y_min, y_max = X[:,1].min()-1, X[:,1].max()+1
    xx, yy = np.meshgrid(
        np.linspace(x_min, x_max, 300),
        np.linspace(y_min, y_max, 300)
    )
    grid = torch.tensor(np.c_[xx.ravel(), yy.ravel()], dtype=torch.float32)
    with torch.no_grad():
        Z = model(grid).reshape(xx.shape)
 
    # Plot
    plt.contourf(xx, yy, Z, levels=[0,0.5,1], alpha=0.4)
    plt.scatter(X[:,0], X[:,1], c=y, cmap='coolwarm', s=10)
    plt.title(title)
    plt.show()
 
plot_decision_boundary(model_bce, "Decision Boundary -- BCE Loss")
plot_decision_boundary(model_focal, "Decision Boundary -- Focal Loss")

The BCE-trained model typically produces an almost flat decision boundary biased toward the majority class, while the Focal Loss model shows a more refined boundary that captures minority-class regions.

Confusion matrices

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
 
def plot_conf_matrix(y_true, y_pred, title):
    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot(cmap="Blues", values_format='d')
    plt.title(title)
    plt.show()
 
# Convert torch tensors to numpy
y_test_np = y_test.numpy().astype(int)
 
preds_bce_label   = (preds_bce > 0.5).astype(int)
preds_focal_label = (preds_focal > 0.5).astype(int)
 
plot_conf_matrix(y_test_np, preds_bce_label, "Confusion Matrix -- BCE Loss")
plot_conf_matrix(y_test_np, preds_focal_label, "Confusion Matrix -- Focal Loss")

In the example run described here, BCE correctly identified only 1 minority sample and misclassified 27, while Focal Loss correctly predicted 14 minority samples and reduced misclassifications from 27 to 14. That concrete improvement shows how Focal Loss helps models focus on rare, hard-to-classify examples.

Practical takeaways

  • For highly imbalanced binary tasks, overall accuracy with BCE can be misleading.
  • Focal Loss re-weights the contribution of samples so the model learns from hard minority examples.
  • Use visualization and confusion matrices, not just accuracy, to evaluate models on imbalanced data.

The provided code snippets let you reproduce the experiment and adapt the focal loss hyperparameters (alpha and gamma) to your dataset.

🇷🇺

Сменить язык

Читать эту статью на русском

Переключить на Русский