Researchers cut the processing required to train transformers by around 20 percent with only a slight degradation in performance.
What’s new: Xiuying Wei and colleagues at Swiss Federal Institute of Technology Lausanne replaced a transformer’s linear layers with approximations based on computationally efficient low-rank linear layers.
Key insight: A low-rank approximation replaces a matrix with a product of two smaller matrices. This technique is widely used to streamline fine-tuning via LoRA, which modifies the weights in each of a transformer’s linear layers by adding a learned low-rank approximation. As a direct replacement for the weights in linear layers, low-rank approximation saves processing during training, but it also causes unstable fluctuations in the training loss and slower convergence. The authors mitigated these undesirable effects by training each full-size layer in parallel with a low-rank approximation of the layer while gradually phasing out the full-size layer. This approach costs more memory and computation initially, but it saves those resources in the long run.
How it works: The authors modified a transformer (1.3 billion parameters) to use low-rank approximation (which trimmed the parameter count to 985 million). They trained both models on 25.5B tokens of text scraped from the web, filtered, and deduplicated.
- The authors replaced each of the larger transformer’s linear layers with two smaller linear layers, approximating its weight matrix with a product of two smaller matrices. (In mathematical terms, if a standard linear layer computes Wx, where W is the weights and x is the input, the replacement computes U(Vx), where U and V are smaller than W.)
- During the first half of training, they trained both usual and low-rank layers in parallel. The output of each layer was a weighted sum of the two. Initially they weighed the usual layer at 1 and the low-rank layers at 0. As training progressed, they decreased the usual layer’s weighting to 0 and increased the low-rank layers’ weighting to 1.
Results: The authors tested both the modified and full-size transformers on 500 million tokens from the validation set according to perplexity (a measure of the likelihood that a model will predict the next word, lower is better). The modified version achieved 12.86 perplexity, slightly worse than the full-size version’s 12.46 perplexity. However, training the modified version required more than 20 percent less processing and 14 percent less time. The modified transformer used 1.66*10^20 FLOPS and took 302 hours, while the full-size version used 2.10*10^20 FLOPS and took 352 hours.
Why it matters: Training large transformers requires a lot of computation. Low-rank approximation lightens the processing load. This work approximates a transformer's linear layers to save memory, while the earlier GaLore approximates the gradient to save optimizer memory.
We’re thinking: The authors note that this approach also works for fine-tuning pretrained models — a potential alternative to LoRA. Simply replace each pretrained linear layer (with weights W) with two linear layers (with weights U and V), and initialize U and V such that W = UV.