The latest large, pretrained language models rely on trendy layers based on transformer networks. New research shows that these newfangled layers may not be necessary.
What’s new: Networks such as BERT and ERNIE take advantage of multi-headed attention layers to outcompete LSTM language models. But training these layers requires lots of compute on enormous GPU clusters. Stephen Merity of d⁄dx Times Labs struck a blow for garage AI with Single Headed Attention RNN (SHA-RNN), which nearly matched state-of-the-art performance after training on a single GPU for less than 24 hours. As he puts it in a tartly worded paper, “Take that, Sesame Street.”
Key insight: The author set out to find a high-performance language model suitable for his personal computer. He used a single attention head out of skepticism that multiple heads are worth their computational cost. Simplifying the transformer’s feed-forward network enabled him to run the model on a single GPU.
How it works: SHA-RNN is built on an LSTM to represent more explicitly the sequential nature of text.
- The model reads an input text sequence token by token and predicts the next token, usually a word or root of a word. The LSTM’s memory component stores important learned features.
- The LSTM’s output layer feeds the single-headed attention layer, which models relationships between tokens across the sequence.
- The attention layer’s output feeds a so-called boom layer. This layer replaces the transformer’s usual two feed-forward layers with a single feed-forward layer plus a summing layer to maintain vector length.
Results: Merity tested SHA-RNN by compressing the enwik8 dataset. More accurate language models use fewer bits to represent a sequence because they know, to some extent, which words will occur. SHA-RNN achieved 1.068 bits per character compared to 0.99 by Sparse Transformer — slightly less accurate, but in half as many parameters.
Yes, but: An LSTM is a good choice for sequential language-prediction tasks like enwik8. In non-sequential tasks such as fill-in-the-blanks, multi-headed attention is a better choice. A version of Transformer-XL that has even fewer parameters than SHA-RNN performed better on the compression task.
Why it matters: SHA-RNN isn’t an out-and-out replacement for transformer-based networks. But it shows that LSTMs remain relevant and useful in language modeling. And if you’re looking for a way to get people to read your research, the author’s style offers pointers: This paper is a very entertaining read!
We’re thinking: Researchers like to focus on optimizing state-of-the-art methods, and media hype frequently chases the latest leaderboard topper. Yet foundational algorithms remain valuable in a variety of contexts.