Large NLP models like BERT can answer questions about a document thanks to the transformer network, a sequence-processing architecture that retains information across much longer sequences than previous methods. But transformers have had little success in reinforcement learning — until now.
What’s new: Research in reinforcement learning (RL) has focused primarily on immediate tasks such as moving a single object. Transformers could support tasks that require longer-term memory. However, past research struggled to train transformer-based RL models. Emilio Parisotto and a DeepMind team combined them successfully with Gated Transformer-XL, or GTrXL. This network can substitute directly for an LSTM in RL applications.
Key insight: A transformer’s attention component models out-of-sequence relationships. Consider a block-stacking task where the first and sixth actions taken are the most important to predicting whether the stack will be in the right order. GTrXL modifies the transformer architecture to allow it to learn sequential relationships early on (say, between the first and second actions, where the first action places the initial block and the second identifies which block needs to be picked up next) before it has learned out-of-sequence relationships.
How it works: GTrXL modifies the transformer network (TrXL) as shown in the diagram above.
- GTrXL replaces the typical transformer’s residual connections with gated connections. This reduces errors that otherwise could flow through the residual connections.
- GTrXL applies layer normalization to the transformer’s sub-components but not to the gated connections. This allows the network to preserve information, including information derived directly from the input, over many residual connections while maintaining the attention mechanism’s performance.
- These modifications allow the network to learn from the order of input data while the attention mechanism hasn’t learned to model longer-term relationships. The shorter-term relationships are easier to model early on in training, making the network more stable during training.
Results: On DMLab 30, an RL environment that supports puzzle tasks requiring long-term memory, GTrXL outperformed the previous state of the art (MERLIN) averaged across all 30 tasks. It also outperformed an LSTM, the ubiquitous recurrent layer in RL research.
Why it matters: LSTMs have been essential to sequence-processing neural networks that work on short-term data. GTrXL give such networks longer-term memory. Longer time horizons eventually may help boost performance in life-long learning and meta-learning.
We’re thinking: Since the original paper describing transformer networks was published in 2017, researchers have developed extensions. This work continues to show that, when it comes to transformers, there’s more than meets the eye.