School teachers may not like to hear this, but sometimes you get the best answer by peeking at your neighbor’s paper. A new language model framework peeks at the training data for context when making a prediction.
What’s new: Facebook AI and Stanford researchers led by Urvashi Khandelwal enhanced language models that predict the next word in an incomplete sentence by enabling them to search for potential answers in the training data. They call their algorithm kNN-LM.
Key insight: It’s much easier for a model to identify two sentence fragments that have similar meanings than it is to complete them. kNN-LM takes advantage of the easier task to improve performance on the harder one. Given a sentence fragment and asked to predict the next words, it searches the training set for sentences similar to that sentence fragment and uses what it finds to help predict the missing words. For example, the model might match a target starting, “Dickens is the author of ___,” with the training sentence, “Dickens wrote Oliver Twist.” The model then knows that “Oliver Twist” may be appropriate to add to the target.
How it works: The authors offer a pretrained model, vector representations of training sentences, and an algorithm for combining information when analyzing a test sentence. Their approach works with any pretrained neural language model, but they used transformer networks in most experiments.
- kNN-LM starts by generating vector representations of every sequence in the training set. Then it searches these vectors for the k-nearest neighbor vector representations of the new input sequence. The closer a training sequence’s vector is to the input’s vector, the more heavily it weights the training sequence’s next token.
- The neural language model also directly predicts the next token for the input.
- Then it factors both the k-nearest neighbors prediction and language model’s prediction into a final decision. A hyperparameter controls how heavily it considers each one.
Results: Tested on a dataset of Wikipedia articles, kNN-LM achieved a score of 15.79 for perplexity, a measure of predictive accuracy, more than 10 percent better than the previous state-of-the-art model.
Why it matters: Language models likely won’t interpret technical terms found in, say, the NuerIPS proceedings, if they’re trained on Wikipedia. kNN-LM lets them find less related words in the training data, potentially improving generalization to obscure subject matter.
We’re thinking: A key step for winning computer vision competitions like ImageNet has been to train multiple models and ensemble (or average) them. This confers perhaps a 1 percent boost in performance, but it’s impractical for most applications because of the computational expense. kNN-LM appears to require a significant computational expense as well, and we look forward to researchers diving deeper into the computational implications.