Обучение и визуализация поведения робота на PushT с LeRobot — практическое руководство

Установка и зависимости

Начинаем с установки LeRobot и сопутствующих библиотек, затем импортируем стандартные модули Python и PyTorch. В руководстве фиксируется seed для воспроизводимости и определяется, доступна ли GPU, чтобы тот же код работал эффективно на Colab или локальной машине.

!pip -q install --upgrade lerobot torch torchvision timm imageio[ffmpeg]


import os, math, random, io, sys, json, pathlib, time
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision.utils import make_grid, save_image
import numpy as np
import imageio.v2 as imageio


try:
   from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
except Exception:
   from lerobot.datasets.lerobot_dataset import LeRobotDataset


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

Загрузка PushT через LeRobot

Мы загружаем датасет PushT через унифицированный API LeRobot, исследуем один пример, чтобы определить ключи для изображений, состояний и действий, и сопоставляем их для удобного доступа в дальнейшем.

REPO_ID = "lerobot/pusht" 
ds = LeRobotDataset(REPO_ID) 
print("Dataset length:", len(ds))


s0 = ds[0]
keys = list(s0.keys())
print("Sample keys:", keys)


def key_with(prefixes):
   for k in keys:
       for p in prefixes:
           if k.startswith(p): return k
   return None


K_IMG = key_with(["observation.image", "observation.images", "observation.rgb"])
K_STATE = key_with(["observation.state"])
K_ACT = "action"
assert K_ACT in s0, f"No 'action' key found in sample. Found: {keys}"
print("Using keys -> IMG:", K_IMG, "STATE:", K_STATE, "ACT:", K_ACT)

Обёртка датасета и DataLoader’ы

Мы создаём обёртку вокруг исходного датасета так, чтобы каждый пример возвращал нормализованное изображение 96×96, плоский тензор состояния и тензор действия. Если в данных присутствует стек кадров, берём последний кадр. Для быстрого запуска в Colab ограничиваем размер тренировочной и валидационной выборок и создаём эффективные DataLoader’ы.

class PushTWrapper(torch.utils.data.Dataset):
   def __init__(self, base):
       self.base = base
   def __len__(self): return len(self.base)
   def __getitem__(self, i):
       x = self.base[i]
       img = x[K_IMG]
       if img.ndim == 4: img = img[-1]
       img = img.float() / 255.0 if img.dtype==torch.uint8 else img.float()
       state = x.get(K_STATE, torch.zeros(2))
       state = state.float().reshape(-1)
       act = x[K_ACT].float().reshape(-1)
       if img.shape[-2:] != (96,96):
           img = F.interpolate(img.unsqueeze(0), size=(96,96), mode="bilinear", align_corners=False)[0]
       return {"image": img, "state": state, "action": act}


wrapped = PushTWrapper(ds)
N = len(wrapped)
idx = list(range(N))
random.shuffle(idx)
n_train = int(0.9*N)
train_idx, val_idx = idx[:n_train], idx[n_train:]


train_ds = Subset(wrapped, train_idx[:12000])
val_ds   = Subset(wrapped, val_idx[:2000])


BATCH = 128
train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH, shuffle=False, num_workers=2, pin_memory=True)

Модель: компактная визуомоторная политика

Небольшая сверточная backbone извлекает признаки из изображения, которые конкатенируются с состоянием робота и передаются в MLP для предсказания двухмерного действия. Архитектура лёгкая, чтобы ускорить эксперименты в Colab.

class SmallBackbone(nn.Module):
   def __init__(self, out=256):
       super().__init__()
       self.conv = nn.Sequential(
           nn.Conv2d(3, 32, 5, 2, 2), nn.ReLU(inplace=True),
           nn.Conv2d(32, 64, 3, 2, 1), nn.ReLU(inplace=True),
           nn.Conv2d(64,128, 3, 2, 1), nn.ReLU(inplace=True),
           nn.Conv2d(128,128,3, 1, 1), nn.ReLU(inplace=True),
       )
       self.head = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(128, out), nn.ReLU(inplace=True))
   def forward(self, x): return self.head(self.conv(x))


class BCPolicy(nn.Module):
   def __init__(self, img_dim=256, state_dim=2, hidden=256, act_dim=2):
       super().__init__()
       self.backbone = SmallBackbone(img_dim)
       self.mlp = nn.Sequential(
           nn.Linear(img_dim + state_dim, hidden), nn.ReLU(inplace=True),
           nn.Linear(hidden, hidden//2), nn.ReLU(inplace=True),
           nn.Linear(hidden//2, act_dim)
       )
   def forward(self, img, state):
       z = self.backbone(img)
       if state.ndim==1: state = state.unsqueeze(0)
       z = torch.cat([z, state], dim=-1)
       return self.mlp(z)


policy = BCPolicy().to(DEVICE)
opt = torch.optim.AdamW(policy.parameters(), lr=3e-4, weight_decay=1e-4)
scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE=="cuda"))

Цикл обучения и оценка

Обучение проводится с AdamW, косинусным уменьшением шага обучения, смешанной точностью и обрезкой градиента. Валидационная MSE используется для выбора лучшей модели, которая сохраняется в чекпоинт.

@torch.no_grad()
def evaluate():
   policy.eval()
   mse, n = 0.0, 0
   for batch in val_loader:
       img = batch["image"].to(DEVICE, non_blocking=True)
       st  = batch["state"].to(DEVICE, non_blocking=True)
       act = batch["action"].to(DEVICE, non_blocking=True)
       pred = policy(img, st)
       mse += F.mse_loss(pred, act, reduction="sum").item()
       n += act.numel()
   return mse / n


def cosine_lr(step, total, base=3e-4, min_lr=3e-5):
   if step>=total: return min_lr
   cos = 0.5*(1+math.cos(math.pi*step/total))
   return min_lr + (base-min_lr)*cos


EPOCHS = 4 
steps_total = EPOCHS*len(train_loader)
step = 0
best = float("inf")
ckpt = "/content/lerobot_pusht_bc.pt"


for epoch in range(EPOCHS):
   policy.train()
   for batch in train_loader:
       lr = cosine_lr(step, steps_total); step += 1
       for g in opt.param_groups: g["lr"] = lr


       img = batch["image"].to(DEVICE, non_blocking=True)
       st  = batch["state"].to(DEVICE, non_blocking=True)
       act = batch["action"].to(DEVICE, non_blocking=True)


       opt.zero_grad(set_to_none=True)
       with torch.cuda.amp.autocast(enabled=(DEVICE=="cuda")):
           pred = policy(img, st)
           loss = F.smooth_l1_loss(pred, act)
       scaler.scale(loss).backward()
       nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
       scaler.step(opt); scaler.update()


   val_mse = evaluate()
   print(f"Epoch {epoch+1}/{EPOCHS} | Val MSE: {val_mse:.6f}")
   if val_mse < best:
       best = val_mse
       torch.save({"state_dict": policy.state_dict(), "val_mse": best}, ckpt)


print("Best Val MSE:", best, "| Saved:", ckpt)

Визуализация предсказаний

После обучения загружаем лучший чекпоинт, переключаем модель в eval и создаём короткое видео и сетку изображений. В примере на кадры накладываются стрелки предсказанных действий и результат сохраняется в MP4, что позволяет быстро посмотреть, что модель предсказывает на настоящих наблюдениях PushT.

policy.load_state_dict(torch.load(ckpt)["state_dict"]); policy.eval()
os.makedirs("/content/vis", exist_ok=True)


def draw_arrow(imgCHW, action_xy, scale=40):
   import PIL.Image, PIL.ImageDraw
   C,H,W = imgCHW.shape
   arr = (imgCHW.clamp(0,1).permute(1,2,0).cpu().numpy()*255).astype(np.uint8)
   im = PIL.Image.fromarray(arr)
   dr = PIL.ImageDraw.Draw(im)
   cx, cy = W//2, H//2
   dx, dy = float(action_xy[0])*scale, float(-action_xy[1])*scale
   dr.line((cx, cy, cx+dx, cy+dy), width=3, fill=(0,255,0))
   return np.array(im)


frames = []
with torch.no_grad():
   for i in range(60):
       b = wrapped[i]
       img = b["image"].unsqueeze(0).to(DEVICE)
       st  = b["state"].unsqueeze(0).to(DEVICE)
       pred = policy(img, st)[0].cpu()
       frames.append(draw_arrow(b["image"], pred))
video_path = "/content/vis/pusht_pred.mp4"
imageio.mimsave(video_path, frames, fps=10)
print("Wrote", video_path)


grid = make_grid(torch.stack([wrapped[i]["image"] for i in range(16)]), nrow=8)
save_image(grid, "/content/vis/grid.png")
print("Saved grid:", "/content/vis/grid.png")

Данный конвейер демонстрирует, как LeRobot объединяет доступ к данным, определение модели, цикл обучения и визуализацию, позволяя запускать воспроизводимые эксперименты по обучению роботов без подключения к аппаратуре. Далее легко сменить backbone, голову политики или метод обучения и опубликовать модели на Hugging Face Hub.