Обучение больших трансформеров в Colab с DeepSpeed: ZeRO, FP16 и градиентный чекпойнтинг
Почему DeepSpeed и ZeRO важны
Обучение трансформеров в масштабе требует аккуратного управления памятью GPU и коммуникационной нагрузкой. Оптимизации DeepSpeed, такие как ZeRO, смешанная точность (FP16), накопление градиентов и перенос состояний оптимизатора на CPU, позволяют обучать большие модели или использовать большие эффективные батчи на скромном железе вроде Colab.
Подготовка окружения в Colab
Сначала устанавливаем PyTorch с поддержкой CUDA, DeepSpeed и сопутствующие библиотеки, чтобы среда была готова для создания и обучения моделей. В руководстве есть helper для установки требуемых пакетов.
import subprocess
import sys
import os
import json
import time
from pathlib import Path
def install_dependencies():
"""Install required packages for DeepSpeed in Colab"""
print(" Installing DeepSpeed and dependencies...")
subprocess.check_call([
sys.executable, "-m", "pip", "install",
"torch", "torchvision", "torchaudio", "--index-url",
"https://download.pytorch.org/whl/cu118"
])
subprocess.check_call([sys.executable, "-m", "pip", "install", "deepspeed"])
subprocess.check_call([
sys.executable, "-m", "pip", "install",
"transformers", "datasets", "accelerate", "wandb"
])
print(" Installation complete!")
install_dependencies()
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import deepspeed
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
from typing import Dict, Any
import argparse
Синтетический датасет для быстрой отладки
Чтобы не тянуть большие датасеты, создаём синтетический датасет, который генерирует случайные последовательности токенов. Это удобно для отладки памяти, пропускной способности и корректности тренировочного цикла.
class SyntheticTextDataset(Dataset):
"""Synthetic dataset for demonstration purposes"""
def __init__(self, size: int = 1000, seq_length: int = 512, vocab_size: int = 50257):
self.size = size
self.seq_length = seq_length
self.vocab_size = vocab_size
self.data = torch.randint(0, vocab_size, (size, seq_length))
def __len__(self):
return self.size
def __getitem__(self, idx):
return {
'input_ids': self.data[idx],
'labels': self.data[idx].clone()
}
Архитектура продвинутого тренера DeepSpeed
Класс AdvancedDeepSpeedTrainer инкапсулирует создание модели, генерацию конфигурации DeepSpeed, инициализацию движка, шаг обучения, сохранение чекпоинтов и демонстрацию инференса. Ниже приводится полный код, который можно запустить без изменений.
class AdvancedDeepSpeedTrainer:
"""Advanced DeepSpeed trainer with multiple optimization techniques"""
def __init__(self, model_config: Dict[str, Any], ds_config: Dict[str, Any]):
self.model_config = model_config
self.ds_config = ds_config
self.model = None
self.engine = None
self.tokenizer = None
def create_model(self):
"""Create a GPT-2 style model for demonstration"""
print(" Creating model...")
config = GPT2Config(
vocab_size=self.model_config['vocab_size'],
n_positions=self.model_config['seq_length'],
n_embd=self.model_config['hidden_size'],
n_layer=self.model_config['num_layers'],
n_head=self.model_config['num_heads'],
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
)
self.model = GPT2LMHeadModel(config)
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
self.tokenizer.pad_token = self.tokenizer.eos_token
print(f" Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")
return self.model
def create_deepspeed_config(self):
"""Create comprehensive DeepSpeed configuration"""
return {
"train_batch_size": self.ds_config['train_batch_size'],
"train_micro_batch_size_per_gpu": self.ds_config['micro_batch_size'],
"gradient_accumulation_steps": self.ds_config['gradient_accumulation_steps'],
"zero_optimization": {
"stage": self.ds_config['zero_stage'],
"allgather_partitions": True,
"allgather_bucket_size": 5e8,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 5e8,
"contiguous_gradients": True,
"cpu_offload": self.ds_config.get('cpu_offload', False)
},
"fp16": {
"enabled": True,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": self.ds_config['learning_rate'],
"betas": [0.9, 0.999],
"eps": 1e-8,
"weight_decay": 0.01
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": self.ds_config['learning_rate'],
"warmup_num_steps": 100
}
},
"gradient_clipping": 1.0,
"wall_clock_breakdown": True,
"memory_breakdown": True,
"tensorboard": {
"enabled": True,
"output_path": "./logs/",
"job_name": "deepspeed_advanced_tutorial"
}
}
def initialize_deepspeed(self):
"""Initialize DeepSpeed engine"""
print(" Initializing DeepSpeed...")
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args([])
self.engine, optimizer, _, lr_scheduler = deepspeed.initialize(
args=args,
model=self.model,
config=self.create_deepspeed_config()
)
print(f" DeepSpeed engine initialized with ZeRO stage {self.ds_config['zero_stage']}")
return self.engine
def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
"""Perform a single training step with DeepSpeed optimizations"""
input_ids = batch['input_ids'].to(self.engine.device)
labels = batch['labels'].to(self.engine.device)
outputs = self.engine(input_ids=input_ids, labels=labels)
loss = outputs.loss
self.engine.backward(loss)
self.engine.step()
return {
'loss': loss.item(),
'lr': self.engine.lr_scheduler.get_last_lr()[0] if self.engine.lr_scheduler else 0
}
def train(self, dataloader: DataLoader, num_epochs: int = 2):
"""Complete training loop with monitoring"""
print(f" Starting training for {num_epochs} epochs...")
self.engine.train()
total_steps = 0
for epoch in range(num_epochs):
epoch_loss = 0.0
epoch_steps = 0
print(f"\n Epoch {epoch + 1}/{num_epochs}")
for step, batch in enumerate(dataloader):
start_time = time.time()
metrics = self.train_step(batch)
epoch_loss += metrics['loss']
epoch_steps += 1
total_steps += 1
if step % 10 == 0:
step_time = time.time() - start_time
print(f" Step {step:4d} | Loss: {metrics['loss']:.4f} | "
f"LR: {metrics['lr']:.2e} | Time: {step_time:.3f}s")
if step % 20 == 0 and hasattr(self.engine, 'monitor'):
self.log_memory_stats()
if step >= 50:
break
avg_loss = epoch_loss / epoch_steps
print(f" Epoch {epoch + 1} completed | Average Loss: {avg_loss:.4f}")
print(" Training completed!")
def log_memory_stats(self):
"""Log GPU memory statistics"""
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / 1024**3
reserved = torch.cuda.memory_reserved() / 1024**3
print(f" GPU Memory - Allocated: {allocated:.2f}GB | Reserved: {reserved:.2f}GB")
def save_checkpoint(self, path: str):
"""Save model checkpoint using DeepSpeed"""
print(f" Saving checkpoint to {path}")
self.engine.save_checkpoint(path)
def demonstrate_inference(self, text: str = "The future of AI is"):
"""Demonstrate optimized inference with DeepSpeed"""
print(f"\n Running inference with prompt: '{text}'")
inputs = self.tokenizer.encode(text, return_tensors='pt').to(self.engine.device)
self.engine.eval()
with torch.no_grad():
outputs = self.engine.module.generate(
inputs,
max_length=inputs.shape[1] + 50,
num_return_sequences=1,
temperature=0.8,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f" Generated text: {generated_text}")
self.engine.train()
Запуск учебного сценария и бенчмарки
Функция run_advanced_tutorial выполняет всю последовательность: подбор конфигураций, создание модели, инициализация DeepSpeed, создание синтетического датасета, обучение, инференс и сохранение чекпоинта. Также есть процедуры для демонстрации ZeRO-стадий и оптимизаций памяти, а бенчмарк сравнивает пиковую память и время на шаг.
def run_advanced_tutorial():
"""Main function to run the advanced DeepSpeed tutorial"""
print(" Advanced DeepSpeed Tutorial Starting...")
print("=" * 60)
model_config = {
'vocab_size': 50257,
'seq_length': 512,
'hidden_size': 768,
'num_layers': 6,
'num_heads': 12
}
ds_config = {
'train_batch_size': 16,
'micro_batch_size': 4,
'gradient_accumulation_steps': 4,
'zero_stage': 2,
'learning_rate': 1e-4,
'cpu_offload': False
}
print(" Configuration:")
print(f" Model size: ~{sum(np.prod(shape) for shape in [[model_config['vocab_size'], model_config['hidden_size']], [model_config['hidden_size'], model_config['hidden_size']] * model_config['num_layers']]) / 1e6:.1f}M parameters")
print(f" ZeRO Stage: {ds_config['zero_stage']}")
print(f" Batch size: {ds_config['train_batch_size']}")
trainer = AdvancedDeepSpeedTrainer(model_config, ds_config)
model = trainer.create_model()
engine = trainer.initialize_deepspeed()
print("\n Creating synthetic dataset...")
dataset = SyntheticTextDataset(
size=200,
seq_length=model_config['seq_length'],
vocab_size=model_config['vocab_size']
)
dataloader = DataLoader(
dataset,
batch_size=ds_config['micro_batch_size'],
shuffle=True
)
print("\n Pre-training memory stats:")
trainer.log_memory_stats()
trainer.train(dataloader, num_epochs=2)
print("\n Post-training memory stats:")
trainer.log_memory_stats()
trainer.demonstrate_inference("DeepSpeed enables efficient training of")
checkpoint_path = "./deepspeed_checkpoint"
trainer.save_checkpoint(checkpoint_path)
demonstrate_zero_stages()
demonstrate_memory_optimization()
print("\n Tutorial completed successfully!")
print("Key DeepSpeed features demonstrated:")
print(" ZeRO optimization for memory efficiency")
print(" Mixed precision training (FP16)")
print(" Gradient accumulation")
print(" Learning rate scheduling")
print(" Checkpoint saving/loading")
print(" Memory monitoring")
Практические советы
- Включите GPU-рантайм в Colab для корректной работы.
- Уменьшайте размер батча или архитектуру модели при проблемах с памятью.
- Включайте CPU offload для переносa состояний оптимизатора, если память GPU ограничена.
- Используйте синтетические данные для быстрой отладки тренировочного конвейера.