Neurips Day 7 — Self-Supervised Learning Workshop
This is a very exciting and inspiring workshop with speakers such as Yann LeCun, and personally I was looking forward to this workshop a lot since it’s arranged on the last day of the conference. The workshop is centered around self-supervised learning, which is a trending research area.
Why Self-supervised learning (SSL)?
Supervised learning is meeting its bottleneck. It not only relies heavily on expensive manual labeling but also suffers from generalization error, spurious correlations, and adversarial attacks.
To solve this issue, recent self-supervised learning (SSL) models evolved including frameworks such as pre-trained language models, Generative Adversarial Networks (GAN), Autoencoder, and Contrastive Coding. But what is SSL? It can be described as obtaining “labels” from the data itself by using a “semi-automatic” process, and predicting part of the data from other parts. Unsupervised learning concentrates on detecting specific data patterns.
SSL frameworks can be categorized into generative, contrastive, and generative-contrastive (adversarial), with particular genres inner each one.
SSL in NLP Representation Learning
Some SSL examples in nlp are ELMo (Embeddings from Language Model) is an implementation to generate contextual word embeddings. It is a feature-based approach which learns multiple embeddings for each word token, and decides how to extract word embedding and combine those embeddings based on the downstream tasks.
Different from feature-based approaches like ELMo, BERT is a fine-tuned approach. The model is first pretrained on a large amount of corpora through self-supervised learning, and then fine-tuned with labeled data. As the name indicates, BERT uses Transformer as its encoder. In the training stage, BERT masks some tokens in the sentence, and is then trained to predict the masked words. When using BERT, we first initialize the BERT model with pretrained weights, and then fine-tune the pre-trained model on downstream tasks.
SSL in Computer Vision (cv) Representation Learning
Residual neural network (ResNet) is the first iteration of SSL in cv. Instead of asking every few stacked layers to directly learn a desired underlying mapping, the authors design a residual mapping architecture ResNet. The core idea of ResNet is the introduction of shortcut connections, which are those skipping over one or more layers
SSL in Graphs Representation Learning
Generative SSL
Important self-supervised learning methods based on generative models, including auto-regressive (AR) models, flow-based models, autoencoding (AE) models and hybrid generative models.
1. Auto-regressive (AR)
AR models can be viewed as “Bayes net structure” (directed graph model). The joint distribution can be factorized as a product of conditionals. The advantage of auto-regressive models is that it can model the context dependency well. However, one shortcoming of the AR model is that the token at each position can only access its context from one direction.
In NLP, the objective of auto-regressive language modeling is usually maximizing the likelihood under the forward autoregressive factorization . GPT and GPT-2 use Transformer decoder architecture for language model. In computer vision, such as PixelRNN and PixelCNN use auto-regressive methods to model images pixel by pixel. For example, the lower (right) pixels are generated by conditioning on the upper (left) pixels. The auto-regressive models can also be applied on graph domain problems, such as graph generation, such as GraphRNN generates realistic graphs with deep auto-regressive models. They decompose the graph generation process into a sequence generation of nodes and edges conditioned on the graph generated so far. The objective of GraphRNN is defined as the likelihood of the observed graph generation sequences.
2. Auto-encoding (AE) Model
The goal of the auto-encoding model is to reconstruct (part of) inputs from (corrupted) inputs. Due to its flexibility, the AE model is probably the most popular generative model with a number of variants.
Autoencoder (AE) was first introduced for pre-training artificial neural networks. Autoencoder is typically for dimensionality reduction. Generally, the autoencoder is a feed-forward neural network trained to produce its input at the output layer. The AE is comprised of an encoder network and a decoder network . The objective of AE is to make and x 0 as similar as possible (such as through mean-square error).
CBOW and Skip-Gram are pioneering works. CBOW aims to predict the input tokens based on context tokens. In contrast, Skip-Gram aims to predict context tokens based on input tokens. Usually, negative sampling is employed to ensure computational efficiency and scalability.
Deepwalk samples truncated random walks to learn latent node embedding based on the Skip-Gram model. It treats random walks as the equivalent of sentences. However, LINE aims to generate neighbors based on current nodes. LINE also uses negative sampling to sample multiple negative edges to approximate the objective.
2.1 Denoising AE Model
The intuition of denoising autoencoder models is that representation should be robust to the introduction of noise. The masked language model (MLM) can be regarded as a denoising AE model because its input masks predicted tokens. To model text sequence, masked language model (MLM) randomly masks some of the tokens from the input, and then predict them based on their context information. BERT is the most representative work in this field. However, one shortcoming of this method is that there are no input [MASK] tokens for down-stream tasks. To mitigate this, the authors do not always replace the predicted tokens with [MASK] in training. Instead, they replace them with original words or random words with a small probability.
Compared with the AR model, in denoising AE for language modeling, the predicted tokens have access to contextual information from both sides. However, MLM assumes that the predicted tokens are independent of each other if the unmasked tokens are given.
2.2 Variational AE Model
The variational auto-encoding model assumes that data are generated from underlying latent (unobserved) representation. The posterior distribution over a set of unobserved variables Z = {z1, z2, …, zn} given some data X is approximated by a variational distribution, q(z|x). Variational Autoencoders (VAE) is one important example where variational inference is utilized. VAE assumes the prior p(z) and the approximate posterior q(z|x) both follow Gaussian distributions.
Nowadays, VAE and its variants have been widely used in the computer vision area, such as image representation learning, image generation, video generation. Variational auto-encoding models have also been employed in node representation learning on graphs. For example, Variational graph auto-encoder (VGAE) uses the same variational inference technique as VAE with graph convolutional networks (GCN) as the encoder. Due to the uniqueness of graph-structured data, the objective of VGAE is to reconstruct the adjacency matrix of the graph by measuring node proximity.
2.3 Combining AR and AE Model
In NLP, Permutation Language Model (PLM) is a representative model that combines the advantage of autoregressive model and auto-encoding model. XLNet, which introduces PLM, is a generalized auto-regressive pretraining method. XLNet enables learning bidirectional contexts by maximizing the expected likelihood over all permutations of the factorization order.
Contrastive learning
From statistical perspective, machine learning models are categorized into two classes: generative model and discriminative model.
For a long time, people believe that the generative model is the only choice for representation learning. However, recent breakthroughs in contrastive learning, such as Deep InfoMax, MoCo and SimCLR, shed light on the potential of discriminative models for representation. Contrastive learning aims at ”learn to compare” through a Noise Contrastive Estimation (NCE) objective.
Contrastive learning has increased popularity for images and graphs representation learning.
Context-Instance Contrast
The context-instance contrast, or so-called global-local contrast, focuses on modeling the belonging relationship between the local feature of a sample and its global context representation. When we learn the representation for a local feature, we hope it is associative to the representation of the global content, such as stripes to tigers, sentences to its paragraph, and nodes to their neighborhoods.
There are two main types of Context-Instance Contrast: Predict Relative Position (PRP) and Maximize Mutual Information (MI):
- PRP focuses on learning relative positions between local components. The global context serves as an implicit requirement for predicting these relations (such as understanding what an elephant looks like is critical for predicting relative position between its head and tail)
- MI focuses on learning the explicit belonging relationships between local parts and global context. The relative positions between local parts are ignored
Context-Context Contrast
As an alternative, context-context contrastive learning discards MI and directly studies the relationships between the global representations of different samples as what metric learning does. At the beginning, researchers borrow ideas from semi-supervised learning to produce pseudo labels via cluster-based discrimination, and achieve rather good performance on representations. More recently, CMC, MoCo, SimCLR, and BYOL further support the above conclusion by outperforming the context-instance-based methods and achieve a competitive result to supervised methods under the linear classification protocol, through a context-to-context level direct comparison. We will start with cluster-based discrimination proposed earlier and then turn to instance discrimination advocated by them.
A more radical step is made by BYOL (presented in NeurIPS 2020), which discards negative sampling in self-supervised learning but achieve an even better result. For contrastive learning methods we mention above, they learn representations by predicting different views of the same image and cast the prediction problem directly in representation space. However, predicting directly in representation space can lead to collapsed representations, because multi-views are generally too predictive for each other. Without negative samples, it would be too easy for the neural networks to distinguish.
Self-supervised Contrastive Pre-training for Semisupervised Self-training
While contrastive learning based self-supervised learning continues to push the boundaries on various bench marks, labels are still important because there is gap between training objectives of self-supervised learning and supervised learning. In other words, no matter how self-supervised learning models improve, they are still only powerful feature extractor, and to transfer to downstream task we still need abundant labels. And to bridge the gap between self pretraining and downstream task, semi-supervised learning is exactly what we are looking for
So that is for the introduction of SSL, now we can actually go over the speakers’ idea and slides (finally).
Speak 1: Oriol Vinyals
Oriol went over a nice introduction for SSL and InfoNCE++. He also touched on semi-supervised learning stating it shows promising results (listen to talk here).
Speaker 2: Ruslan Salakhutdinov
Ruslan talks about Capsule network which is inverted attention by aggregating feature on low level parts/capsule, and try to tell high level object/capsule it should belong to.
The cool thing is that the model has multi-view capability which can generate different angles of an object.
Speak3: Yejin Choi
Yejin presents the Comet project which lists out common sense modeling to help machine tell what might be true and incorporated the Mosaic project into . Comet is based on ConceptNet (declarative knowledge) and images (paper here and here).
Speaker 4: Jitendra malik
Jitendr talks about multimodal SSL, predicting human gestures from how they talk.
Speak 5: Jia Den
Jia talks about SSL in computer vision.
Speak 6: Alexei Efros
Alexei talks about using random walk on videos and make it a palindrome so it can perform self-supervised learning. Using video as data augmentation is brilliant why to get continuous angles of objects to get extra data for free (paper: Space-Time Correspondence as a Contrastive Random Walk)
Speaker 7: Yann LeCun
The purpose in SSL is to learn representation to train with fewer variables; and learn about the world is like, learn to fill in blanks, figure out two things that are the same, but don’t attempt to predict things.
Yann LeCu talks about engergy-based model which he has a full course teaching: Energy based model: energy function goes down on trying samples, minimize energy function to reconstruct or predict (lecture video link).
Self Supervised Learning (SSL) encompasses both supervised and unsupervised learning. The objective of the SSL pretext task is to learn a good representation of the input so that it can subsequently be used for supervised tasks. In SSL, the model is trained to predict one part of the data given other parts of the data. For example, BERT was trained using SSL techniques and the Denoising Auto-Encoder (DAE) has particularly shown state-of-the-art results in Natural Language Processing (NLP). Self Supervised Learning task can be defined as the following:
- Predict the future from the past.
- Predict the masked from the visible.
- Predict any occluded parts from all available parts.
Although the models can fill in the missing space they have not shared the same level of success as NLP systems. If you were to take the internal representations generated by these models, as input to a computer vision system, it is unable to beat a model that was pre-trained in a supervised manner on ImageNet. The difference here is that NLP is discrete whereas images are continuous. The difference in success is because in the discrete domain we know how to represent uncertainty, we can use a big softmax over the possible outputs, in the continuous domain we do not.
More of Yann’s hand writings from the talk:
A interesting Yann made is that the state-of-art computer vision models we have now still can’t compare to human’s observation. We are biased to certain things/representation when we use vision to identify things, but we are not focusing on other things such as texture, and the motivation is different.
Speaker 8: Kristen Grauman
Kristen talks about self-supervised visual learning from audio sound in the environment.
Speaker 9: Katerina Fragkiadaki
Katerina talks about simulating objects environment.
Speaker 10: Abhinav Gupta
Abhinav talks about multimodal SSL from videos.
Speaker 11: Quoc V. Le
Quoc talks about teacher student setting for SSL.
Speaker 12: Yuandong Tian
Yuandong talks about SimCLR and BYOL.
Panel discussion:
Why SSL only works for solving perception tasks but not reasoning tasks?
ex: semantic role labeling, first train on labeled data, then let model create labels as self learning and pop them back.
SSL mostly works because of data augmentation (ex: cropping) hidden state be invariant (ex: sentence is quite fixed, so language model can compress that, but CV can’t do that from raw pixels. language is easier in this sense.)
Also, language have symbols and sequence, high level concept can be represented in high level abstractions, but computer vision is all low level pixels, predicting the next frame is operating at low level, so it’s backward comparing to language (cv structure and space are more tricky). CV need to be at higher level semantics, ex: entities (objects, people) can get counterpart in language.
We are far from reasoning at least in CV. Nice numbers on benchmark doesn’t mean we’re good at it. Data augmentation from cropping contributing 90% meaning we’re still overfitting to the data. We need to step back and start from easier tasks and build up instead of from big, fun tasks.
The question also comes down to what you mean by reasoning. The way human reason? or the way machine reason. For example, computer add much faster then human. Definition of reasoning depends on your goal and resource. The goal for reasoning and SSL is not aligned: there’s infinite resource and time for SSL, but not great for reasoning becoz no starvation (using encoded rules plus example representation might work better, can generalize to better tasks).
Are we talking about induction or reduction (drive conclusion from new information)? SSL language model can fine-tune on downstream tasks, and put more focus on inference time algorithm (backprop based inference algorithm for induction). But not reasoning between what happened in previous image and recent image, such as if we changed some part of the image, what will happen, a.k.a getting the sequence of SSL tasks.
Yann : SSL hard technical issue to solve is predictions under uncertainty (many possible futures for video, can’t represent all possibility distribution over all possible video segments, can’t parametrize the distribution properly). language is ahead, becoz text is discrete and can compute softmax on large set of words and we know dependencies between words from transformer so we can get the model distribution. but that’s not true in computer vision becoz we can’t have train that on high dimensional and continuous signals like video and images, hard to do inference if no correct density estimation.
Contrastive learning work ok in low dimension but can’t scale to high dimensions, because it’s very inefficient, we should tty to get rid of them using joint embeddings (ex: BYOL) by pushing things/energy we don’t want out, but it’s not sustainable since there’s too many places to push up on. Methods on joint embeddings works better then methods on reconstruction or predictions. It is intrinsically bad to predict on pixels or latent variables.
GPT3: only learn what people says, can’t generalize out of context (ex: how many eyes horse has). human read much less. combine vision and language not with existing supervised labels by concept connections