Emergent Mind

Fast and Optimal Weight Update for Pruned Large Language Models

(2401.02938)
Published Jan 1, 2024 in cs.CL and cs.LG

Abstract

Pruning LLMs is a challenging task due to their enormous size. The primary difficulty is fine-tuning the model after pruning, which is needed to recover the lost performance caused by dropping weights. Recent approaches have either ignored fine-tuning entirely, focusing on efficient pruning criteria, or attempted layer-wise weight updates, preserving the behavior of each layer. However, even layer-wise weight updates can be costly for LLMs, and previous works have resorted to various approximations. In our paper, we propose a fast and optimal weight update algorithm for pruned layers based on the Alternating Direction Method of Multipliers (ADMM). Coupled with a simple iterative pruning mask selection, our algorithm achieves state-of-the-art pruning performance across a wide range of LLMs. Code is available at https://github.com/fmfi-compbio/admm-pruning.

Overview

  • The paper addresses the challenge of deploying LLMs by introducing a novel algorithm for weight updates in pruned LLMs.

  • It utilizes the Alternating Direction Method of Multipliers (ADMM) to achieve optimal weight updates with a significantly reduced computational cost.

  • The authors propose a norm-based rule and preconditioning for effective pruning mask selection and an iterative pruning strategy that invokes a sparsity schedule.

  • Experimental results demonstrate the algorithm's superiority over existing methods in terms of convergence speeds and the quality of weight updates.

  • The study highlights the ADMM-based method's potential for improving LLM scalability and suggests directions for future enhancements.

Introduction

The ongoing development of LLMs has led to remarkable advances in a wide variety of language tasks. Nevertheless, the deployment of these models poses significant challenges, mainly attributed to their size, which results in substantial memory and computational resource requirements. While previous attempts to tackle these issues have included methods such as parameter quantization and pruning, the latter approach has not gained as much traction primarily due to difficulties in fine-tuning pruned networks. Existing solutions have either overlooked fine-tuning or utilized layer-wise weight updates, which, although intentioned to be efficient, are still costly and often resort to approximations, particularly in the context of LLMs.

Optimizing Pruning via Alternating Direction Method of Multipliers (ADMM)

In this context, the paper introduces a novel efficient algorithm for updating the weights of pruned LLMs based on the Alternating Direction Method of Multipliers (ADMM), a mathematical optimization technique. Coupled with a straightforward iterative pruning mask selection, this algorithm bypasses the issues faced by predecessors and achieves state-of-the-art performance in pruning without compromising the model's ability to recover and maintain its original functionality. The paper's ADMM-based solution requires only a single matrix inversion and a few simple iterations, yielding optimal weight updates for given pruning masks.

Pruning Mask Selection and Weight Update

To perform the pruning effectively, the paper also examines how to select the mask for pruning. Incorporating insights from recent literature on pruning mask selection, the authors implement a norm-based rule to determine the significance of weights and their eligibility for removal. The process is fine-tuned using a preconditioning step that scales weight matrices and calibration inputs, making subsequent pruning decisions equivalent to those suggested by the Wanda algorithm. Their approach to iterative pruning invokes a sparsity schedule, chunking the pruning process across several steps, which not only ensures gradual reduction in model size, but also allows for concurrent optimizations.

Experimental Validation and Conclusions

Through extensive experimentation, the paper validates the proposed algorithm against alternative methods like SparseGPT and Adaprune, demonstrating superior convergence speeds and quality of weight updates. The algorithm's efficacy is illustrated in tests conducted with LLMs across a spectrum of pruning sparsities, revealing that the new method outperforms existing approaches, particularly in iterative pruning setups. Despite focusing on weight update post-pruning, the authors acknowledge the limitations of their study, which include not capitalizing on sparsity during computations and leave potential improvements for future work, such as incorporating nonuniform sparsity or more nuanced mask selection algorithms.

The research concludes on a high note, asserting the ADMM-based weight update method as a sound and practical solution for enhancing the scalability and deployment feasibility of LLMs. The contribution is commendable not only for pushing the boundaries of pruning performance but also in setting a benchmark for future research in this crucial area of deep learning.

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.