Освойте TorchVision v2: продвинутые трансформы, MixUp, CutMix и современная тренировка CNN
Установка и импорты
Перед сборкой пайплайна и модели установите необходимые пакеты и импортируйте основные модули, используемые в руководстве.
!pip install torch torchvision torchaudio --quiet
!pip install matplotlib pillow numpy --quiet
import torch
import torchvision
from torchvision import transforms as T
from torchvision.transforms import v2
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import requests
from io import BytesIO
print(f"PyTorch version: {torch.__version__}")
print(f"TorchVision version: {torchvision.__version__}")
Эти импорты настраивают PyTorch, TorchVision v2 трансформы и утилиты (NumPy, PIL, Matplotlib), что готовит окружение для продвинутых аугментаций, определения модели и обучения.
Продвинутый пайплайн аугментаций с TorchVision v2
Соберите гибкий пайплайн аугментации, который применяет сильные преобразования во время обучения и простую предобработку для валидации.
class AdvancedAugmentationPipeline:
def __init__(self, image_size=224, training=True):
self.image_size = image_size
self.training = training
base_transforms = [
v2.ToImage(),
v2.ToDtype(torch.uint8, scale=True),
]
if training:
self.transform = v2.Compose([
*base_transforms,
v2.Resize((image_size + 32, image_size + 32)),
v2.RandomResizedCrop(image_size, scale=(0.8, 1.0), ratio=(0.9, 1.1)),
v2.RandomHorizontalFlip(p=0.5),
v2.RandomRotation(degrees=15),
v2.ColorJitter(brights=0.4, contst=0.4, sation=0.4, hue=0.1),
v2.RandomGrayscale(p=0.1),
v2.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
v2.RandomPerspective(distortion_scale=0.1, p=0.3),
v2.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
else:
self.transform = v2.Compose([
*base_transforms,
v2.Resize((image_size, image_size)),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def __call__(self, image):
return self.transform(image)
Пайплайн применяет набор преобразований при обучении: изменение размера, случайная обрезка, перевороты, повороты, ColorJitter, размытие, перспектива и афинные преобразования, затем конвертацию типа и нормализацию. Для валидации остаются только resize и нормализация.
MixUp и CutMix: единый модуль
MixUp и CutMix реализованы в одном классе, который либо смешивает изображения по коэффициенту, либо вставляет патч из другого изображения. Метод возвращает смешанные входы, два таргета и коэффициент смешения.
class AdvancedMixupCutmix:
def __init__(self, mixup_alpha=1.0, cutmix_alpha=1.0, prob=0.5):
self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
self.prob = prob
def mixup(self, x, y):
batch_size = x.size(0)
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha) if self.mixup_alpha > 0 else 1
index = torch.randperm(batch_size)
mixed_x = lam * x + (1 - lam) * x[index, :]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
def cutmix(self, x, y):
batch_size = x.size(0)
lam = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if self.cutmix_alpha > 0 else 1
index = torch.randperm(batch_size)
y_a, y_b = y, y[index]
bbx1, bby1, bbx2, bby2 = self._rand_bbox(x.size(), lam)
x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
return x, y_a, y_b, lam
def _rand_bbox(self, size, lam):
W = size[2]
H = size[3]
cut_rat = np.sqrt(1. - lam)
cut_w = int(W * cut_rat)
cut_h = int(H * cut_rat)
cx = np.random.randint(W)
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
def __call__(self, x, y):
if np.random.random() > self.prob:
return x, y, y, 1.0
if np.random.random() < 0.5:
return self.mixup(x, y)
else:
return self.cutmix(x, y)
class ModernCNN(nn.Module):
def __init__(self, num_classes=10, dropout=0.3):
super(ModernCNN, self).__init__()
self.conv1 = self._conv_block(3, 64)
self.conv2 = self._conv_block(64, 128, downsample=True)
self.conv3 = self._conv_block(128, 256, downsample=True)
self.conv4 = self._conv_block(256, 512, downsample=True)
self.gap = nn.AdaptiveAvgPool2d(1)
self.attention = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.Sigmoid()
)
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(dropout/2),
nn.Linear(256, num_classes)
)
def _conv_block(self, in_channels, out_channels, downsample=False):
stride = 2 if downsample else 1
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.gap(x)
x = torch.flatten(x, 1)
attention_weights = self.attention(x)
x = x * attention_weights
return self.classifier(x)
ModernCNN использует последовательные сверточные блоки с возможностью даунсемплинга, глобальный average pooling, MLP-внимание и классификатор с дропаутом.
Цикл обучения и оптимизация
Объедините AdamW, OneCycleLR и обрезку градиентов. Если активен MixUp/CutMix, используйте интерполированную функцию потерь.
class AdvancedTrainer:
def __init__(self, model, device='cuda' if torch.cuda.is_available() else 'cpu'):
self.model = model.to(device)
self.device = device
self.mixup_cutmix = AdvancedMixupCutmix()
self.optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
self.scheduler = optim.lr_scheduler.OneCycleLR(
self.optimizer, max_lr=1e-2, epochs=10, steps_per_epoch=100
)
self.criterion = nn.CrossEntropyLoss()
def mixup_criterion(self, pred, y_a, y_b, lam):
return lam * self.criterion(pred, y_a) + (1 - lam) * self.criterion(pred, y_b)
def train_epoch(self, dataloader):
self.model.train()
total_loss = 0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.to(self.device), target.to(self.device)
data, target_a, target_b, lam = self.mixup_cutmix(data, target)
self.optimizer.zero_grad()
output = self.model(data)
if lam != 1.0:
loss = self.mixup_criterion(output, target_a, target_b, lam)
else:
loss = self.criterion(output, target)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
self.scheduler.step()
total_loss += loss.item()
_, predicted = output.max(1)
total += target.size(0)
if lam != 1.0:
correct += (lam * predicted.eq(target_a).sum().item() +
(1 - lam) * predicted.eq(target_b).sum().item())
else:
correct += predicted.eq(target).sum().item()
return total_loss / len(dataloader), 100. * correct / total
Практики: weight decay (AdamW), OneCycleLR со степом по батчу, обрезка градиентов и интерполированная функция потерь при смешениях.
Демонстрация end-to-end
Небольшой демо-скрипт проверяет взаимодействие пайплайна, миксинга, модели и тренера на синтетических данных.
def demo_advanced_techniques():
batch_size = 16
num_classes = 10
sample_data = torch.randn(batch_size, 3, 224, 224)
sample_labels = torch.randint(0, num_classes, (batch_size,))
transform_pipeline = AdvancedAugmentationPipeline(training=True)
model = ModernCNN(num_classes=num_classes)
trainer = AdvancedTrainer(model)
print(" Advanced Deep Learning Tutorial Demo")
print("=" * 50)
print("\n1. Advanced Augmentation Pipeline:")
augmented = transform_pipeline(Image.fromarray((sample_data[0].permute(1,2,0).numpy() * 255).astype(np.uint8)))
print(f" Original shape: {sample_data[0].shape}")
print(f" Augmented shape: {augmented.shape}")
print(f" Applied transforms: Resize, Crop, Flip, ColorJitter, Blur, Perspective, etc.")
print("\n2. MixUp/CutMix Augmentation:")
mixup_cutmix = AdvancedMixupCutmix()
mixed_data, target_a, target_b, lam = mixup_cutmix(sample_data, sample_labels)
print(f" Mixed batch shape: {mixed_data.shape}")
print(f" Lambda value: {lam:.3f}")
print(f" Technique: {'MixUp' if lam > 0.7 else 'CutMix'}")
print("\n3. Modern CNN Architecture:")
model.eval()
with torch.no_grad():
output = model(sample_data)
print(f" Input shape: {sample_data.shape}")
print(f" Output shape: {output.shape}")
print(f" Features: Residual blocks, Attention, Global Average Pooling")
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
print("\n4. Advanced Training Simulation:")
dummy_loader = [(sample_data, sample_labels)]
loss, acc = trainer.train_epoch(dummy_loader)
print(f" Training loss: {loss:.4f}")
print(f" Training accuracy: {acc:.2f}%")
print(f" Learning rate: {trainer.scheduler.get_last_lr()[0]:.6f}")
print("\n Tutorial completed successfully!")
print("This code demonstrates state-of-the-art techniques in deep learning:")
print("• Advanced data augmentation with TorchVision v2")
print("• MixUp and CutMix for better generalization")
print("• Modern CNN architecture with attention")
print("• Advanced training loop with OneCycleLR")
print("• Gradient clipping and weight decay")
if __name__ == "__main__":
demo_advanced_techniques()
Этот набор компонентов полезен для быстрого прототипирования современных методов компьютерного зрения: от аугментаций и смешений до архитектуры и надежного цикла обучения. Запустите демо на Colab и масштабируйте его на реальных датасетах.