Speed Up DistilBERT: Optimize, Export and Quantize with Hugging Face Optimum + ONNX Runtime

Setup and installation

Install required libraries and dependencies to run the end-to-end optimization workflow in Colab or any Python environment:

!pip -q install "transformers>=4.49" "optimum[onnxruntime]>=1.20.0" "datasets>=2.20" "evaluate>=0.4" accelerate

Environment, constants and imports

The next snippet sets up imports, environment variables, model identifiers, paths and runtime options (CPU/GPU). It also prints the chosen device.

from pathlib import Path
import os, time, numpy as np, torch
from datasets import load_dataset
import evaluate
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
from optimum.onnxruntime import ORTModelForSequenceClassification, ORTQuantizer
from optimum.onnxruntime.configuration import QuantizationConfig


os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")


MODEL_ID = "distilbert-base-uncased-finetuned-sst-2-english"
ORT_DIR  = Path("onnx-distilbert")
Q_DIR    = Path("onnx-distilbert-quant")
DEVICE   = "cuda" if torch.cuda.is_available() else "cpu"
BATCH    = 16
MAXLEN   = 128
N_WARM   = 3
N_ITERS  = 8


print(f"Device: {DEVICE} | torch={torch.__version__}")

Data loading and evaluation helpers

Load a slice of the SST-2 validation set, prepare a tokenizer and an accuracy metric. The helper functions create batches, evaluate a predictive function and measure end-to-end latency with warmup iterations.

ds = load_dataset("glue", "sst2", split="validation[:20%]")
texts, labels = ds["sentence"], ds["label"]
metric = evaluate.load("accuracy")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)


def make_batches(texts, max_len=MAXLEN, batch=BATCH):
   for i in range(0, len(texts), batch):
       yield tokenizer(texts[i:i+batch], padding=True, truncation=True,
                       max_length=max_len, return_tensors="pt")


def run_eval(predict_fn, texts, labels):
   preds = []
   for toks in make_batches(texts):
       preds.extend(predict_fn(toks))
   return metric.compute(predictions=preds, references=labels)["accuracy"]


def bench(predict_fn, texts, n_warm=N_WARM, n_iters=N_ITERS):
   for _ in range(n_warm):
       for toks in make_batches(texts[:BATCH*2]):
           predict_fn(toks)
   times = []
   for _ in range(n_iters):
       t0 = time.time()
       for toks in make_batches(texts):
           predict_fn(toks)
       times.append((time.time() - t0) * 1000)
   return float(np.mean(times)), float(np.std(times))

Baseline PyTorch and optional torch.compile

Load the PyTorch DistilBERT classifier, define a predict helper, benchmark and compute accuracy. The snippet also attempts to compile the model with torch.compile (if available) and runs the same measurements for a fair comparison.

torch_model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID).to(DEVICE).eval()


@torch.no_grad()
def pt_predict(toks):
   toks = {k: v.to(DEVICE) for k, v in toks.items()}
   logits = torch_model(**toks).logits
   return logits.argmax(-1).detach().cpu().tolist()


pt_ms, pt_sd = bench(pt_predict, texts)
pt_acc = run_eval(pt_predict, texts, labels)
print(f"[PyTorch eager]   {pt_ms:.1f}±{pt_sd:.1f} ms | acc={pt_acc:.4f}")


compiled_model = torch_model
compile_ok = False
try:
   compiled_model = torch.compile(torch_model, mode="reduce-overhead", fullgraph=False)
   compile_ok = True
except Exception as e:
   print("torch.compile unavailable or failed -> skipping:", repr(e))


@torch.no_grad()
def ptc_predict(toks):
   toks = {k: v.to(DEVICE) for k, v in toks.items()}
   logits = compiled_model(**toks).logits
   return logits.argmax(-1).detach().cpu().tolist()


if compile_ok:
   ptc_ms, ptc_sd = bench(ptc_predict, texts)
   ptc_acc = run_eval(ptc_predict, texts, labels)
   print(f"[torch.compile]   {ptc_ms:.1f}±{ptc_sd:.1f} ms | acc={ptc_acc:.4f}")

Export to ONNX Runtime and dynamic quantization

Export the model to ONNX with Optimum’s ORTModelForSequenceClassification, run it with ONNX Runtime, then apply dynamic quantization using ORTQuantizer. Benchmark both the ONNX and quantized ONNX variants to see latency and accuracy changes.

provider = "CUDAExecutionProvider" if DEVICE == "cuda" else "CPUExecutionProvider"
ort_model = ORTModelForSequenceClassification.from_pretrained(
   MODEL_ID, export=True, provider=provider, cache_dir=ORT_DIR
)


@torch.no_grad()
def ort_predict(toks):
   logits = ort_model(**{k: v.cpu() for k, v in toks.items()}).logits
   return logits.argmax(-1).cpu().tolist()


ort_ms, ort_sd = bench(ort_predict, texts)
ort_acc = run_eval(ort_predict, texts, labels)
print(f"[ONNX Runtime]    {ort_ms:.1f}±{ort_sd:.1f} ms | acc={ort_acc:.4f}")


Q_DIR.mkdir(parents=True, exist_ok=True)
quantizer = ORTQuantizer.from_pretrained(ORT_DIR)
qconfig = QuantizationConfig(approach="dynamic", per_channel=False, reduce_range=True)
quantizer.quantize(model_input=ORT_DIR, quantization_config=qconfig, save_dir=Q_DIR)


ort_quant = ORTModelForSequenceClassification.from_pretrained(Q_DIR, provider=provider)


@torch.no_grad()
def ortq_predict(toks):
   logits = ort_quant(**{k: v.cpu() for k, v in toks.items()}).logits
   return logits.argmax(-1).cpu().tolist()


oq_ms, oq_sd = bench(ortq_predict, texts)
oq_acc = run_eval(ortq_predict, texts, labels)
print(f"[ORT Quantized]   {oq_ms:.1f}±{oq_sd:.1f} ms | acc={oq_acc:.4f}")

Semantic sanity checks and results table

Run sentiment pipelines for a few sample sentences to verify outputs between PyTorch and ONNX. Then assemble a small table comparing mean latency, std and accuracy across engines. The final notes suggest further avenues (FlashAttention2, FP8, TensorRT-LLM, static quantization, CPU thread tuning).

pt_pipe  = pipeline("sentiment-analysis", model=torch_model, tokenizer=tokenizer,
                   device=0 if DEVICE=="cuda" else -1)
ort_pipe = pipeline("sentiment-analysis", model=ort_model, tokenizer=tokenizer, device=-1)
samples = [
   "What a fantastic movie—performed brilliantly!",
   "This was a complete waste of time.",
   "I’m not sure how I feel about this one."
]
print("\nSample predictions (PT | ORT):")
for s in samples:
   a = pt_pipe(s)[0]["label"]
   b = ort_pipe(s)[0]["label"]
   print(f"- {s}\n  PT={a} | ORT={b}")


import pandas as pd
rows = [["PyTorch eager", pt_ms, pt_sd, pt_acc],
       ["ONNX Runtime",  ort_ms, ort_sd, ort_acc],
       ["ORT Quantized", oq_ms, oq_sd, oq_acc]]
if compile_ok: rows.insert(1, ["torch.compile", ptc_ms, ptc_sd, ptc_acc])
df = pd.DataFrame(rows, columns=["Engine", "Mean ms (↓)", "Std ms", "Accuracy"])
display(df)


print("""
Notes:
- BetterTransformer is deprecated on transformers>=4.49, hence omitted.
- For larger gains on GPU, also try FlashAttention2 models or FP8 with TensorRT-LLM.
- For CPU, tune threads: set OMP_NUM_THREADS/MKL_NUM_THREADS; try NUMA pinning.
- For static (calibrated) quantization, use QuantizationConfig(approach='static') with a calibration set.
""")

This end-to-end recipe demonstrates how Optimum and ONNX Runtime can help move a Transformer from a research PyTorch model to a faster, production-ready runtime with quantization options while keeping accuracy comparable. Use the provided helpers (bench/run_eval) to ensure fair comparisons across different execution engines and to experiment with other backends and quantization modes.