Ускоряем DistilBERT: оптимизация, экспорт и квантизация с Hugging Face Optimum и ONNX Runtime

Установка и зависимые пакеты

Установите необходимые библиотеки для выполнения полного рабочего процесса оптимизации в Colab или любой другой среде Python:

!pip -q install "transformers>=4.49" "optimum[onnxruntime]>=1.20.0" "datasets>=2.20" "evaluate>=0.4" accelerate

Импорты, переменные окружения и константы

Этот фрагмент настраивает импорты, переменные окружения, идентификаторы модели, пути и выбор устройства (CPU/GPU). Также выводит информацию об устройстве.

from pathlib import Path
import os, time, numpy as np, torch
from datasets import load_dataset
import evaluate
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
from optimum.onnxruntime import ORTModelForSequenceClassification, ORTQuantizer
from optimum.onnxruntime.configuration import QuantizationConfig


os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")


MODEL_ID = "distilbert-base-uncased-finetuned-sst-2-english"
ORT_DIR  = Path("onnx-distilbert")
Q_DIR    = Path("onnx-distilbert-quant")
DEVICE   = "cuda" if torch.cuda.is_available() else "cpu"
BATCH    = 16
MAXLEN   = 128
N_WARM   = 3
N_ITERS  = 8


print(f"Device: {DEVICE} | torch={torch.__version__}")

Загрузка данных и вспомогательные функции

Загружаем срез валидационного набора SST-2, подготавливаем токенизатор и метрику точности. Функции помогают формировать батчи, оценивать предсказание и измерять латентность с прогревом.

ds = load_dataset("glue", "sst2", split="validation[:20%]")
texts, labels = ds["sentence"], ds["label"]
metric = evaluate.load("accuracy")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)


def make_batches(texts, max_len=MAXLEN, batch=BATCH):
   for i in range(0, len(texts), batch):
       yield tokenizer(texts[i:i+batch], padding=True, truncation=True,
                       max_length=max_len, return_tensors="pt")


def run_eval(predict_fn, texts, labels):
   preds = []
   for toks in make_batches(texts):
       preds.extend(predict_fn(toks))
   return metric.compute(predictions=preds, references=labels)["accuracy"]


def bench(predict_fn, texts, n_warm=N_WARM, n_iters=N_ITERS):
   for _ in range(n_warm):
       for toks in make_batches(texts[:BATCH*2]):
           predict_fn(toks)
   times = []
   for _ in range(n_iters):
       t0 = time.time()
       for toks in make_batches(texts):
           predict_fn(toks)
       times.append((time.time() - t0) * 1000)
   return float(np.mean(times)), float(np.std(times))

Базовая PyTorch модель и torch.compile (опционально)

Загружаем классификатор DistilBERT в PyTorch, определяем функцию предсказания, замеряем время и точность. Попытка задействовать torch.compile выполняется при доступности функции, и результаты сравниваются с другим окружением.

torch_model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID).to(DEVICE).eval()


@torch.no_grad()
def pt_predict(toks):
   toks = {k: v.to(DEVICE) for k, v in toks.items()}
   logits = torch_model(**toks).logits
   return logits.argmax(-1).detach().cpu().tolist()


pt_ms, pt_sd = bench(pt_predict, texts)
pt_acc = run_eval(pt_predict, texts, labels)
print(f"[PyTorch eager]   {pt_ms:.1f}±{pt_sd:.1f} ms | acc={pt_acc:.4f}")


compiled_model = torch_model
compile_ok = False
try:
   compiled_model = torch.compile(torch_model, mode="reduce-overhead", fullgraph=False)
   compile_ok = True
except Exception as e:
   print("torch.compile unavailable or failed -> skipping:", repr(e))


@torch.no_grad()
def ptc_predict(toks):
   toks = {k: v.to(DEVICE) for k, v in toks.items()}
   logits = compiled_model(**toks).logits
   return logits.argmax(-1).detach().cpu().tolist()


if compile_ok:
   ptc_ms, ptc_sd = bench(ptc_predict, texts)
   ptc_acc = run_eval(ptc_predict, texts, labels)
   print(f"[torch.compile]   {ptc_ms:.1f}±{ptc_sd:.1f} ms | acc={ptc_acc:.4f}")

Экспорт в ONNX Runtime и динамическая квантизация

Экспортируем модель в ONNX с помощью Optimum, запускаем через ONNX Runtime, затем применяем динамическую квантизацию через ORTQuantizer. Замеряем производительность и точность до и после квантизации.

provider = "CUDAExecutionProvider" if DEVICE == "cuda" else "CPUExecutionProvider"
ort_model = ORTModelForSequenceClassification.from_pretrained(
   MODEL_ID, export=True, provider=provider, cache_dir=ORT_DIR
)


@torch.no_grad()
def ort_predict(toks):
   logits = ort_model(**{k: v.cpu() for k, v in toks.items()}).logits
   return logits.argmax(-1).cpu().tolist()


ort_ms, ort_sd = bench(ort_predict, texts)
ort_acc = run_eval(ort_predict, texts, labels)
print(f"[ONNX Runtime]    {ort_ms:.1f}±{ort_sd:.1f} ms | acc={ort_acc:.4f}")


Q_DIR.mkdir(parents=True, exist_ok=True)
quantizer = ORTQuantizer.from_pretrained(ORT_DIR)
qconfig = QuantizationConfig(approach="dynamic", per_channel=False, reduce_range=True)
quantizer.quantize(model_input=ORT_DIR, quantization_config=qconfig, save_dir=Q_DIR)


ort_quant = ORTModelForSequenceClassification.from_pretrained(Q_DIR, provider=provider)


@torch.no_grad()
def ortq_predict(toks):
   logits = ort_quant(**{k: v.cpu() for k, v in toks.items()}).logits
   return logits.argmax(-1).cpu().tolist()


oq_ms, oq_sd = bench(ortq_predict, texts)
oq_acc = run_eval(ortq_predict, texts, labels)
print(f"[ORT Quantized]   {oq_ms:.1f}±{oq_sd:.1f} ms | acc={oq_acc:.4f}")

Проверка предсказаний и сводная таблица результатов

Прогоняем пару примеров через пайплайны и формируем таблицу с измерениями для наглядного сравнения. В заметках перечислены возможные дальнейшие шаги и улучшения (FlashAttention2, FP8, TensorRT-LLM, статическая калиброванная квантизация, настройка потоков на CPU).

pt_pipe  = pipeline("sentiment-analysis", model=torch_model, tokenizer=tokenizer,
                   device=0 if DEVICE=="cuda" else -1)
ort_pipe = pipeline("sentiment-analysis", model=ort_model, tokenizer=tokenizer, device=-1)
samples = [
   "What a fantastic movie—performed brilliantly!",
   "This was a complete waste of time.",
   "I’m not sure how I feel about this one."
]
print("\nSample predictions (PT | ORT):")
for s in samples:
   a = pt_pipe(s)[0]["label"]
   b = ort_pipe(s)[0]["label"]
   print(f"- {s}\n  PT={a} | ORT={b}")


import pandas as pd
rows = [["PyTorch eager", pt_ms, pt_sd, pt_acc],
       ["ONNX Runtime",  ort_ms, ort_sd, ort_acc],
       ["ORT Quantized", oq_ms, oq_sd, oq_acc]]
if compile_ok: rows.insert(1, ["torch.compile", ptc_ms, ptc_sd, ptc_acc])
df = pd.DataFrame(rows, columns=["Engine", "Mean ms (↓)", "Std ms", "Accuracy"])
display(df)


print("""
Notes:
- BetterTransformer is deprecated on transformers>=4.49, hence omitted.
- For larger gains on GPU, also try FlashAttention2 models or FP8 with TensorRT-LLM.
- For CPU, tune threads: set OMP_NUM_THREADS/MKL_NUM_THREADS; try NUMA pinning.
- For static (calibrated) quantization, use QuantizationConfig(approach='static') with a calibration set.
""")

Этот рецепт показывает практический путь от стандартной PyTorch-модели к более быстрой и пригодной для продакшена реализации с ONNX Runtime и опциями квантизации, при этом точность остаётся сопоставимой. Используйте функции bench и run_eval для честного сравнения разных движков и для экспериментов с другими бэкендами и режимами квантизации.