Neural networks designed to process datasets in the form of a graph — a collection of nodes connected by edges — have delivered nearly state-of-the-art results with only a handful of layers. This capability raises the question: Do deeper graph neural networks have any advantage? New research shows that they do.
What’s new: Ravichandra Addanki and colleagues at DeepMind probed the impact of depth on the performance of graph neural networks.
GNN basics: A graph neural network (GNN) operates on graphs that link, for instance, customers to products they've purchased, papers to the other papers they cite, or pixels adjacent to one another in an image. A GNN typically represents nodes and edges as vectors and updates them iteratively based on the states of neighboring nodes and edges. Some GNNs represent an entire graph as a vector and update it according to the representations of nodes and edges.
Key insight: Previous work found that adding a few layers to a shallow GNN barely improved performance. That study used graphs that comprised hundreds of thousands of nodes and edges. Since then, graphs have emerged with hundreds of millions of nodes and edges. Deeper GNNs may achieve superior performance on these larger datasets.
How it works: The authors built GNNs up to more than 100 layers deep, including an encoder (a vanilla neural network), a graph network made up of message-passing blocks (each a trio of vanilla neural networks), and a decoder (another vanilla neural network). Among other experiments, they trained a GNN on 4 million graphs of molecules, in which nodes are atoms and edges are bonds between them, to estimate a particular key property called the HOMO-LUMO gap. (This property helps determine a molecule’s behavior in the presence of light, electricity, and other chemicals.)
- Given a graph, the encoder generated an initial representation of each edge, each node, and the entire graph.
- A series of message passing blocks updated the representations iteratively: (1) A three-layer vanilla neural network updated each edge representation based on the previous representation, the two nodes on either side, and the graph. (2) A three-layer vanilla neural network updated each node representation based on the previous representation, all connected edges, and the graph. (3) A three-layer vanilla neural network updated the graph representation based on the previous representation, all edges, and all nodes.
- Given the final representation of the graph, the decoder computed the HOMO-LUMO gap.
- To improve the representations, the authors used Noisy Nodes self-supervision, which perturbed the representations of nodes or edges and penalized the GNN depending on how well it reconstructed them.
Results: The authors tested GNNs with different numbers of message-passing blocks. Performance on the validation set improved progressively with more message-passing blocks up to 32 — 104 layers total — but showed no benefit beyond that depth. A version with 8 message-passing blocks achieved ~0.128 mean absolute error, one with 16 achieved ~0.124 mean absolute error, and one with 32 achieved ~0.121 mean absolute error.
Why it matters: Not all types of data can be represented easily as an image or text — consider a social network — but almost all can be represented as a graph. This suggests that deep GNNs could prove useful in solving a wide variety of problems.
We’re thinking: CNNs and RNNs have become more powerful with increasing depth. GNNs may have a lot of room to grow.