Обучение и визуализация поведения робота на PushT с LeRobot — практическое руководство
Установка и зависимости
Начинаем с установки LeRobot и сопутствующих библиотек, затем импортируем стандартные модули Python и PyTorch. В руководстве фиксируется seed для воспроизводимости и определяется, доступна ли GPU, чтобы тот же код работал эффективно на Colab или локальной машине.
!pip -q install --upgrade lerobot torch torchvision timm imageio[ffmpeg]
import os, math, random, io, sys, json, pathlib, time
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision.utils import make_grid, save_image
import numpy as np
import imageio.v2 as imageio
try:
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
except Exception:
from lerobot.datasets.lerobot_dataset import LeRobotDataset
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
Загрузка PushT через LeRobot
Мы загружаем датасет PushT через унифицированный API LeRobot, исследуем один пример, чтобы определить ключи для изображений, состояний и действий, и сопоставляем их для удобного доступа в дальнейшем.
REPO_ID = "lerobot/pusht"
ds = LeRobotDataset(REPO_ID)
print("Dataset length:", len(ds))
s0 = ds[0]
keys = list(s0.keys())
print("Sample keys:", keys)
def key_with(prefixes):
for k in keys:
for p in prefixes:
if k.startswith(p): return k
return None
K_IMG = key_with(["observation.image", "observation.images", "observation.rgb"])
K_STATE = key_with(["observation.state"])
K_ACT = "action"
assert K_ACT in s0, f"No 'action' key found in sample. Found: {keys}"
print("Using keys -> IMG:", K_IMG, "STATE:", K_STATE, "ACT:", K_ACT)
Обёртка датасета и DataLoader’ы
Мы создаём обёртку вокруг исходного датасета так, чтобы каждый пример возвращал нормализованное изображение 96×96, плоский тензор состояния и тензор действия. Если в данных присутствует стек кадров, берём последний кадр. Для быстрого запуска в Colab ограничиваем размер тренировочной и валидационной выборок и создаём эффективные DataLoader’ы.
class PushTWrapper(torch.utils.data.Dataset):
def __init__(self, base):
self.base = base
def __len__(self): return len(self.base)
def __getitem__(self, i):
x = self.base[i]
img = x[K_IMG]
if img.ndim == 4: img = img[-1]
img = img.float() / 255.0 if img.dtype==torch.uint8 else img.float()
state = x.get(K_STATE, torch.zeros(2))
state = state.float().reshape(-1)
act = x[K_ACT].float().reshape(-1)
if img.shape[-2:] != (96,96):
img = F.interpolate(img.unsqueeze(0), size=(96,96), mode="bilinear", align_corners=False)[0]
return {"image": img, "state": state, "action": act}
wrapped = PushTWrapper(ds)
N = len(wrapped)
idx = list(range(N))
random.shuffle(idx)
n_train = int(0.9*N)
train_idx, val_idx = idx[:n_train], idx[n_train:]
train_ds = Subset(wrapped, train_idx[:12000])
val_ds = Subset(wrapped, val_idx[:2000])
BATCH = 128
train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH, shuffle=False, num_workers=2, pin_memory=True)
Модель: компактная визуомоторная политика
Небольшая сверточная backbone извлекает признаки из изображения, которые конкатенируются с состоянием робота и передаются в MLP для предсказания двухмерного действия. Архитектура лёгкая, чтобы ускорить эксперименты в Colab.
class SmallBackbone(nn.Module):
def __init__(self, out=256):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 32, 5, 2, 2), nn.ReLU(inplace=True),
nn.Conv2d(32, 64, 3, 2, 1), nn.ReLU(inplace=True),
nn.Conv2d(64,128, 3, 2, 1), nn.ReLU(inplace=True),
nn.Conv2d(128,128,3, 1, 1), nn.ReLU(inplace=True),
)
self.head = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(128, out), nn.ReLU(inplace=True))
def forward(self, x): return self.head(self.conv(x))
class BCPolicy(nn.Module):
def __init__(self, img_dim=256, state_dim=2, hidden=256, act_dim=2):
super().__init__()
self.backbone = SmallBackbone(img_dim)
self.mlp = nn.Sequential(
nn.Linear(img_dim + state_dim, hidden), nn.ReLU(inplace=True),
nn.Linear(hidden, hidden//2), nn.ReLU(inplace=True),
nn.Linear(hidden//2, act_dim)
)
def forward(self, img, state):
z = self.backbone(img)
if state.ndim==1: state = state.unsqueeze(0)
z = torch.cat([z, state], dim=-1)
return self.mlp(z)
policy = BCPolicy().to(DEVICE)
opt = torch.optim.AdamW(policy.parameters(), lr=3e-4, weight_decay=1e-4)
scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE=="cuda"))
Цикл обучения и оценка
Обучение проводится с AdamW, косинусным уменьшением шага обучения, смешанной точностью и обрезкой градиента. Валидационная MSE используется для выбора лучшей модели, которая сохраняется в чекпоинт.
@torch.no_grad()
def evaluate():
policy.eval()
mse, n = 0.0, 0
for batch in val_loader:
img = batch["image"].to(DEVICE, non_blocking=True)
st = batch["state"].to(DEVICE, non_blocking=True)
act = batch["action"].to(DEVICE, non_blocking=True)
pred = policy(img, st)
mse += F.mse_loss(pred, act, reduction="sum").item()
n += act.numel()
return mse / n
def cosine_lr(step, total, base=3e-4, min_lr=3e-5):
if step>=total: return min_lr
cos = 0.5*(1+math.cos(math.pi*step/total))
return min_lr + (base-min_lr)*cos
EPOCHS = 4
steps_total = EPOCHS*len(train_loader)
step = 0
best = float("inf")
ckpt = "/content/lerobot_pusht_bc.pt"
for epoch in range(EPOCHS):
policy.train()
for batch in train_loader:
lr = cosine_lr(step, steps_total); step += 1
for g in opt.param_groups: g["lr"] = lr
img = batch["image"].to(DEVICE, non_blocking=True)
st = batch["state"].to(DEVICE, non_blocking=True)
act = batch["action"].to(DEVICE, non_blocking=True)
opt.zero_grad(set_to_none=True)
with torch.cuda.amp.autocast(enabled=(DEVICE=="cuda")):
pred = policy(img, st)
loss = F.smooth_l1_loss(pred, act)
scaler.scale(loss).backward()
nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
scaler.step(opt); scaler.update()
val_mse = evaluate()
print(f"Epoch {epoch+1}/{EPOCHS} | Val MSE: {val_mse:.6f}")
if val_mse < best:
best = val_mse
torch.save({"state_dict": policy.state_dict(), "val_mse": best}, ckpt)
print("Best Val MSE:", best, "| Saved:", ckpt)
Визуализация предсказаний
После обучения загружаем лучший чекпоинт, переключаем модель в eval и создаём короткое видео и сетку изображений. В примере на кадры накладываются стрелки предсказанных действий и результат сохраняется в MP4, что позволяет быстро посмотреть, что модель предсказывает на настоящих наблюдениях PushT.
policy.load_state_dict(torch.load(ckpt)["state_dict"]); policy.eval()
os.makedirs("/content/vis", exist_ok=True)
def draw_arrow(imgCHW, action_xy, scale=40):
import PIL.Image, PIL.ImageDraw
C,H,W = imgCHW.shape
arr = (imgCHW.clamp(0,1).permute(1,2,0).cpu().numpy()*255).astype(np.uint8)
im = PIL.Image.fromarray(arr)
dr = PIL.ImageDraw.Draw(im)
cx, cy = W//2, H//2
dx, dy = float(action_xy[0])*scale, float(-action_xy[1])*scale
dr.line((cx, cy, cx+dx, cy+dy), width=3, fill=(0,255,0))
return np.array(im)
frames = []
with torch.no_grad():
for i in range(60):
b = wrapped[i]
img = b["image"].unsqueeze(0).to(DEVICE)
st = b["state"].unsqueeze(0).to(DEVICE)
pred = policy(img, st)[0].cpu()
frames.append(draw_arrow(b["image"], pred))
video_path = "/content/vis/pusht_pred.mp4"
imageio.mimsave(video_path, frames, fps=10)
print("Wrote", video_path)
grid = make_grid(torch.stack([wrapped[i]["image"] for i in range(16)]), nrow=8)
save_image(grid, "/content/vis/grid.png")
print("Saved grid:", "/content/vis/grid.png")
Данный конвейер демонстрирует, как LeRobot объединяет доступ к данным, определение модели, цикл обучения и визуализацию, позволяя запускать воспроизводимые эксперименты по обучению роботов без подключения к аппаратуре. Далее легко сменить backbone, голову политики или метод обучения и опубликовать модели на Hugging Face Hub.