<НА ГЛАВНУЮ

Focal Loss против BCE: как исправить несбалансированную бинарную классификацию

'Сравнение Focal Loss и BCE на примере датасета с дисбалансом 99:1 показывает, как Focal Loss улучшает обнаружение редкого класса и более информативные разделяющие границы.'

Почему стандартный Binary Cross-Entropy не справляется с несбалансированными данными

Binary cross-entropy (BCE) одинаково взвешивает ошибки для обоих классов. Когда один класс очень редок, эта симметрия становится проблемой: сеть может достичь высокой общей точности, просто предсказывая преобладающий класс в большинстве случаев, при этом полностью игнорируя редкий класс. Например, редкий образец с меткой 1, предсказанный как 0.3, и преобладающий образец с меткой 0, предсказанный как 0.7, дают одинаковое значение BCE, хотя первая ошибка намного важнее.

Как Focal Loss меняет фокус обучения

Focal Loss решает проблему, понижая вклад простых и уверенных предсказаний и усиливая влияние сложных, неправильно классифицированных примеров. В формуле присутствуют две ключевые гиперпараметры: gamma, контролирующий агрессивность подавления простых примеров, и alpha, позволяющий придать дополнительный вес редкому классу. Вместе они заставляют оптимизатор уделять больше внимания редким и трудным примерам, вместо того чтобы подстраиваться под доминирующий класс.

Эксперимент: несбалансированность 99:1

Чтобы показать разницу, мы обучаем две одинаковые небольшие нейросети на синтетическом датасете с соотношением классов 99:1 (6000 образцов). Одна модель использует BCE, другая — Focal Loss. Ниже приведены шаги установки, создания датасета, определения модели и функции потерь, цикл обучения и код визуализации, использованные в эксперименте.

Установка зависимостей

pip install numpy pandas matplotlib scikit-learn torch

Создание несбалансированного датасета

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
 
# Generate imbalanced dataset
X, y = make_classification(
    n_samples=6000,
    n_features=2,
    n_redundant=0,
    n_clusters_per_class=1,
    weights=[0.99, 0.01],   
    class_sep=1.5,
    random_state=42
)
 
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42
)
 
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)
X_test  = torch.tensor(X_test,  dtype=torch.float32)
y_test  = torch.tensor(y_test,  dtype=torch.float32).unsqueeze(1)

Определение нейросети

class SimpleNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(2, 16),
            nn.ReLU(),
            nn.Linear(16, 8),
            nn.ReLU(),
            nn.Linear(8, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.layers(x)

Реализация Focal Loss

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
 
    def forward(self, preds, targets):
        eps = 1e-7
        preds = torch.clamp(preds, eps, 1 - eps)
        
        pt = torch.where(targets == 1, preds, 1 - preds)
        loss = -self.alpha * (1 - pt) ** self.gamma * torch.log(pt)
        return loss.mean()

Обучение и сравнение

Мы обучаем две одинаковые модели с одним и тем же оптимизатором и гиперпараметрами: одна с nn.BCELoss, другая с FocalLoss(alpha=0.25, gamma=2). Цикл обучения ниже вычисляет точность на тестовой выборке после обучения.

def train(model, loss_fn, lr=0.01, epochs=30):
    opt = optim.Adam(model.parameters(), lr=lr)
 
    for _ in range(epochs):
        preds = model(X_train)
        loss = loss_fn(preds, y_train)
        opt.zero_grad()
        loss.backward()
        opt.step()
 
    with torch.no_grad():
        test_preds = model(X_test)
        test_acc = ((test_preds > 0.5).float() == y_test).float().mean().item()
    return test_acc, test_preds.squeeze().detach().numpy()
 
# Models
model_bce = SimpleNN()
model_focal = SimpleNN()
 
acc_bce, preds_bce = train(model_bce, nn.BCELoss())
acc_focal, preds_focal = train(model_focal, FocalLoss(alpha=0.25, gamma=2))
 
print("Test Accuracy (BCE):", acc_bce)
print("Test Accuracy (Focal Loss):", acc_focal)

В этом примере BCE часто показывает обманчиво высокую точность (например, 98%), потому что модель предсказывает почти всё как преобладающий класс. Focal Loss улучшает обнаружение редкого класса и даёт более информативную оценку качества на несбалансированных данных.

Визуализация разделяющей границы

def plot_decision_boundary(model, title):
    # Create a grid
    x_min, x_max = X[:,0].min()-1, X[:,0].max()+1
    y_min, y_max = X[:,1].min()-1, X[:,1].max()+1
    xx, yy = np.meshgrid(
        np.linspace(x_min, x_max, 300),
        np.linspace(y_min, y_max, 300)
    )
    grid = torch.tensor(np.c_[xx.ravel(), yy.ravel()], dtype=torch.float32)
    with torch.no_grad():
        Z = model(grid).reshape(xx.shape)
 
    # Plot
    plt.contourf(xx, yy, Z, levels=[0,0.5,1], alpha=0.4)
    plt.scatter(X[:,0], X[:,1], c=y, cmap='coolwarm', s=10)
    plt.title(title)
    plt.show()
 
plot_decision_boundary(model_bce, "Decision Boundary -- BCE Loss")
plot_decision_boundary(model_focal, "Decision Boundary -- Focal Loss")

BCE-модель обычно даёт почти плоскую границу, смещённую в сторону преобладающего класса, тогда как модель с Focal Loss показывает более сложную и информативную границу.

Матрицы ошибок (confusion matrices)

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
 
def plot_conf_matrix(y_true, y_pred, title):
    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot(cmap="Blues", values_format='d')
    plt.title(title)
    plt.show()
 
# Convert torch tensors to numpy
y_test_np = y_test.numpy().astype(int)
 
preds_bce_label   = (preds_bce > 0.5).astype(int)
preds_focal_label = (preds_focal > 0.5).astype(int)
 
plot_conf_matrix(y_test_np, preds_bce_label, "Confusion Matrix -- BCE Loss")
plot_conf_matrix(y_test_np, preds_focal_label, "Confusion Matrix -- Focal Loss")

В описанном запуске BCE правильно идентифицировала только 1 образец редкого класса и ошиблась в 27 случаях, тогда как Focal Loss верно предсказала 14 редких образцов, уменьшив число ошибок с 27 до 14. Это наглядно показывает, как Focal Loss помогает модели фокусироваться на редких и трудных для классификации примерах.

Практические выводы

  • Для сильно несбалансированных задач общая точность при использовании BCE может вводить в заблуждение.
  • Focal Loss перераспределяет вклад примеров, чтобы модель училась на трудных образцах редкого класса.
  • Оценивайте модели не только по accuracy, но и по матрицам ошибок и визуализациям, особенно при несбалансированных данных.

Все приведённые фрагменты кода позволяют воспроизвести эксперимент и настроить гиперпараметры focal loss (alpha и gamma) под ваш датасет.

🇬🇧

Switch Language

Read this article in English

Switch to English