Semi-supervised learning — a set of training techniques that use a small number of labeled examples and a large number of unlabeled examples — typically treats all unlabeled examples the same way. But some examples are more useful for learning than others. A new approach lets models distinguish between them.
What’s new: Researchers Zhongzheng Ren, Raymond A. Yeh, and Alexander G. Schwing from the University of Illinois at Urbana-Champaign developed an algorithm that weighs the most significant examples more heavily.
Key insight: In its most common form, semi-supervised learning tries to minimize a weighted combination of supervised and unsupervised losses. Most previous approaches effectively weight each unlabeled example as equally important. The authors, instead of assigning one weight to all unlabeled examples, calculate weights for every example automatically by evaluating how it changes the model’s output during training.
How it works: The algorithm works with any semi-supervised model. It trains by alternating between optimizing the model and the per-example weights.
- First, the authors trained the model on the training set while keeping the per-example weights fixed.
- Then they trained the per-example weights on the validation set while keeping the model parameters fixed.
- The authors derived an influence function to calculate the gradient of the validation loss. This function measures how changing the weight assigned to an unlabeled training example affects the model parameters.
Results: Using synthetic data, the authors demonstrated that less useful examples were assigned lower weights. In image classification using the Cifar-10 and SVHN datasets, their approach marginally outperformed previous state of the art semi-supervised learning work including FixMatch and UDA. Specifically, using a Wide ResNet-28-2 and Cifar-10 with 250 labeled examples, the authors’ method combined with FixMatch achieved a classification error of 5.05 percent compared to FixMatch’s 5.07 percent. Combined with UDA, the authors’ method on Cifar-10 achieved a classification error of 5.53 percent compared to UDA’s 8.76 percent.
Why it matters: Unlabeled data points are available in far greater profusion than labeled data points. This work explores a path toward unlocking their value.
We’re thinking: Sometimes another 1,000 cat pictures don’t provide a model with any more useful information. But keep sending them anyway. The Batch team appreciates it!