Train Large Transformers on Colab with DeepSpeed: ZeRO, FP16 & Gradient Checkpointing
Why DeepSpeed and ZeRO matter
Training transformer models at scale requires careful management of GPU memory and communication overhead. DeepSpeed’s ZeRO optimization, mixed-precision (FP16), gradient accumulation and CPU offloading combine to let you train larger models or use larger effective batch sizes on constrained hardware like Colab.
Setting up Colab and dependencies
Start by installing PyTorch with CUDA support, DeepSpeed and related libraries so the environment is ready for model creation and training. The tutorial includes an installation helper that installs torch, torchvision, torchaudio, deepspeed, transformers, datasets, accelerate and wandb.
import subprocess
import sys
import os
import json
import time
from pathlib import Path
def install_dependencies():
"""Install required packages for DeepSpeed in Colab"""
print(" Installing DeepSpeed and dependencies...")
subprocess.check_call([
sys.executable, "-m", "pip", "install",
"torch", "torchvision", "torchaudio", "--index-url",
"https://download.pytorch.org/whl/cu118"
])
subprocess.check_call([sys.executable, "-m", "pip", "install", "deepspeed"])
subprocess.check_call([
sys.executable, "-m", "pip", "install",
"transformers", "datasets", "accelerate", "wandb"
])
print(" Installation complete!")
install_dependencies()
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import deepspeed
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
from typing import Dict, Any
import argparse
Synthetic dataset for fast iteration
To test training flows without downloading large corpora we create a synthetic dataset that generates random token sequences. This is useful for debugging memory, throughput and correctness of training loops.
class SyntheticTextDataset(Dataset):
"""Synthetic dataset for demonstration purposes"""
def __init__(self, size: int = 1000, seq_length: int = 512, vocab_size: int = 50257):
self.size = size
self.seq_length = seq_length
self.vocab_size = vocab_size
self.data = torch.randint(0, vocab_size, (size, seq_length))
def __len__(self):
return self.size
def __getitem__(self, idx):
return {
'input_ids': self.data[idx],
'labels': self.data[idx].clone()
}
Design of the advanced DeepSpeed trainer
The AdvancedDeepSpeedTrainer encapsulates model creation, DeepSpeed config generation, engine initialization, training step, checkpointing and an inference demo. The trainer shows how to set ZeRO stages, enable fp16 and tune optimizer and scheduler settings. The full implementation below is included unchanged so you can run it directly.
class AdvancedDeepSpeedTrainer:
"""Advanced DeepSpeed trainer with multiple optimization techniques"""
def __init__(self, model_config: Dict[str, Any], ds_config: Dict[str, Any]):
self.model_config = model_config
self.ds_config = ds_config
self.model = None
self.engine = None
self.tokenizer = None
def create_model(self):
"""Create a GPT-2 style model for demonstration"""
print(" Creating model...")
config = GPT2Config(
vocab_size=self.model_config['vocab_size'],
n_positions=self.model_config['seq_length'],
n_embd=self.model_config['hidden_size'],
n_layer=self.model_config['num_layers'],
n_head=self.model_config['num_heads'],
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
)
self.model = GPT2LMHeadModel(config)
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
self.tokenizer.pad_token = self.tokenizer.eos_token
print(f" Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")
return self.model
def create_deepspeed_config(self):
"""Create comprehensive DeepSpeed configuration"""
return {
"train_batch_size": self.ds_config['train_batch_size'],
"train_micro_batch_size_per_gpu": self.ds_config['micro_batch_size'],
"gradient_accumulation_steps": self.ds_config['gradient_accumulation_steps'],
"zero_optimization": {
"stage": self.ds_config['zero_stage'],
"allgather_partitions": True,
"allgather_bucket_size": 5e8,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 5e8,
"contiguous_gradients": True,
"cpu_offload": self.ds_config.get('cpu_offload', False)
},
"fp16": {
"enabled": True,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": self.ds_config['learning_rate'],
"betas": [0.9, 0.999],
"eps": 1e-8,
"weight_decay": 0.01
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": self.ds_config['learning_rate'],
"warmup_num_steps": 100
}
},
"gradient_clipping": 1.0,
"wall_clock_breakdown": True,
"memory_breakdown": True,
"tensorboard": {
"enabled": True,
"output_path": "./logs/",
"job_name": "deepspeed_advanced_tutorial"
}
}
def initialize_deepspeed(self):
"""Initialize DeepSpeed engine"""
print(" Initializing DeepSpeed...")
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args([])
self.engine, optimizer, _, lr_scheduler = deepspeed.initialize(
args=args,
model=self.model,
config=self.create_deepspeed_config()
)
print(f" DeepSpeed engine initialized with ZeRO stage {self.ds_config['zero_stage']}")
return self.engine
def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
"""Perform a single training step with DeepSpeed optimizations"""
input_ids = batch['input_ids'].to(self.engine.device)
labels = batch['labels'].to(self.engine.device)
outputs = self.engine(input_ids=input_ids, labels=labels)
loss = outputs.loss
self.engine.backward(loss)
self.engine.step()
return {
'loss': loss.item(),
'lr': self.engine.lr_scheduler.get_last_lr()[0] if self.engine.lr_scheduler else 0
}
def train(self, dataloader: DataLoader, num_epochs: int = 2):
"""Complete training loop with monitoring"""
print(f" Starting training for {num_epochs} epochs...")
self.engine.train()
total_steps = 0
for epoch in range(num_epochs):
epoch_loss = 0.0
epoch_steps = 0
print(f"\n Epoch {epoch + 1}/{num_epochs}")
for step, batch in enumerate(dataloader):
start_time = time.time()
metrics = self.train_step(batch)
epoch_loss += metrics['loss']
epoch_steps += 1
total_steps += 1
if step % 10 == 0:
step_time = time.time() - start_time
print(f" Step {step:4d} | Loss: {metrics['loss']:.4f} | "
f"LR: {metrics['lr']:.2e} | Time: {step_time:.3f}s")
if step % 20 == 0 and hasattr(self.engine, 'monitor'):
self.log_memory_stats()
if step >= 50:
break
avg_loss = epoch_loss / epoch_steps
print(f" Epoch {epoch + 1} completed | Average Loss: {avg_loss:.4f}")
print(" Training completed!")
def log_memory_stats(self):
"""Log GPU memory statistics"""
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / 1024**3
reserved = torch.cuda.memory_reserved() / 1024**3
print(f" GPU Memory - Allocated: {allocated:.2f}GB | Reserved: {reserved:.2f}GB")
def save_checkpoint(self, path: str):
"""Save model checkpoint using DeepSpeed"""
print(f" Saving checkpoint to {path}")
self.engine.save_checkpoint(path)
def demonstrate_inference(self, text: str = "The future of AI is"):
"""Demonstrate optimized inference with DeepSpeed"""
print(f"\n Running inference with prompt: '{text}'")
inputs = self.tokenizer.encode(text, return_tensors='pt').to(self.engine.device)
self.engine.eval()
with torch.no_grad():
outputs = self.engine.module.generate(
inputs,
max_length=inputs.shape[1] + 50,
num_return_sequences=1,
temperature=0.8,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f" Generated text: {generated_text}")
self.engine.train()
Running the end-to-end tutorial
The run_advanced_tutorial orchestrates config setup, model creation, engine initialization, synthetic data creation, training, inference demonstration and checkpoint saving. The tutorial also shows functions to demonstrate ZeRO stages and memory optimizations, and a benchmark routine compares memory and step time across ZeRO stages.
def run_advanced_tutorial():
"""Main function to run the advanced DeepSpeed tutorial"""
print(" Advanced DeepSpeed Tutorial Starting...")
print("=" * 60)
model_config = {
'vocab_size': 50257,
'seq_length': 512,
'hidden_size': 768,
'num_layers': 6,
'num_heads': 12
}
ds_config = {
'train_batch_size': 16,
'micro_batch_size': 4,
'gradient_accumulation_steps': 4,
'zero_stage': 2,
'learning_rate': 1e-4,
'cpu_offload': False
}
print(" Configuration:")
print(f" Model size: ~{sum(np.prod(shape) for shape in [[model_config['vocab_size'], model_config['hidden_size']], [model_config['hidden_size'], model_config['hidden_size']] * model_config['num_layers']]) / 1e6:.1f}M parameters")
print(f" ZeRO Stage: {ds_config['zero_stage']}")
print(f" Batch size: {ds_config['train_batch_size']}")
trainer = AdvancedDeepSpeedTrainer(model_config, ds_config)
model = trainer.create_model()
engine = trainer.initialize_deepspeed()
print("\n Creating synthetic dataset...")
dataset = SyntheticTextDataset(
size=200,
seq_length=model_config['seq_length'],
vocab_size=model_config['vocab_size']
)
dataloader = DataLoader(
dataset,
batch_size=ds_config['micro_batch_size'],
shuffle=True
)
print("\n Pre-training memory stats:")
trainer.log_memory_stats()
trainer.train(dataloader, num_epochs=2)
print("\n Post-training memory stats:")
trainer.log_memory_stats()
trainer.demonstrate_inference("DeepSpeed enables efficient training of")
checkpoint_path = "./deepspeed_checkpoint"
trainer.save_checkpoint(checkpoint_path)
demonstrate_zero_stages()
demonstrate_memory_optimization()
print("\n Tutorial completed successfully!")
print("Key DeepSpeed features demonstrated:")
print(" ZeRO optimization for memory efficiency")
print(" Mixed precision training (FP16)")
print(" Gradient accumulation")
print(" Learning rate scheduling")
print(" Checkpoint saving/loading")
print(" Memory monitoring")
def demonstrate_zero_stages():
"""Demonstrate different ZeRO optimization stages"""
print("\n ZeRO Optimization Stages Explained:")
print(" Stage 0: Disabled (baseline)")
print(" Stage 1: Optimizer state partitioning (~4x memory reduction)")
print(" Stage 2: Gradient partitioning (~8x memory reduction)")
print(" Stage 3: Parameter partitioning (~Nx memory reduction)")
zero_configs = {
1: {"stage": 1, "reduce_bucket_size": 5e8},
2: {"stage": 2, "allgather_partitions": True, "reduce_scatter": True},
3: {"stage": 3, "stage3_prefetch_bucket_size": 5e8, "stage3_param_persistence_threshold": 1e6}
}
for stage, config in zero_configs.items():
estimated_memory_reduction = [1, 4, 8, "Nx"][stage]
print(f" Stage {stage}: ~{estimated_memory_reduction}x memory reduction")
def demonstrate_memory_optimization():
"""Show memory optimization techniques"""
print("\n Memory Optimization Techniques:")
print(" Gradient Checkpointing: Trade compute for memory")
print(" CPU Offloading: Move optimizer states to CPU")
print(" Compression: Reduce communication overhead")
print(" Mixed Precision: Use FP16 for faster training")
Benchmarking and config generator
The tutorial includes a DeepSpeedConfigGenerator that produces tuned configs (including special handling for stage 3 and CPU offload) and a benchmark routine that runs short training loops to measure peak memory and step time across ZeRO stages.
class DeepSpeedConfigGenerator:
"""Utility class to generate DeepSpeed configurations"""
@staticmethod
def generate_config(
batch_size: int = 16,
zero_stage: int = 2,
use_cpu_offload: bool = False,
learning_rate: float = 1e-4
) -> Dict[str, Any]:
"""Generate a complete DeepSpeed configuration"""
config = {
"train_batch_size": batch_size,
"train_micro_batch_size_per_gpu": max(1, batch_size // 4),
"gradient_accumulation_steps": max(1, batch_size // max(1, batch_size // 4)),
"zero_optimization": {
"stage": zero_stage,
"allgather_partitions": True,
"allgather_bucket_size": 5e8,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 5e8,
"contiguous_gradients": True
},
"fp16": {
"enabled": True,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": learning_rate,
"betas": [0.9, 0.999],
"eps": 1e-8,
"weight_decay": 0.01
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": learning_rate,
"warmup_num_steps": 100
}
},
"gradient_clipping": 1.0,
"wall_clock_breakdown": True
}
if use_cpu_offload:
config["zero_optimization"]["cpu_offload"] = True
config["zero_optimization"]["pin_memory"] = True
if zero_stage == 3:
config["zero_optimization"].update({
"stage3_prefetch_bucket_size": 5e8,
"stage3_param_persistence_threshold": 1e6,
"stage3_gather_16bit_weights_on_model_save": True
})
return config
def benchmark_zero_stages():
"""Benchmark different ZeRO stages"""
print("\n Benchmarking ZeRO Stages...")
model_config = {
'vocab_size': 50257,
'seq_length': 256,
'hidden_size': 512,
'num_layers': 4,
'num_heads': 8
}
results = {}
for stage in [1, 2]:
print(f"\n Testing ZeRO Stage {stage}...")
ds_config = {
'train_batch_size': 8,
'micro_batch_size': 2,
'gradient_accumulation_steps': 4,
'zero_stage': stage,
'learning_rate': 1e-4
}
try:
trainer = AdvancedDeepSpeedTrainer(model_config, ds_config)
model = trainer.create_model()
engine = trainer.initialize_deepspeed()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
dataset = SyntheticTextDataset(size=20, seq_length=model_config['seq_length'])
dataloader = DataLoader(dataset, batch_size=ds_config['micro_batch_size'])
start_time = time.time()
for i, batch in enumerate(dataloader):
if i >= 5:
break
trainer.train_step(batch)
end_time = time.time()
peak_memory = torch.cuda.max_memory_allocated() / 1024**3
results[stage] = {
'peak_memory_gb': peak_memory,
'time_per_step': (end_time - start_time) / 5
}
print(f" Peak Memory: {peak_memory:.2f}GB")
print(f" Time per step: {results[stage]['time_per_step']:.3f}s")
del trainer, model, engine
torch.cuda.empty_cache()
except Exception as e:
print(f" Error with stage {stage}: {str(e)}")
if len(results) > 1:
print(f"\n Comparison:")
stage_1_memory = results.get(1, {}).get('peak_memory_gb', 0)
stage_2_memory = results.get(2, {}).get('peak_memory_gb', 0)
if stage_1_memory > 0 and stage_2_memory > 0:
memory_reduction = (stage_1_memory - stage_2_memory) / stage_1_memory * 100
print(f" Memory reduction from Stage 1 to 2: {memory_reduction:.1f}%")
Advanced features and practical tips
The tutorial points out additional DeepSpeed features: dynamic loss scaling, gradient compression, pipeline and expert parallelism (MoE), and curriculum learning strategies. Practical debugging tips include enabling GPU runtime in Colab, reducing batch/model size when encountering memory issues and enabling CPU offload if needed.
def demonstrate_advanced_features():
"""Demonstrate additional advanced DeepSpeed features"""
print("\n Advanced DeepSpeed Features:")
print(" Dynamic Loss Scaling: Automatically adjusts FP16 loss scaling")
print(" Gradient Compression: Reduces communication overhead")
print(" Pipeline Parallelism: Splits model across devices")
print(" Expert Parallelism: Efficient Mixture-of-Experts training")
print(" Curriculum Learning: Progressive training strategies")
if __name__ == "__main__":
print(f" CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f" GPU: {torch.cuda.get_device_name()}")
print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB")
try:
run_advanced_tutorial()
benchmark_zero_stages()
demonstrate_advanced_features()
except Exception as e:
print(f" Error during tutorial: {str(e)}")
print(" Tips for troubleshooting:")
print(" - Ensure you have GPU runtime enabled in Colab")
print(" - Try reducing batch_size or model size if facing memory issues")
print(" - Enable CPU offloading in ds_config if needed")
Putting it together
This hands-on walkthrough gives you a reproducible trainer, config generator and benchmark scripts to evaluate memory and speed trade-offs across ZeRO stages. Use the provided code blocks directly in Colab, tweak ds_config and model sizes, and iterate to find the right balance between compute and memory for your training budget.