Освойте Self-Supervised Learning с Lightly: SimCLR, кортсеты и активное обучение в Colab

Зачем нужен self-supervised learning и Lightly?

Self-supervised learning (SSL) позволяет моделям извлекать полезные визуальные представления без размеченных данных, сопоставляя разные аугментации одного изображения. Lightly упрощает практические шаги SSL — от обучения SimCLR до извлечения эмбеддингов и выбора кортсетов — что ускоряет эксперименты в Google Colab.

Подготовка окружения и импорты

Сначала фиксируем версии и устанавливаем Lightly, PyTorch и UMAP, затем импортируем необходимые модули для обучения, извлечения эмбеддингов и визуализации.

!pip uninstall -y numpy
!pip install numpy==1.26.4
!pip install -q lightly torch torchvision matplotlib scikit-learn umap-learn


import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.neighbors import NearestNeighbors
import umap


from lightly.loss import NTXentLoss
from lightly.models.modules import SimCLRProjectionHead
from lightly.transforms import SimCLRTransform
from lightly.data import LightlyDataset


print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

Эти команды подготавливают рантайм и проверяют доступность GPU для ускоренного обучения.

Модель SimCLR на ResNet

Основная модель — SimCLR: ResNet-бэкбон и проекционная голова для контрастивного обучения. Метод extract_features позволяет получать признаки бэкбона для downstream-задач.

class SimCLRModel(nn.Module):
   """SimCLR model with ResNet backbone"""
   def __init__(self, backbone, hidden_dim=512, out_dim=128):
       super().__init__()
       self.backbone = backbone
       self.backbone.fc = nn.Identity()
       self.projection_head = SimCLRProjectionHead(
           input_dim=512, hidden_dim=hidden_dim, output_dim=out_dim
       )
  
   def forward(self, x):
       features = self.backbone(x).flatten(start_dim=1)
       z = self.projection_head(features)
       return z
  
   def extract_features(self, x):
       """Extract backbone features without projection"""
       with torch.no_grad():
           return self.backbone(x).flatten(start_dim=1)

Подготовка CIFAR-10 для SSL и оценки

Используем CIFAR-10 с разными преобразованиями для обучения без меток (сильные аугментации, создающие несколько просмотров) и для оценки (нормализованные изображения). Обёртка SSLDataset возвращает требуемые представления для контрастивного обучения.

def load_dataset(train=True):
   """Load CIFAR-10 dataset"""
   ssl_transform = SimCLRTransform(input_size=32, cj_prob=0.8)
  
   eval_transform = transforms.Compose([
       transforms.ToTensor(),
       transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
   ])
  
   base_dataset = torchvision.datasets.CIFAR10(
       root='./data', train=train, download=True
   )
  
   class SSLDataset(torch.utils.data.Dataset):
       def __init__(self, dataset, transform):
           self.dataset = dataset
           self.transform = transform
      
       def __len__(self):
           return len(self.dataset)
      
       def __getitem__(self, idx):
           img, label = self.dataset[idx]
           return self.transform(img), label
  
   ssl_dataset = SSLDataset(base_dataset, ssl_transform)
  
   eval_dataset = torchvision.datasets.CIFAR10(
       root='./data', train=train, download=True, transform=eval_transform
   )
  
   return ssl_dataset, eval_dataset

Обучение SimCLR

Обучение проводится с функцией потерь NT-Xent, которая сближает представления двух аугментаций одного изображения и раздвигает представления разных изображений.

def train_ssl_model(model, dataloader, epochs=5, device='cuda'):
   """Train SimCLR model"""
   model.to(device)
   criterion = NTXentLoss(temperature=0.5)
   optimizer = torch.optim.SGD(model.parameters(), lr=0.06, momentum=0.9, weight_decay=5e-4)
  
   print("\n=== Self-Supervised Training ===")
   for epoch in range(epochs):
       model.train()
       total_loss = 0
       for batch_idx, batch in enumerate(dataloader):
           views = batch[0] 
           view1, view2 = views[0].to(device), views[1].to(device)
          
           z1 = model(view1)
           z2 = model(view2)
           loss = criterion(z1, z2)
          
           optimizer.zero_grad()
           loss.backward()
           optimizer.step()
          
           total_loss += loss.item()
          
           if batch_idx % 50 == 0:
               print(f"Epoch {epoch+1}/{epochs} | Batch {batch_idx} | Loss: {loss.item():.4f}")
      
       avg_loss = total_loss / len(dataloader)
       print(f"Epoch {epoch+1} Complete | Avg Loss: {avg_loss:.4f}")
  
   return model

Генерация и визуализация эмбеддингов

После предобучения извлекаем признаки и снижаем размерность до 2D с помощью UMAP или t-SNE, чтобы визуально оценить структуру кластеров.

def generate_embeddings(model, dataset, device='cuda', batch_size=256):
   """Generate embeddings for the entire dataset"""
   model.eval()
   model.to(device)
  
   dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)
  
   embeddings = []
   labels = []
  
   print("\n=== Generating Embeddings ===")
   with torch.no_grad():
       for images, targets in dataloader:
           images = images.to(device)
           features = model.extract_features(images)
           embeddings.append(features.cpu().numpy())
           labels.append(targets.numpy())
  
   embeddings = np.vstack(embeddings)
   labels = np.concatenate(labels)
  
   print(f"Generated {embeddings.shape[0]} embeddings with dimension {embeddings.shape[1]}")
   return embeddings, labels


def visualize_embeddings(embeddings, labels, method='umap', n_samples=5000):
   """Visualize embeddings using UMAP or t-SNE"""
   print(f"\n=== Visualizing Embeddings with {method.upper()} ===")
  
   if len(embeddings) > n_samples:
       indices = np.random.choice(len(embeddings), n_samples, replace=False)
       embeddings = embeddings[indices]
       labels = labels[indices]
  
   if method == 'umap':
       reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, metric='cosine')
   else:
       reducer = TSNE(n_components=2, perplexity=30, metric='cosine')
  
   embeddings_2d = reducer.fit_transform(embeddings)
  
   plt.figure(figsize=(12, 10))
   scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1],
                         c=labels, cmap='tab10', s=5, alpha=0.6)
   plt.colorbar(scatter)
   plt.title(f'CIFAR-10 Embeddings ({method.upper()})')
   plt.xlabel('Component 1')
   plt.ylabel('Component 2')
   plt.tight_layout()
   plt.savefig(f'embeddings_{method}.png', dpi=150)
   print(f"Saved visualization to embeddings_{method}.png")
   plt.show()


def select_coreset(embeddings, labels, budget=1000, method='diversity'):
   """
   Select a coreset using different strategies:
   - diversity: Maximum diversity using k-center greedy
   - balanced: Class-balanced selection
   """
   print(f"\n=== Coreset Selection ({method}) ===")
  
   if method == 'balanced':
       selected_indices = []
       n_classes = len(np.unique(labels))
       per_class = budget // n_classes
      
       for cls in range(n_classes):
           cls_indices = np.where(labels == cls)[0]
           selected = np.random.choice(cls_indices, min(per_class, len(cls_indices)), replace=False)
           selected_indices.extend(selected)
      
       return np.array(selected_indices)
  
   elif method == 'diversity':
       selected_indices = []
       remaining_indices = set(range(len(embeddings)))
      
       first_idx = np.random.randint(len(embeddings))
       selected_indices.append(first_idx)
       remaining_indices.remove(first_idx)
      
       for _ in range(budget - 1):
           if not remaining_indices:
               break
          
           remaining = list(remaining_indices)
           selected_emb = embeddings[selected_indices]
           remaining_emb = embeddings[remaining]
          
           distances = np.min(
               np.linalg.norm(remaining_emb[:, None] - selected_emb, axis=2), axis=1
           )
          
           max_dist_idx = np.argmax(distances)
           selected_idx = remaining[max_dist_idx]
           selected_indices.append(selected_idx)
           remaining_indices.remove(selected_idx)
      
       print(f"Selected {len(selected_indices)} samples")
       return np.array(selected_indices)

Выбор кортсета помогает выделить небольшую, но информативную подвыборку: либо сбалансированную по классам, либо максимально разнообразную.

Оценка через линейный probe

Замораживаем бэкбон и обучаем лёгкий линейный классификатор на извлечённых признаках — так измеряется пригодность эмбеддингов для классификации.

def evaluate_linear_probe(model, train_subset, test_dataset, device='cuda'):
   """Train linear classifier on frozen features"""
   model.eval()
  
   train_loader = DataLoader(train_subset, batch_size=128, shuffle=True, num_workers=2)
   test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=2)
  
   classifier = nn.Linear(512, 10).to(device)
   criterion = nn.CrossEntropyLoss()
   optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)
  
   for epoch in range(10):
       classifier.train()
       for images, targets in train_loader:
           images, targets = images.to(device), targets.to(device)
          
           with torch.no_grad():
               features = model.extract_features(images)
          
           outputs = classifier(features)
           loss = criterion(outputs, targets)
          
           optimizer.zero_grad()
           loss.backward()
           optimizer.step()
  
   classifier.eval()
   correct = 0
   total = 0
  
   with torch.no_grad():
       for images, targets in test_loader:
           images, targets = images.to(device), targets.to(device)
           features = model.extract_features(images)
           outputs = classifier(features)
           _, predicted = outputs.max(1)
           total += targets.size(0)
           correct += predicted.eq(targets).sum().item()
  
   accuracy = 100. * correct / total
   return accuracy


def main():
   device = 'cuda' if torch.cuda.is_available() else 'cpu'
   print(f"Using device: {device}")
  
   ssl_dataset, eval_dataset = load_dataset(train=True)
   _, test_dataset = load_dataset(train=False)
  
   ssl_subset = Subset(ssl_dataset, range(10000)) 
   ssl_loader = DataLoader(ssl_subset, batch_size=128, shuffle=True, num_workers=2, drop_last=True)
  
   backbone = torchvision.models.resnet18(pretrained=False)
   model = SimCLRModel(backbone)
   model = train_ssl_model(model, ssl_loader, epochs=5, device=device)
  
   eval_subset = Subset(eval_dataset, range(10000))
   embeddings, labels = generate_embeddings(model, eval_subset, device=device)
  
   visualize_embeddings(embeddings, labels, method='umap')
  
   coreset_indices = select_coreset(embeddings, labels, budget=1000, method='diversity')
   coreset_subset = Subset(eval_dataset, coreset_indices)
  
   print("\n=== Active Learning Evaluation ===")
   coreset_acc = evaluate_linear_probe(model, coreset_subset, test_dataset, device=device)
   print(f"Coreset Accuracy (1000 samples): {coreset_acc:.2f}%")
  
   random_indices = np.random.choice(len(eval_subset), 1000, replace=False)
   random_subset = Subset(eval_dataset, random_indices)
   random_acc = evaluate_linear_probe(model, random_subset, test_dataset, device=device)
   print(f"Random Accuracy (1000 samples): {random_acc:.2f}%")
  
   print(f"\nCoreset improvement: +{coreset_acc - random_acc:.2f}%")
  
   print("\n=== Tutorial Complete! ===")
   print("Key takeaways:")
   print("1. Self-supervised learning creates meaningful representations without labels")
   print("2. Embeddings capture semantic similarity between images")
   print("3. Smart data selection (coreset) outperforms random sampling")
   print("4. Active learning reduces labeling costs while maintaining accuracy")


if __name__ == "__main__":
   main()

Этот конвейер демонстрирует полный цикл: предобучение SimCLR, визуализация эмбеддингов, выбор кортсета и оценка через линейный probe. Сравнение кортсета и случайной подвыборки показывает, как интеллектуальный отбор улучшает эффективность разметки и итоговое качество модели.