MaskLLM:
Learnable Semi-Structured Sparsity for Large Language Models

- NeurIPS 2024 Spotlight -
NVIDIA National University of Singapore
*Work done at NVIDIA Research
MY ALT TEXT

Learnable Semi-Structured (or "N:M") Sparsity for Large Language Models. The learned mask can be further transfered to downstream tasks for lossless compression.

Abstract

Large Language Models (LLMs) are distinguished by their massive parameter counts, which typically result in significant redundancy. This work introduces MaskLLM, a learnable pruning method that establishes Semi-structured (or "N:M") Sparsity in LLMs, aimed at reducing computational overhead during inference. Instead of developing a new importance criterion, MaskLLM explicitly models N:M patterns as a learnable distribution through Gumbel Softmax sampling. This approach facilitates end-to-end training on large-scale datasets and offers two notable advantages: 1) High-quality Masks - our method effectively scales to large datasets and learns accurate masks; 2) Transferability - the probabilistic modeling of mask distribution enables the transfer learning of sparsity across domains or tasks. We assessed MaskLLM using 2:4 sparsity on various LLMs, including LLaMA-2, Nemotron-4, and GPT-3, with sizes ranging from 843M to 15B parameters, and our empirical results show substantial improvements over state-of-the-art methods. For instance, leading approaches achieve a perplexity (PPL) of 10 or greater on Wikitext compared to the dense model's 5.12 PPL, but MaskLLM achieves a significantly lower 6.72 PPL solely by learning the masks with frozen weights. Furthermore, MaskLLM's learnable nature allows customized masks for lossless application of 2:4 sparsity to downstream tasks or domains.

Method


Key Findings

Acknowledgement

BibTeX


        @article{fang2024maskllm,
          title={MaskLLM: Learnable Semi-structured Sparsity for Large Language Models},
          author={Fang, Gongfan and Yin, Hongxu and Muralidharan, Saurav and Heinrich, Greg and Pool, Jeff and Kautz, Jan and Molchanov, Pavlo and Wang, Xinchao },
          journal={Advances in Neural Information Processing Systems},
          year={2024}
        }