Ускоряем DistilBERT: оптимизация, экспорт и квантизация с Hugging Face Optimum и ONNX Runtime
Установка и зависимые пакеты
Установите необходимые библиотеки для выполнения полного рабочего процесса оптимизации в Colab или любой другой среде Python:
!pip -q install "transformers>=4.49" "optimum[onnxruntime]>=1.20.0" "datasets>=2.20" "evaluate>=0.4" accelerate
Импорты, переменные окружения и константы
Этот фрагмент настраивает импорты, переменные окружения, идентификаторы модели, пути и выбор устройства (CPU/GPU). Также выводит информацию об устройстве.
from pathlib import Path
import os, time, numpy as np, torch
from datasets import load_dataset
import evaluate
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
from optimum.onnxruntime import ORTModelForSequenceClassification, ORTQuantizer
from optimum.onnxruntime.configuration import QuantizationConfig
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
MODEL_ID = "distilbert-base-uncased-finetuned-sst-2-english"
ORT_DIR = Path("onnx-distilbert")
Q_DIR = Path("onnx-distilbert-quant")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH = 16
MAXLEN = 128
N_WARM = 3
N_ITERS = 8
print(f"Device: {DEVICE} | torch={torch.__version__}")
Загрузка данных и вспомогательные функции
Загружаем срез валидационного набора SST-2, подготавливаем токенизатор и метрику точности. Функции помогают формировать батчи, оценивать предсказание и измерять латентность с прогревом.
ds = load_dataset("glue", "sst2", split="validation[:20%]")
texts, labels = ds["sentence"], ds["label"]
metric = evaluate.load("accuracy")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
def make_batches(texts, max_len=MAXLEN, batch=BATCH):
for i in range(0, len(texts), batch):
yield tokenizer(texts[i:i+batch], padding=True, truncation=True,
max_length=max_len, return_tensors="pt")
def run_eval(predict_fn, texts, labels):
preds = []
for toks in make_batches(texts):
preds.extend(predict_fn(toks))
return metric.compute(predictions=preds, references=labels)["accuracy"]
def bench(predict_fn, texts, n_warm=N_WARM, n_iters=N_ITERS):
for _ in range(n_warm):
for toks in make_batches(texts[:BATCH*2]):
predict_fn(toks)
times = []
for _ in range(n_iters):
t0 = time.time()
for toks in make_batches(texts):
predict_fn(toks)
times.append((time.time() - t0) * 1000)
return float(np.mean(times)), float(np.std(times))
Базовая PyTorch модель и torch.compile (опционально)
Загружаем классификатор DistilBERT в PyTorch, определяем функцию предсказания, замеряем время и точность. Попытка задействовать torch.compile выполняется при доступности функции, и результаты сравниваются с другим окружением.
torch_model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID).to(DEVICE).eval()
@torch.no_grad()
def pt_predict(toks):
toks = {k: v.to(DEVICE) for k, v in toks.items()}
logits = torch_model(**toks).logits
return logits.argmax(-1).detach().cpu().tolist()
pt_ms, pt_sd = bench(pt_predict, texts)
pt_acc = run_eval(pt_predict, texts, labels)
print(f"[PyTorch eager] {pt_ms:.1f}±{pt_sd:.1f} ms | acc={pt_acc:.4f}")
compiled_model = torch_model
compile_ok = False
try:
compiled_model = torch.compile(torch_model, mode="reduce-overhead", fullgraph=False)
compile_ok = True
except Exception as e:
print("torch.compile unavailable or failed -> skipping:", repr(e))
@torch.no_grad()
def ptc_predict(toks):
toks = {k: v.to(DEVICE) for k, v in toks.items()}
logits = compiled_model(**toks).logits
return logits.argmax(-1).detach().cpu().tolist()
if compile_ok:
ptc_ms, ptc_sd = bench(ptc_predict, texts)
ptc_acc = run_eval(ptc_predict, texts, labels)
print(f"[torch.compile] {ptc_ms:.1f}±{ptc_sd:.1f} ms | acc={ptc_acc:.4f}")
Экспорт в ONNX Runtime и динамическая квантизация
Экспортируем модель в ONNX с помощью Optimum, запускаем через ONNX Runtime, затем применяем динамическую квантизацию через ORTQuantizer. Замеряем производительность и точность до и после квантизации.
provider = "CUDAExecutionProvider" if DEVICE == "cuda" else "CPUExecutionProvider"
ort_model = ORTModelForSequenceClassification.from_pretrained(
MODEL_ID, export=True, provider=provider, cache_dir=ORT_DIR
)
@torch.no_grad()
def ort_predict(toks):
logits = ort_model(**{k: v.cpu() for k, v in toks.items()}).logits
return logits.argmax(-1).cpu().tolist()
ort_ms, ort_sd = bench(ort_predict, texts)
ort_acc = run_eval(ort_predict, texts, labels)
print(f"[ONNX Runtime] {ort_ms:.1f}±{ort_sd:.1f} ms | acc={ort_acc:.4f}")
Q_DIR.mkdir(parents=True, exist_ok=True)
quantizer = ORTQuantizer.from_pretrained(ORT_DIR)
qconfig = QuantizationConfig(approach="dynamic", per_channel=False, reduce_range=True)
quantizer.quantize(model_input=ORT_DIR, quantization_config=qconfig, save_dir=Q_DIR)
ort_quant = ORTModelForSequenceClassification.from_pretrained(Q_DIR, provider=provider)
@torch.no_grad()
def ortq_predict(toks):
logits = ort_quant(**{k: v.cpu() for k, v in toks.items()}).logits
return logits.argmax(-1).cpu().tolist()
oq_ms, oq_sd = bench(ortq_predict, texts)
oq_acc = run_eval(ortq_predict, texts, labels)
print(f"[ORT Quantized] {oq_ms:.1f}±{oq_sd:.1f} ms | acc={oq_acc:.4f}")
Проверка предсказаний и сводная таблица результатов
Прогоняем пару примеров через пайплайны и формируем таблицу с измерениями для наглядного сравнения. В заметках перечислены возможные дальнейшие шаги и улучшения (FlashAttention2, FP8, TensorRT-LLM, статическая калиброванная квантизация, настройка потоков на CPU).
pt_pipe = pipeline("sentiment-analysis", model=torch_model, tokenizer=tokenizer,
device=0 if DEVICE=="cuda" else -1)
ort_pipe = pipeline("sentiment-analysis", model=ort_model, tokenizer=tokenizer, device=-1)
samples = [
"What a fantastic movie—performed brilliantly!",
"This was a complete waste of time.",
"I’m not sure how I feel about this one."
]
print("\nSample predictions (PT | ORT):")
for s in samples:
a = pt_pipe(s)[0]["label"]
b = ort_pipe(s)[0]["label"]
print(f"- {s}\n PT={a} | ORT={b}")
import pandas as pd
rows = [["PyTorch eager", pt_ms, pt_sd, pt_acc],
["ONNX Runtime", ort_ms, ort_sd, ort_acc],
["ORT Quantized", oq_ms, oq_sd, oq_acc]]
if compile_ok: rows.insert(1, ["torch.compile", ptc_ms, ptc_sd, ptc_acc])
df = pd.DataFrame(rows, columns=["Engine", "Mean ms (↓)", "Std ms", "Accuracy"])
display(df)
print("""
Notes:
- BetterTransformer is deprecated on transformers>=4.49, hence omitted.
- For larger gains on GPU, also try FlashAttention2 models or FP8 with TensorRT-LLM.
- For CPU, tune threads: set OMP_NUM_THREADS/MKL_NUM_THREADS; try NUMA pinning.
- For static (calibrated) quantization, use QuantizationConfig(approach='static') with a calibration set.
""")
Этот рецепт показывает практический путь от стандартной PyTorch-модели к более быстрой и пригодной для продакшена реализации с ONNX Runtime и опциями квантизации, при этом точность остаётся сопоставимой. Используйте функции bench и run_eval для честного сравнения разных движков и для экспериментов с другими бэкендами и режимами квантизации.