[Paper-PreTrain] Big Transfer (BiT):General Visual Representation Learning

9 minute read

Published:

Last Updated: 2020-05-26

This paper: Big Transfer (BiT):General Visual Representation Learning is proposed by researchers from Google Research.

Code: https://github.com/google-research/big_transfer

This paper presents a paradigm of pre-training on large supervised datasets and fine-tuning the model on a target task called Big Transfer (BiT). By combining a few carefully selected components, and transferring using a simple heuristic, we achieve strong performance on over20 datasets.

Performance:

BiT performs well across a surprisingly wide range of data regimes — from 1 example per class to 1 M total examples. BiT achieves 87.5% top-1 accuracy on ILSVRC-2012, 99.4% on CIFAR-10, and 76.3%on the 19 task Visual Task Adaptation Benchmark (VTAB). On small datasets, BiT attains 76.8% on ILSVRC-2012 with 10 examples per class,and 97.0% on CIFAR-10 with 10 examples per class.

1. Introduction

Three different scales of datasets for pre-train:

  1. BiT-L: JFT-300M dataset [51], which contains 300 M noisily labelled images.
  2. BiT-M: ImageNet-21k [10], which contains 14M images.
  3. BiT-S: ILSVRC-2012 [46], which contains 1.3M image.

Transfer tasks:

  1. ILSVRC-2012 [10]
  2. CIFAR-10/100 [27]
  3. Oxford-IIIT Pet [41]
  4. Oxford-Flowers-102 [39] (including few-shot variants)
  5. 1000-sample VTAB-1k benchmark [66], which consists of 19 diverse datasets

BiT-L attains state-of-the-art performance on many of these tasks, and is surprisingly effective when very little downstream data is available (Figure 1).

2020-05-24-blog-post-13-1

2. Big Transfer

In this section, the authors highlight the most important components that make Big Transfer effective

2.1. Upstream Pre-Training

1. Scale

The authors train BiT in three levels of scales to study the effectiveness of scale and interplay between computational budget (training time), architecture size and dataset size.

2. Group Normalization (GN) and Weight Standardization (WS)

Batch normalization (BN) is detrimental for Big Transfer (large-scale pre-training + any transfer):

  1. For large-scale (distributed) training, BN performs poorly or incurs inter-device synchronization cost.
  2. For fine-tuning, BN needs to update running statistics.

The authors show that the combination of GN and WS is useful for training with large batch sizes, and has a significant impact on transfer learning.

2.2. Transfer to Downstream Tasks

This work proposes a cheap fine-tuning protocol that applies to many diverse downstream tasks called BiT-HyperRule. These hyperparameters are found important for per fine-tuning task (considering the task’s intrinsic image resolution and number of data points):

  1. training schedule length
  2. resolution
  3. whether to use MixUp regularization

During fine-tuning, the authors adopt the standard data pre-processing strategy: resize to a square. For the training, they further crop out a smaller random square and randomly horizontally flip the image.

Recent work has shown that existing augmentation methods introduce inconsistency between training and test resolutions for CNNs [57]. It is common to scale up the resolution by a small factor at test time. As an alternative, the authors add a step at which the trained model is fine-tuned to the test resolution.

MixUp is not useful for pre-training BiT but sometimes useful for transfer (mostly for mid-sized datasets, not for few-shot transfer). The authors did not use regularization during fine-tuning including weight decay to zero, weight decay to initial parameters [31], or dropout because authors found that setting an appropriate schedule length provides sufficient regularization.

3. Experiments

3.1. Data for Upstream Training

  1. BiT-S is trained on ILSVRC-2012 variant of ImageNet, which contains 1.28 million images and 1000 classes. Each image has a single label.
  2. BiT-M is trained on the full ImageNet-21k dataset [10], a public dataset containing 14.2 million images and 21k classes organized by the WordNet hierarchy. Images may contain multiple labels.
  3. BiT-L is trained on the JFT-300M dataset consisting of around 300 million images with 1.26 labels per image on average. The labels are organized into a hierarchy of 18,291 classes. Approximately 20% of the labels are noisy. The authors removed all images present in downstream test sets.

3.2. Downstream Tasks

Tasks are long-standing benchmarks: ILSVRC-2012 [10], CIFAR-10/100 [27], Oxford-IIIT Pet [41] and Oxford Flowers-102 [39]. The authors fine-tuned BiT on the official training split and report results on the official test split if publicly available. Otherwise, they used the val split.

To further assess the generality of representations learned by BiT, the authors evaluated on the Visual Task Adaptation Benchmark (VTAB) [66].

VTAB consists of 19 diverse visual tasks, each of which has 1000 training samples (VTAB-1k variant). The tasks are organized into three groups: natural, specialized and structured. The VTAB-1k score is top-1 recognition performance averaged over these 19 tasks. The natural group of tasks contains classical datasets of natural images captured using standard cameras. The specialized group also contains images captured in the real world, but through specialist equipment, such as satellite or medical images. Finally, the structured tasks assess understanding of the the structure of a scene, and are mostly generated from synthetic environments. Example structured tasks include object counting and 3D depth estimation.

3.3. Hyperparameter Details

Model architecture: vanilla ResNet-v2-152x4. * The Batch Normalization is replaced with Group Normalization, and Weight Standardization is used in all convolutional layers.

BiT-HyperRule: Most hyperparameters are fixed across all datasets, but schedule, resolution, and usage of MixUp depend on the tasks image resolution and training set size.

Schedule length:

we call small tasks those with fewer than 20k labeled examples, medium those with fewer than 500 k, and any larger dataset is a large task. We fine-tune BiT for 500 steps on small tasks, for 10k steps on medium tasks, and for 20ksteps on large tasks.

Resolution:

We resize input images with area smaller than 96×96pixels to 160×160 pixels, and then take a random crop of 128×128 pixels. We resize larger images to 448×448 and take a 384×384-sized crop.1We apply random crops and horizontal flips for all tasks, except those for which cropping or flipping destroys the label semantics, see Supplementary section F for details.

Usage of Mixup:

we use MixUp [67], with α= 0.1, for medium and large tasks. See Supplementary section A for details.

Please refer to the original paper for other detailed settings.

3.4. Standard Computer Vision Benchmarks

Comparison between BiT-L and current SOTA:

2020-05-25-blog-post-13-2

Comparison between BiT-S and BiT-M:

2020-05-25-blog-post-13-3

Top-5 accuracy on ILSVRC-2012 with median±standard deviation format across 3 runs:

  1. 98.46%±0.02% for BiT-L
  2. 97.69%±0.02% for BiT-M
  3. 95.65%±0.03% for BiT-S.

3.5. Tasks with Few Datapoints

The authors transfer BiT-L using subsets of ILSVRC-2012, CIFAR-10, and CIFAR-100, down to 1 example per class.

2020-05-25-blog-post-13-4

Performance on 19 VTAB-1k tasks compared with SOTA:

2020-05-25-blog-post-13-5

3.6. ObjectNet: Recognition on a “Real-World” Test Set

We evaluate BiT on the new test-only ObjectNet dataset [2]. There are 313 object classes in total, with 113 overlapping with ILSVRC-2012. We follow the literature [2,6] and evaluate our models on those 113 classes.

2020-05-26-blog-post-13-6

3.7. Object Detection

We use the COCO-2017 dataset [34] and train a top-performing object detector, RetinaNet [33], using pre-trained BiT models as backbones. Due to memory constraints, we use the ResNet-101x3 architecture for all of our BiT models. We fine-tune the detection models on the COCO-2017 train split and report results on the validation split using the standard metric [34] in Table 3.

Here, we do not use BiT-HyperRule, but stick to the standard RetinaNet training protocol, see the Supplementary Material section E for details.

2020-05-26-blog-post-13-7

4. Analysis

4.1. Scaling Models and Datasets

Evaluation of BiT models of different sizes (ResNet-50x1, ResNet-50x3, ResNet-101x1, ResNet-101x3, and ResNet-152x4)

2020-05-26-blog-post-13-8

Fig.5 shows that:

  1. Larger models’ benefit diminishes on smaller dataset (ILSVRC-2012). Meanwhile, benefits of larger models are more pronounced on the larger two datasets.
  2. There is also limited (or even negative) benefit from training a small model on a larger dataset. (Perhaps surprisingly, the ResNet-50x1 model trained on the JFT-300M dataset can even performs worse than the same architecture trained on the smaller ImageNet-21k.)

Figure 6 ablates few-shot performance across different pre-training datasets and architectures.

2020-05-26-blog-post-13-9

  1. In the extreme case of one example per class, larger architectures outperform smaller ones when pre-trained on large upstream data.
  2. Interestingly, on ILSVRC-2012 with few shots, BiT-L trained on JFT-300M out-performs the models trained on the entire ILSVRC-2012 dataset itself. Note that for comparability, the classifier head is re-trained from scratch during fine-tuning,even when transferring ILSVRC-2012 full to ILSVRC-2012 few shot.

4.2. Optimization on Large Datasets

2020-05-26-blog-post-13-10

  1. Sufficient computational budget is crucial for training performant models on large datasets. (Left)

  2. The validation error may not improve over a long time —Figure 7 middle plot, “8 GPU weeks” zoom-in— although the model is still improving as evidenced by the longer time window.

  3. Lower weight decay can result in an apparent acceleration of convergence, Figure 7 rightmost plot. However, this setting eventually results in an under-performing final model.

    Explanation:

    This counter-intuitive behavior stems from the interaction of weight decay and normalization layers [29,32]. Low weight decay results in growing weight norms, which in turn results in a diminishing effective learning rate. Initially this effect creates an impression of faster convergence, but it eventually prevents further progress. A sufficiently large weight decay is required to avoid this effect, and throughout we use 1e−4.

  4. Optimizer: stochastic gradient descent with momentum without any modification.

4.3. Large Batches, Group Normalization, Weight Standardization

  1. Batch Normalization (BN) performs worse when the number of images on each accelerator (per-device batch size) is too low [20].
  2. Group Normalization (GN) [60] and Weight Standardization(WS) [43] as alternatives to BN:

We found that GN alone does not scale to large batches; we observe a performance drop of 5.4% on ILSVRC-2012 top-1accuracy compared to using BN in a ResNet-50x1. The addition of WS enables GN to scale to such large batches, even outperforming BN, see Table 4.

For fine-tuning, the results in Table 5 indicate that the GN/WS combination transfers better than BN.

2020-05-26-blog-post-13-11

Large-scale Weakly Supervised Learning of Representations

Specialized Representations

Unsupervised and Semi-Supervised Representation learning

Few-shot Learning