Focal Loss против BCE: как исправить несбалансированную бинарную классификацию
'Сравнение Focal Loss и BCE на примере датасета с дисбалансом 99:1 показывает, как Focal Loss улучшает обнаружение редкого класса и более информативные разделяющие границы.'
Почему стандартный Binary Cross-Entropy не справляется с несбалансированными данными
Binary cross-entropy (BCE) одинаково взвешивает ошибки для обоих классов. Когда один класс очень редок, эта симметрия становится проблемой: сеть может достичь высокой общей точности, просто предсказывая преобладающий класс в большинстве случаев, при этом полностью игнорируя редкий класс. Например, редкий образец с меткой 1, предсказанный как 0.3, и преобладающий образец с меткой 0, предсказанный как 0.7, дают одинаковое значение BCE, хотя первая ошибка намного важнее.
Как Focal Loss меняет фокус обучения
Focal Loss решает проблему, понижая вклад простых и уверенных предсказаний и усиливая влияние сложных, неправильно классифицированных примеров. В формуле присутствуют две ключевые гиперпараметры: gamma, контролирующий агрессивность подавления простых примеров, и alpha, позволяющий придать дополнительный вес редкому классу. Вместе они заставляют оптимизатор уделять больше внимания редким и трудным примерам, вместо того чтобы подстраиваться под доминирующий класс.
Эксперимент: несбалансированность 99:1
Чтобы показать разницу, мы обучаем две одинаковые небольшие нейросети на синтетическом датасете с соотношением классов 99:1 (6000 образцов). Одна модель использует BCE, другая — Focal Loss. Ниже приведены шаги установки, создания датасета, определения модели и функции потерь, цикл обучения и код визуализации, использованные в эксперименте.
Установка зависимостей
pip install numpy pandas matplotlib scikit-learn torchСоздание несбалансированного датасета
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
# Generate imbalanced dataset
X, y = make_classification(
n_samples=6000,
n_features=2,
n_redundant=0,
n_clusters_per_class=1,
weights=[0.99, 0.01],
class_sep=1.5,
random_state=42
)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=42
)
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.float32).unsqueeze(1)Определение нейросети
class SimpleNN(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(2, 16),
nn.ReLU(),
nn.Linear(16, 8),
nn.ReLU(),
nn.Linear(8, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.layers(x)Реализация Focal Loss
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, preds, targets):
eps = 1e-7
preds = torch.clamp(preds, eps, 1 - eps)
pt = torch.where(targets == 1, preds, 1 - preds)
loss = -self.alpha * (1 - pt) ** self.gamma * torch.log(pt)
return loss.mean()Обучение и сравнение
Мы обучаем две одинаковые модели с одним и тем же оптимизатором и гиперпараметрами: одна с nn.BCELoss, другая с FocalLoss(alpha=0.25, gamma=2). Цикл обучения ниже вычисляет точность на тестовой выборке после обучения.
def train(model, loss_fn, lr=0.01, epochs=30):
opt = optim.Adam(model.parameters(), lr=lr)
for _ in range(epochs):
preds = model(X_train)
loss = loss_fn(preds, y_train)
opt.zero_grad()
loss.backward()
opt.step()
with torch.no_grad():
test_preds = model(X_test)
test_acc = ((test_preds > 0.5).float() == y_test).float().mean().item()
return test_acc, test_preds.squeeze().detach().numpy()
# Models
model_bce = SimpleNN()
model_focal = SimpleNN()
acc_bce, preds_bce = train(model_bce, nn.BCELoss())
acc_focal, preds_focal = train(model_focal, FocalLoss(alpha=0.25, gamma=2))
print("Test Accuracy (BCE):", acc_bce)
print("Test Accuracy (Focal Loss):", acc_focal)В этом примере BCE часто показывает обманчиво высокую точность (например, 98%), потому что модель предсказывает почти всё как преобладающий класс. Focal Loss улучшает обнаружение редкого класса и даёт более информативную оценку качества на несбалансированных данных.
Визуализация разделяющей границы
def plot_decision_boundary(model, title):
# Create a grid
x_min, x_max = X[:,0].min()-1, X[:,0].max()+1
y_min, y_max = X[:,1].min()-1, X[:,1].max()+1
xx, yy = np.meshgrid(
np.linspace(x_min, x_max, 300),
np.linspace(y_min, y_max, 300)
)
grid = torch.tensor(np.c_[xx.ravel(), yy.ravel()], dtype=torch.float32)
with torch.no_grad():
Z = model(grid).reshape(xx.shape)
# Plot
plt.contourf(xx, yy, Z, levels=[0,0.5,1], alpha=0.4)
plt.scatter(X[:,0], X[:,1], c=y, cmap='coolwarm', s=10)
plt.title(title)
plt.show()
plot_decision_boundary(model_bce, "Decision Boundary -- BCE Loss")
plot_decision_boundary(model_focal, "Decision Boundary -- Focal Loss")BCE-модель обычно даёт почти плоскую границу, смещённую в сторону преобладающего класса, тогда как модель с Focal Loss показывает более сложную и информативную границу.
Матрицы ошибок (confusion matrices)
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
def plot_conf_matrix(y_true, y_pred, title):
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap="Blues", values_format='d')
plt.title(title)
plt.show()
# Convert torch tensors to numpy
y_test_np = y_test.numpy().astype(int)
preds_bce_label = (preds_bce > 0.5).astype(int)
preds_focal_label = (preds_focal > 0.5).astype(int)
plot_conf_matrix(y_test_np, preds_bce_label, "Confusion Matrix -- BCE Loss")
plot_conf_matrix(y_test_np, preds_focal_label, "Confusion Matrix -- Focal Loss")В описанном запуске BCE правильно идентифицировала только 1 образец редкого класса и ошиблась в 27 случаях, тогда как Focal Loss верно предсказала 14 редких образцов, уменьшив число ошибок с 27 до 14. Это наглядно показывает, как Focal Loss помогает модели фокусироваться на редких и трудных для классификации примерах.
Практические выводы
- Для сильно несбалансированных задач общая точность при использовании BCE может вводить в заблуждение.
- Focal Loss перераспределяет вклад примеров, чтобы модель училась на трудных образцах редкого класса.
- Оценивайте модели не только по accuracy, но и по матрицам ошибок и визуализациям, особенно при несбалансированных данных.
Все приведённые фрагменты кода позволяют воспроизвести эксперимент и настроить гиперпараметры focal loss (alpha и gamma) под ваш датасет.
Switch Language
Read this article in English