[Paper-PreTrain] Representation Learning with Contrastive Predictive Coding

13 minute read

Published:

Last Updated: 2020-04-13

This paper: Representation Learning with Contrastive Predictive Coding is proposed by researchers from DeepMind.

Code: https://github.com/jefflai108/Contrastive-Predictive-Coding-PyTorch

1. Introduction

Unsupervised representation learning is an important stepping stone towards robust and generic representation learning because representations learned by a specific supervised task hardly suit other tasks well.

One of the most common strategies for unsupervised learning is to predict future, missing or context.

The authors explanation for this strategy is that:

We hypothesize that these approaches are fruitful partly because the context from which we predict related values are often conditionally dependent on the same shared high-level latent information. And by casting this as a prediction problem, we automatically infer these features of interest to representation learning.

Three steps of the proposed method:

  1. Compress high-dimensional data into a more compact latent space.
  2. Use autoregressive models in this latent space to make predictions many steps in the future.
  3. Rely on Noise-Contrastive Estimation [12] for the loss function which allows the model to be trained end-to-end.

2. Contrastive Predicting Code

2.1. Motivation

Learn the representation that encode the underlying shared information between different parts of the high-dimensional raw signal while discard low-level local information and noise.

Challenges:

  1. Unimodal losses such as MSE or cross entropy is not useful here.
  2. Powerful conditional generative models are computationally intense and waste capacity at modeling the complex relationships in the data x, ignoring the context c. (e.g. what we want to know is only the label which is an integer but we have to model the whole image)

Intuition:

  1. P(x|c) is not optimal for the purpose of extracting shared information between x and c.
  2. Instead, we encode both x and c into the latent space in a way the maximally preserve the mutual information between x and c:
\[I(x;c)=\sum_{x,c}p(x,c)log\frac{P(x,c)}{P(x)P(c)}=\sum_{x,c}p(x,c)log\frac{P(x|c)}{P(x)}\]

2.2. Contrastive Predictive Coding

The figure below shows the architecture of CPC models.

2020-04-04-blog-post-9-1

Data Flow:

  1. Input sequence –> a non-linear encoder g_enc –> z_t = g_enc(x_t)
  2. previous data in latent space z_{<=t} –> autoregressive model g_ar –> predict c_t = g_ar(z_{<=t})
  3. model the density ratio which preserves the mutual information between x_{t+k} and c_t

Note that in step 3, the authors do not predict x_{t+k} directly with a generative model P(x_{t+k}|c_t)

Instead, they model the following density ratio equation f:

\[f_k(x_{t+k},c_t) \propto \frac{P(x_{t+k}|c_t)}{P(x_{t+k})}\]

The authors choose to use a simple log-bilinear model here:

\[f_k(x_{t+k},c_t) = exp(z_{t+k}^T W_kc_t)\]

Note that the matrix W_k is different for every step k. Alternatively, non-linear networks or RNN can be used here.

By using a density ratio f(x_{t+k}, c_t) and inferring z_{t+k} with an encoder, we relieve the model from modeling the high dimensional distribution x_{t+k}. Although we cannot evaluate p(x) or p(x|c) directly,we can use samples from these distributions, allowing us to use techniques such as Noise-Contrastive Estimation [12,14,15] and Importance Sampling [16] that are based on comparing the target value with randomly sampled negative values.

In the proposed model, either of z_t and c_t could be used as representation for downstream tasks.

Finally, note that any type of encoder and autoregressive model can be used in the proposed framework. For simplicity we opted for standard architectures such as strided convolutional layers with resnet blocks for the encoder, and GRUs [17] for the autoregresssive model. More recent advancements in autoregressive modeling such as masked convolutional architectures [18,19] or self-attention networks [20] could help improve results further.

2.3. InfoNCE Loss and Mutual Information Estimation

Both the encoder and autoregressive model are trained to jointly optimize a loss based on NCE, which we will call InfoNCE.

Given a set X = {x_1, …, x_N} of N samples containing one positive sample from p(x_{t+k}|c_t) and N - 1 negative samples from the ‘proposal’ distribution p(x_{t+k}).

\[L_n = -E_X[log\frac{f_k(x_{t+k}, c_t)}{\sum_{x_j\in X} f_k(x_j, c_t)}]\]

Optimizing this loss will result in f_k(x_t+k, c_t) estimating the density ratio in the above equation 1. This can be shown as follows.

Let [d = i] be the indicator that sample x_i is the positive sample, the equation below is the probability that sample x_i is drawn from the conditional distribution p(x_{t+k}|c) rather than the proposal distribution p(x_{t+k}):

\[p(d=i|X, c_t)=\frac{p(x_i|c_t) \prod_{l\neq i}p(x_l)}{\sum_{j=1}^N p(x_j|c_t)\prod_{l\neq j}p(x_l)}=\frac{\frac{p(x_i|c_t)}{p(x_i)}}{\sum_{j=1}^N\frac{p(x_j|c_t)}{p(x_j)}}\]

Then the lower bound of the mutual information between c_t and x_{t+k} can be written as follow:

\[I(x_{t+k}, c_t)=log(N)-L_N\]

(Proof of the equation above is in the Appendix of the original paper)

3. Experiments

The authors conduct four experiments on different domains: speech, images, natural languages and reinforcement learning.

3.1. Audio

For audio, the authors use a 100-hour subset of the publicly available LibriSpeech dataset [30]. The dataset contains speech from 251 different speakers.

Although the dataset does not provide labels other than the raw text, we obtained force-aligned phone sequences with the Kaldi toolkit [31] and pre-trained models on Librispeech. We have made the aligned phone labels and our train/test split available for download on Google Drive.

The encoder architecture g_enc used in our experiments consists of a strided convolutional neural network that runs directly on the 16KHz PCM audio waveform. We use five convolutional layers with strides [5, 4, 2, 2, 2], filter-sizes [10, 8, 4, 4, 4] and 512 hidden units with ReLU activations. The total downsampling factor of the network is 160 so that there is a feature vector for every 10ms of speech, which is also the rate of the phoneme sequence labels obtained with Kaldi.

We then use a GRU RNN [17] for the autoregressive part of the model, g_ar with 256 dimensional hidden state. The output of the GRU at every timestep is used as the context c from which we predict 12 timesteps in the future using the contrastive loss. We train on sampled audio windows of length 20480. We use the Adam optimizer [32] with a learning rate of 2e-4, and use 8 GPUs each with a minibatch of 8examples from which the negative samples in the contrastive loss are drawn. The model is trained until convergence, which happens roughly at 300,000 updates.

The figure below shows the accuracy of the model to predict latents in the future, from 1 to 20 timesteps.

2020-04-13-blog-post-9-2


The authors extract the 256-dimensional outputs c_t from the GRU for the whole dataset after the model converged and test them on two tasks: phone classification (41 classes) and speaker classification (251 classes). The results are as below.

Note that the classifier is a multi-class linear logistic regression classifier without hidden layers because the authors want to see how much encoded information is linearly accessible.

The supervised model is trained end-to-end with the labeled data comprised of two models of the same architecture as the one used to extract the CPC representations to see what is achievable with this architecture.

2020-04-13-blog-post-9-3

We can see that the results using representation learned by CPC is much better than MFCC and close to the fully supervised model. The authors also find that not all the information encoded is linearly accessible because when they use a single hidden layer, the phone classification performance increases from 64.6 to 72.5.


The table below shows the results of ablation study.

The first set shows that predicting multiple steps is important for the model to learn useful features.

The second set compares different strategies for drawing negative samples. In all strategies the model predicts 12 steps which gave the best result in the first ablation.

2020-04-13-blog-post-9-4

Additionally, Figure 2 shows a t-SNE visualization [33] of how discriminative the embeddings are for speaker voice-characteristics. It is important to note that the window size (maximum context size for the GRU) has a big impact on the performance, and longer segments would give better results. Our model had a maximum of 20480 timesteps to process, which is slightly longer than a second.

2020-04-13-blog-post-9-5


3.2. Vision

Experimental setup:

In our visual representation experiments we use the ILSVRC ImageNet competition dataset [34].The ImageNet dataset has been used to evaluate unsupervised vision models by many authors[28,11,35,10,29,36]. We follow the same setup as [36] and use a ResNet v2 101 architecture [37] as the image encoder g_enc to extract CPC representations (note that this encoder is not pretrained). We did not use Batch-Norm [38]. After unsupervised training, a linear layer is trained to measure classification accuracy on ImageNet labels.

Data preprocessing & augmentation:

The training procedure is as follows: from a 256x256 image we extract a 7x7 grid of 64x64 crops with 32 pixels overlap. Simple data augmentation proved helpful on both the 256x256 images and the 64x64 crops. The 256x256 images are randomly cropped from a 300x300 image, horizontally flipped with a probability of 50% and converted to greyscale. For each of the 64x64 crops we randomly take a 60x60 subcrop and pad them back to a 64x64 image.

Training details:

Each crop is then encoded by the ResNet-v2-101 encoder. We use the outputs from the third residual block, and spatially mean-pool to get a single 1024-d vector per 64x64 patch. This results in a 7x7x1024 tensor. Next, we use a PixelCNN-style autoregressive model [19] (a convolutional row-GRU PixelRNN [39] gave similar results) to make predictions about the latent activations in following rows top-to-bottom, visualized in Figure 4. We predict up to five rows from the 7x7 grid, and we apply the contrastive loss for each patch in the row. We used Adam optimizer with a learning rate of 2e-4 and trained on 32 GPUs each with a batch size of 16.

2020-04-13-blog-post-9-6

For the linear classifier trained on top of the CPC features we use SGD with a momentum of 0.9, a learning rate schedule of 0.1, 0.01 and 0.001 for 50k, 25k and 10k updates and batch size of 2048 on a single GPU. Note that when training the linear classifier we first spatially mean-pool the 7x7x1024 representation to a single 1024 dimensional vector. This is slightly different from [36] which uses a 3x3x1024 representation without pooling, and thus has more parameters in the supervised linear mapping (which could be advantageous).

Results:

Tables 3 and 4 show the top-1 and top-5 classification accuracies compared with the state-of-the-art. Despite being relatively domain agnostic, CPCs improve upon state-of-the-art by 9% absolute in top-1 accuracy, and 4% absolute in top-5 accuracy.

2020-04-13-blog-post-9-7


3.3. Natural Language

For the NLP tasks, the model is trained on the BookCorpus dataset [42]. The model then works as a generic feature extractor and is evaluated on a set of classification tasks.

To cope with words that are not seen during training, we employ vocabulary expansion the same way as [26], where a linear mapping is constructed between word2vec and the word embeddings learned by the model.

The datasets used for classification tasks are as follows:

  1. movie review sentiment (MR) [43]
  2. customer product reviews (CR) [44]
  3. subjectivity/objectivity [45]
  4. opinion polarity (MPQA) [46]
  5. question-type classification (TREC) [47]

Model architecture and training details:

Our model consists of a simple sentence encoder g_enc (a 1D-convolution + ReLU + mean-pooling) that embeds a whole sentence into a 2400-dimension vector z, followed by a GRU (2400 hidden units) which predicts up to 3 future sentence embeddings with the contrastive loss to form c. We used Adam optimizer with a learning rate of 2e-4 trained on 8 GPUs, each with a batch size of 64. We found that more advanced sentence encoders did not significantly improve the results, which may be due to the simplicity of the transfer tasks (e.g., in MPQA most datapoints consists of one or a few words), and the fact that bag-of-words models usually perform well on many NLP tasks [48].

The results are as follows.

2020-04-13-blog-post-9-8

The performance of our method is very similar to the skip-thought vector model, with the advantage that it does not require a powerful LSTM as word-level decoder, therefore much faster to train. Although this is a standard transfer learning benchmark, we found that models that learn better relationships in the children books did not necessarily perform better on the target tasks (which are very different: movie reviews etc). We note that better [49,27] results have been published on these target datasets, by transfer learning from a different source task.


3.4. Reinforcement Learning

The authors evaluate the proposed unsupervised learning approach on five reinforcement learning in 3D environments of DeepMind Lab [51]: rooms_watermaze, explore_goal_locations_small, seekavoid_arena_01, lasertag_three_opponents_small and rooms_keys_doors_puzzle.

Please refer to the original paper for details about the reinforcement learning experiment.


4. Conclusion

The Contrastive Predictive Coding (CPC) is a framework for extracting compact latent representations to encode predictions over future observations. CPC combines autoregressive modeling and noise-contrastive estimation with intuitions from predictive coding to learn abstract representations in an unsupervised fashion.

The simplicity and low computational requirements to train the model, together with the encouraging results in challenging reinforcement learning domains when used in conjunction with the main loss are exciting developments towards useful unsupervised learning that applies universally to many more data modalities.