<RETURN_TO_BASE

LLM-Pruning Collection: JAX Repo for LLM Compression

Explore an innovative repo consolidating LLM pruning algorithms for enhanced performance and efficiency.

Overview

Zlab Princeton researchers have released LLM-Pruning Collection, a JAX-based repository consolidating major pruning algorithms for large language models into a single, reproducible framework. It aims to facilitate the comparison of block level, layer level, and weight level pruning methods under a consistent training and evaluation stack on both GPUs and TPUs.

What LLM-Pruning Collection Contains?

This repo is organized into three main directories:

  • pruning: Holds implementations for several pruning methods: Minitron, ShortGPT, Wanda, SparseGPT, Magnitude, Sheared Llama, and LLM-Pruner.
  • training: Provides integration with FMS-FSDP for GPU training and MaxText for TPU training.
  • eval: Exposes JAX-compatible evaluation scripts built around lm-eval-harness, with accelerate-based support for MaxText, achieving about 2 to 4 times speedup.

Pruning Methods Covered

LLM-Pruning Collection includes several families of pruning algorithms with varying granularity levels:

Minitron

Minitron, developed by NVIDIA, compresses Llama 3.1 8B and Mistral NeMo 12B to 4B and 8B while maintaining performance. It employs depth pruning and joint width pruning of hidden sizes, attention, and MLP, followed by distillation. The pruning/minitron folder includes scripts like prune_llama3.1-8b.sh to execute Minitron-style pruning on Llama 3.1 8B.

ShortGPT

ShortGPT identifies redundant Transformer layers using Block Influence, a metric measuring each layer's contribution, facilitating low influence layer deletion. It outperforms previous pruning methods in multiple choice and generative tasks. In the collection, ShortGPT is implemented through the Minitron folder, specifically with the script prune_llama2-7b.sh.

Wanda, SparseGPT, Magnitude

Wanda is a post-training pruning method that scores weights by the product of weight magnitude and corresponding input activation, pruning the smallest scores without requiring retraining. SparseGPT utilizes a second-order inspired reconstruction step for high sparsity ratios, while Magnitude pruning is a classical baseline method that removes weights with small absolute value. In the LLM-Pruning Collection, all three are available under pruning/wanda.

Sheared Llama

Sheared LLaMA applies structured pruning by learning masks for layers, attention heads, and hidden dimensions while retraining the pruned architecture. This directory offers models at multiple scales, including 2.7B and 1.3B.

LLM-Pruner

LLM-Pruner is designed for structural pruning of large language models, utilizing gradient-based importance scores to remove non-critical structures and employing a short LoRA tuning stage to recover performance. It includes scripts for LLaMA, LLaMA 2, and Llama 3.1 8B under pruning/LLM-Pruner.

Key Takeaways

  • LLM-Pruning Collection is an Apache-2.0 compliant repo from zlab-princeton that unifies modern LLM pruning methods, supporting shared pruning, training, and evaluation pipelines for both GPUs and TPUs.
  • It incorporates block, layer, and weight level pruning techniques, including Minitron, ShortGPT, Wanda, SparseGPT, and more, with templates for Llama family models.
  • The training pipeline integrates FMS-FSDP on GPU and MaxText on TPU, featuring JAX-compatible evaluation scripts that offer significant speedups.
  • The repository reproduces key results from prior pruning research, publishing side-by-side “paper vs reproduced” tables for validation.

Check out the GitHub Repo for more information.

🇷🇺

Сменить язык

Читать эту статью на русском

Переключить на Русский