Focal Loss vs BCE: How to Fix Imbalanced Binary Classification
'Compare Focal Loss and Binary Cross-Entropy on a 99:1 imbalanced dataset to see how Focal Loss improves minority-class detection and yields more meaningful decision boundaries.'
Why standard Binary Cross-Entropy fails on imbalanced data
Binary cross-entropy (BCE) treats errors from both classes equally. When one class is extremely rare, that symmetry becomes a problem: the network can achieve high overall accuracy by simply predicting the majority class most of the time, while completely ignoring the minority class. For example, a minority sample labeled 1 predicted at 0.3 and a majority sample labeled 0 predicted at 0.7 produce the same BCE loss value, yet the first error is far more consequential in practice.
How Focal Loss changes the training focus
Focal Loss addresses this imbalance by down-weighting easy, confident predictions and amplifying the contribution of hard, misclassified samples. It introduces two key hyperparameters: gamma, which controls how aggressively easy examples are suppressed, and alpha, which allows placing extra weight on the minority class. Together they steer the optimizer to pay more attention to rare and difficult examples instead of being dominated by the abundant majority class.
Experimental setup: 99:1 imbalance
To demonstrate the difference, we train two identical small neural networks on a synthetic dataset with a 99:1 class imbalance (6000 samples). One model uses BCE, the other uses Focal Loss. Below are the installation steps, dataset creation, model and loss implementations, training loop, and visualization code used in the experiment.
Install dependencies
pip install numpy pandas matplotlib scikit-learn torchCreating an imbalanced dataset
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
# Generate imbalanced dataset
X, y = make_classification(
n_samples=6000,
n_features=2,
n_redundant=0,
n_clusters_per_class=1,
weights=[0.99, 0.01],
class_sep=1.5,
random_state=42
)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=42
)
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.float32).unsqueeze(1)Creating the neural network
class SimpleNN(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(2, 16),
nn.ReLU(),
nn.Linear(16, 8),
nn.ReLU(),
nn.Linear(8, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.layers(x)Focal Loss implementation
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, preds, targets):
eps = 1e-7
preds = torch.clamp(preds, eps, 1 - eps)
pt = torch.where(targets == 1, preds, 1 - preds)
loss = -self.alpha * (1 - pt) ** self.gamma * torch.log(pt)
return loss.mean()Training and comparison
We train two identical models with the same optimizer and hyperparameters: one with nn.BCELoss and one with FocalLoss(alpha=0.25, gamma=2). The training loop below computes test accuracy after training.
def train(model, loss_fn, lr=0.01, epochs=30):
opt = optim.Adam(model.parameters(), lr=lr)
for _ in range(epochs):
preds = model(X_train)
loss = loss_fn(preds, y_train)
opt.zero_grad()
loss.backward()
opt.step()
with torch.no_grad():
test_preds = model(X_test)
test_acc = ((test_preds > 0.5).float() == y_test).float().mean().item()
return test_acc, test_preds.squeeze().detach().numpy()
# Models
model_bce = SimpleNN()
model_focal = SimpleNN()
acc_bce, preds_bce = train(model_bce, nn.BCELoss())
acc_focal, preds_focal = train(model_focal, FocalLoss(alpha=0.25, gamma=2))
print("Test Accuracy (BCE):", acc_bce)
print("Test Accuracy (Focal Loss):", acc_focal)On this dataset BCE often reports deceptively high accuracy (e.g., 98%) because it predicts almost everything as the majority class. Focal Loss tends to improve detection of minority-class samples and therefore yields a more meaningful measure of performance for imbalanced data.
Visualizing decision boundaries
def plot_decision_boundary(model, title):
# Create a grid
x_min, x_max = X[:,0].min()-1, X[:,0].max()+1
y_min, y_max = X[:,1].min()-1, X[:,1].max()+1
xx, yy = np.meshgrid(
np.linspace(x_min, x_max, 300),
np.linspace(y_min, y_max, 300)
)
grid = torch.tensor(np.c_[xx.ravel(), yy.ravel()], dtype=torch.float32)
with torch.no_grad():
Z = model(grid).reshape(xx.shape)
# Plot
plt.contourf(xx, yy, Z, levels=[0,0.5,1], alpha=0.4)
plt.scatter(X[:,0], X[:,1], c=y, cmap='coolwarm', s=10)
plt.title(title)
plt.show()
plot_decision_boundary(model_bce, "Decision Boundary -- BCE Loss")
plot_decision_boundary(model_focal, "Decision Boundary -- Focal Loss")The BCE-trained model typically produces an almost flat decision boundary biased toward the majority class, while the Focal Loss model shows a more refined boundary that captures minority-class regions.
Confusion matrices
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
def plot_conf_matrix(y_true, y_pred, title):
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap="Blues", values_format='d')
plt.title(title)
plt.show()
# Convert torch tensors to numpy
y_test_np = y_test.numpy().astype(int)
preds_bce_label = (preds_bce > 0.5).astype(int)
preds_focal_label = (preds_focal > 0.5).astype(int)
plot_conf_matrix(y_test_np, preds_bce_label, "Confusion Matrix -- BCE Loss")
plot_conf_matrix(y_test_np, preds_focal_label, "Confusion Matrix -- Focal Loss")In the example run described here, BCE correctly identified only 1 minority sample and misclassified 27, while Focal Loss correctly predicted 14 minority samples and reduced misclassifications from 27 to 14. That concrete improvement shows how Focal Loss helps models focus on rare, hard-to-classify examples.
Practical takeaways
- For highly imbalanced binary tasks, overall accuracy with BCE can be misleading.
- Focal Loss re-weights the contribution of samples so the model learns from hard minority examples.
- Use visualization and confusion matrices, not just accuracy, to evaluate models on imbalanced data.
The provided code snippets let you reproduce the experiment and adapt the focal loss hyperparameters (alpha and gamma) to your dataset.
Сменить язык
Читать эту статью на русском