Datasets that were scraped from the web tend to be unbalanced, meaning examples of some classes (say, cats) are plentiful while examples of others (say, caterpillars) are scarce. A model that’s trained on an unbalanced dataset will perform unevenly across classes, but the labor required to balance the data manually can be prohibitive. An automated method addresses such imbalances.
What’s new: Huy V. Vo and colleagues at Meta, France’s National Institute for Research in Digital Science and Technology, Université Paris Saclay, and Google proposed a method that automatically selects a balanced subset of text or image datasets.
Key insight: A naive way to balance a dataset automatically is to cluster it using k-means to define implicit categories and then draw an equal number of points randomly from the resulting clusters. But this approach tends to form many clusters in areas of the distribution that have more examples, leading to over-representation of certain categories. For instance, when the authors applied k-means to web images and associated the clusters with their nearest neighbors in ImageNet, around 300 clusters (out of 10,000) corresponded to the ImageNet class “website.” However, after clustering, the distribution of the centroids is a bit more uniform than that of the entire dataset. Applying k-means repeatedly distributes the centroids (and thus the clusters) more uniformly. After a number of iterations, each cluster is more likely to represent a distinct category, and selecting equal numbers of examples from each cluster makes a balanced dataset.
How it works: The authors balanced image and text datasets using several iterations of k-means clustering. Their image dataset started with 743 million examples from a “publicly available repository of crawled web data.” For text, they started with CCNet, a version of Common Crawl that was filtered to match the distribution of language and topics found in Wikipedia. The following approach ensured balanced sampling from all levels, maintaining a balance among high-level classes (such as animal, vehicle, and sport) and lower-level subclasses (such as dog, airplane, and football):
- The authors embedded the data. They built an image-embedding model by training a ViT-L (307 million parameters) on ImageNet1k according to the DINOv2 self-supervised training method. To embed text, they used a pretrained SBERT.
- They clustered the data via k-means to produce 10 million clusters.
- They selected a small number of points closest to the centroid of each cluster. Then they applied k-means to the selected points to find new centroids. They repeated this process four times, each time decreasing the number of clusters, so the new clusters represented higher-level categories. With each iteration, the distribution of centroids became more uniform.
- Using the resulting hierarchy of clusters, the authors randomly selected balanced datasets of 100 million images and 210 billion text tokens. Specifically, starting with the highest-level clusters, they computed the number of samples to be drawn from each cluster. Then they looked up which clusters in the previous level were contained within each of the clusters in the current level and determined the number of samples to be drawn from each of these subclusters. They repeated this process at each level. In this way, when they reached the lowest level, they knew how many points to draw randomly from each of the lowest-level clusters. The points they drew made up a balanced dataset.
Results: Both vision and language models that were pretrained on the balanced data outperformed models that were pretrained on the corresponding unbalanced datasets.
- To test their balancing method on image classifiers, the authors pretrained ViT-g models on their balanced dataset and the unbalanced raw data. They froze the trained models and fine-tuned a linear layer on top of them to classify ImageNet. Pretrained on their balanced dataset, ViT-g achieved 85.7 percent accuracy on the ImageNet 1k validation set. Pretrained on the unbalanced dataset, it achieved 85.0 percent accuracy.
- To test their method on language models, they compared performance on various tasks of LLaMA-7B models that were pretrained on their balanced version of 210 billion tokens in CCNet and the unbalanced CCNet. For instance, on the HellaSwag question-answering dataset (zero-shot), the model pretrained on balanced data achieved 52.7 percent accuracy, while the model pretrained on unbalanced data achieved 51.9 percent accuracy. Similarly, on Arc-C (questions about common-sense physics such as the buoyancy of wood, zero-shot), the model pretrained on balanced data achieved 40.1 percent accuracy, while the model pretrained on unbalanced data achieved 35.5 percent accuracy.
Why it matters: The old-school machine learning algorithm k-means can organize quantities of pretraining data that are too large for manual inspection yet crucial to data-hungry models. Breaking down data into clusters also makes it possible to manually inspect cluster elements, which might help identify unwanted data.
We’re thinking: Even in the era of foundation models, data-centric AI — that is, systematically engineering the data used to train such models — remains a critical, often under-appreciated step. This paper offers a promising way to create more balanced datasets. The encouraging results suggest fruitful avenues for further study.