Transformer networks are gaining popularity as a high-accuracy alternative to recurrent neural networks. But they can run slowly when they’re applied to long sequences. New research converts transformers into functional RNNs for a major speed boost.
What’s new: Angelos Katharopoulos and colleagues at Idiap Research Institute, École Polytechnique Fédérale de Lausanne and University of Washington accelerated transformers nearly a thousand-fold by outfitting them with linear attention.
Key insight: Researchers have used transformers instead of RNNs to analyze sequences, primarily sequences of words but also sequences of pixels. However, the number of calculations performed by the straightforward implementation of a transformer rises quadratically as sequence length increases, while calculations performed by RNNs rise linearly. The authors modified a transformer to act like an RNN’s hidden state. This modification, along with a clever speedup, allows the transformer’s computations to scale linearly with sequence length.
How it works: Transformers extract features that capture the relationship between elements in the sequence. These features depend on comparisons between a single token to every other token in the sequence.
- The authors noticed that similarities among tokens could be reformulated as a dot product in an alternative feature space (a technique known as the kernel trick).
- The kernel trick enables linear attention to combine intermediate calculations into a single matrix that’s shared among all feature comparisons. The matrix’s size remains constant regardless of the number of tokens in the sequence, which avoids the quadratic slowdown.
- To mimic an RNN, the researchers compared the latest input token only to earlier tokens rather than all tokens in a sequence. This technique, called causal masking, lets the transformer reuse the matrix in consecutive time steps instead of recomputing the entire layer as usual. Thus the matrix acts like the hidden state of an RNN.
Results: Linear attention generated synthetic MNIST images over 400 times faster than Reformer, the pace-setting transformer in this task. And it was more accurate, too. In speech recognition on the WSJ dataset, linear attention achieved a lower error rate (8 percent) compared to both Reformer (9.3 percent) and a bi-LSTM (10.9 percent).
Why it matters: This work demonstrated advantages over typical transformers without incurring any apparent costs. It remains to be seen whether these benefits extend to all situations.
We’re thinking: Estimates of the cost of training gargantuan transformer-based language models run to millions of dollars. It sure would be nice to trim those budgets by a few orders of magnitude.