Transformers can learn a lot from sequential data like words in a book, but they’ve shown limited ability to learn from data in the form of a graph. A new transformer variant gives graphs due attention.
What's new: Vijay Prakash Dwivedi and Xavier Bresson at Nanyang Technological University devised Graph Transformer (GT), a transformer layer designed to process graph data. Stacking GT layers provides a transformer-based alternative to typical graph neural networks, which process data in the form of nodes and edges that connect them, such as customers connected to products they’ve purchased or atoms connected to one another in a molecule.
Key insight: Previous work applied transformers to graph data by dedicating a token to each node and computing self-attention between every pair. This method encodes both local relationships, such as which nodes are neighbors (given a hyperparameter that defines the neighborhood within a number of degrees of separation), and global information, such as a node’s distance from non-neighboring nodes. However, this approach is prohibitively expensive for large graphs, since the computation required for self-attention grows quadratically with the size of the input. Applying attention only to neighboring nodes captures crucial local information while cutting the computational burden. Meanwhile, a positional vector that represents each node’s relative distance from all other nodes can capture global information in a compute-efficient way.
How it works: The authors built three models, each of which comprised embedding layers, 10 GT layers (including self-attention and fully connected layers) followed by a vanilla neural network. They trained each model on a different task: two-class classification of synthetic data, six-class classification of synthetic data, and a regression task that estimated the solubility of various compounds that contain zinc.
- Given a graph, the embedding layers generated an embedding and positional vector for each node. Using a contrastive approach, it generated similar positional vectors for nearby nodes and dissimilar positional vectors for distant nodes. It added the embedding and positional vector to form a node representation.
- The GT layer honed each node representation by applying self-attention between it and its neighbors. Then it passed the node representation to the fully connected layer.
- The model executed these steps through 10 layers and delivered the final representations to the vanilla neural network, which performed classification or regression.
Results: The authors’ model achieved 73.17 percent accuracy and 84.81 percent accuracy on the two- and six-class classification tasks, respectively. A baseline GAT graph neural network, which applied attention across neighboring node representations, achieved 70.58 percent accuracy and 78.27 percent accuracy respectively. On the regression task, the authors’ model achieved mean absolute error (MAE) of 0.226 compared to GAT’s 0.384 (lower is better). However, it slightly underperformed the state-of-the-art Gated Graph ConvNet in all three tasks.
Why it matters: Transformers have proven their value in processing text, images, and other data types. This work makes them more useful with graphs. Although the Graph Transformer model fell short of the best graph neural network, this work establishes a strong baseline for further work in this area.
We're thinking: Pretrained and fine-tuned transformers handily outperform trained convolutional neural networks. Would fine-tuning a Graph Transformer model yield similarly outstanding results?