SimCLR系列

Dec 29, 2024
1 views
Self-Supervised

Self-Supervised Learning,又称为自监督学习,我们知道一般机器学习分为有监督学习,无监督学习和强化学习。 而 Self-Supervised Learning 是无监督学习里面的一种,主要是希望能够学习到一种通用的特征表达用于下游任务 (Downstream Tasks)。 其主要的方式就是通过自己监督自己。作为代表作的 kaiming 的 MoCo 引发一波热议, Yann Lecun也在 AAAI 上讲 Self-Supervised Learning 是未来的大势所趋。所以在这个系列中,我会系统地解读 Self-Supervised Learning 的经典工作。

总结下 Self-Supervised Learning 的方法,用 4 个英文单词概括一下就是:

Unsupervised Pre-train, Supervised Fine-tune.

这段话先放在这里,可能你现在还不一定完全理解,后面还会再次提到它。

在预训练阶段我们使用无标签的数据集 (unlabeled data),因为有标签的数据集很贵,打标签得要多少人工劳力去标注,那成本是相当高的,所以这玩意太贵。相反,无标签的数据集网上随便到处爬,它便宜。在训练模型参数的时候,我们不追求把这个参数用带标签数据从初始化的一张白纸给一步训练到位,原因就是数据集太贵。于是 Self-Supervised Learning 就想先把参数从 一张白纸 训练到 初步成型,再从 初步成型 训练到 完全成型。注意这是2个阶段。这个训练到初步成型的东西,我们把它叫做 Visual Representation。预训练模型的时候,就是模型参数从 一张白纸 到 初步成型 的这个过程,还是用无标签数据集。等我把模型参数训练个八九不离十,这时候再根据你 下游任务 (Downstream Tasks) 的不同去用带标签的数据集把参数训练到 完全成型,那这时用的数据集量就不用太多了,因为参数经过了第1阶段就已经训练得差不多了。

第1个阶段不涉及任何下游任务,就是拿着一堆无标签的数据去预训练,没有特定的任务,这个话用官方语言表达叫做:in a task-agnostic way。第2个阶段涉及下游任务,就是拿着一堆带标签的数据去在下游任务上 Fine-tune,这个话用官方语言表达叫做:in a task-specific way (这俩英文很重要啊,会不断出现)

以上这些话就是 Self-Supervised Learning 的核心思想,如下图所示,后面还会再次提到它。

image

Self-Supervised Learning 经典工作的分类如下图1所示。在上篇文章中主要介绍了 Self-Supervised Learning 在 NLP 领域 的经典工作:BERT模型的原理及其变体GPT, MASS, BART, ELECTRA等等,这些方法都是属于 Prediction 类别的。本文主要介绍Self-Supervised Learning 在 CV 领域 的经典工作之一:SimCLR和SimCLR v2,它们都是属于 Contrastive 类别的。

那 Prediction 类别和 Constractive 类别有什么不同呢?

Prediction 类别比如说BERT,是会用一堆没有label的句子去训练BERT做填空题 (详见上篇文章):给一个句子随机盖住 (mask掉) 一个token,输入这个BERT,期望它输出盖住的部分,使用这种办法让BERT无监督地学习到结合上下文做Embedding的能力,学习的过程是一种Prediction的行为。Contrastive 类别方法并不要求模型能够重建原始输入,而是希望模型能够在特征空间上对不同的输入进行分辨,这也会在SimCLR的训练过程中体现。

image

图1:Self-Supervised Learning 经典工作的分类

1 SimCLR 原理分析

论文名称:A Simple Framework for Contrastive Learning of Visual Representations

论文地址:

https://arxiv.org/pdf/2002.05709.pdfarxiv.org/pdf/2002.05709.pdf

SimCLR 是Hinton团队在 Self-Supervised Learning 领域的一个系列的经典工作。先来通过图2直观地感受下它的性能:SimCLR (4×) 这个模型可以在 ImageNet 上面达到 76.5% 的 Top 1 Accuracy,比当时的 SOTA 模型高了7个点。如果把这个预训练模型用 1%的ImageNet的标签给 Fine-tune 一下,借助这一点点的有监督信息,SimCLR 就可以再达到 85.5% 的 Top 5 Accuracy,也就是再涨10个点。

image

图2:SimCLR性能

那根据上一篇文章BERT的描述,我们说 Self-Supervised Learning 的目的一般是使用大量的无 label 的资料去Pre-train一个模型,这么做的原因是无 label 的资料获取比较容易,且数量一般相当庞大,我们希望先用这些廉价的资料获得一个预训练的模型,接着根据下游任务的不同在不同的有 label 数据集上进行 Fine-tune 即可

作为 Self-Supervised Learning 的工作之一,SimCLR 自然也遵循这样的思想。我们回忆一下之前 BERT 会用一堆没有label的句子去训练BERT做填空题 (详见上篇文章):给一个句子随机盖住 (mask掉) 一个token,输入这个BERT,期望它输出盖住的部分。这就是BERT进行自监督学习的做法,那么在 SimCLR 里面是如何做的呢?一个核心的词汇叫做:Contrastive

这个词翻译成中文是 对比 的意思,它的实质就是:试图教机器区分相似和不相似的事物

image

图3:对比学习试图教机器区分相似和不相似的事物

这个话是什么意思呢?比如说现在我们有任意的 4 张 images,如下图4所示。前两张都是dog 这个类别,后两张是其他类别,以第1张图为例,我们就希望它与第2张图的相似性越高越好,而与第3,第4张图的相似性越低越好

但是以上做法其实都是很理想的情形,因为:

  1. 我们只有大堆images,没有任何标签,不知道哪些是 dog 这个类的,哪些是其他类的。
  2. 没办法找出哪些图片应该去 Maximize Similarity,哪些应该去 Minimize Similarity。
    image

图4:试图教机器区分相似和不相似的事物

所以,SimCLR是怎么解决这个问题的呢?它的framework如下图5所示:

假设现在有1张任意的图片 \(x\) ,叫做Original Image,先对它做数据增强,得到2张增强以后的图片 \(x_i, x_j\) 。注意数据增强的方式有以下3种:

  • 随机裁剪之后再resize成原来的大小 (Random cropping followed by resize back to the original size)。
  • 随机色彩失真 (Random color distortions)。
  • 随机高斯模糊 (Random Gaussian Deblur)。
    接下来把增强后的图片 \(x_i, x_j\) 输入到Encoder里面,注意这2个Encoder是共享参数的,得到representation \(h_i, h_j\) ,再把 \(h_i, h_j\) 继续通过 Projection head 得到 representation \(z_i, z_j\) ,这里的2个 Projection head 依旧是共享参数的,且其具体的结构表达式是:

接下来的目标就是最大化同一张图片得到的 \(z_i, z_j\)

image

图5:SimCLR框架

以上是对SinCLR框架的较为笼统的叙述,下面具体地看每一步的做法:

回到起点,一开始我们有的training corpus就是一大堆 unlabeled images,如下图6所示。

image

图6:我们有的training corpus

1.1 数据增强

比如batch size的大小是 N ,实际使用的batch size是8192,为了方便我们假设 N=2 。

image

图7:Batch Size=2

注意数据增强的方式有以下3种:

  • 随机裁剪之后再resize成原来的大小 (Random cropping followed by resize back to the original size)。代码:

    torchvision:transforms:RandomResizedCrop
    

  • 随机色彩失真 (Random color distortions)。代码:

    **from** torchvision **import** transforms
    **def** **get_color_distortion**(s**=**1.0):
    *# s is the strength of color distortion.*color_jitter **=** transforms**.**ColorJitter(0.8*****s, 0.8*****s, 0.8*****s, 0.2*****s)
        rnd_color_jitter **=** transforms**.**RandomApply([color_jitter], p**=**0.8)
        rnd_gray **=** transforms**.**RandomGrayscale(p**=**0.2)
        color_distort **=** transforms**.**Compose([
        rnd_color_jitter,
        rnd_gray])
    
        **return** color_distort
    

  • 随机高斯模糊 (Random Gaussian Deblur)。

    random (crop **+** flip **+** color jitter **+** grayscale)
    

image

图8:对Input Image进行数据增强

对每张图片我们得到2个不同的数据增强结果,所以1个Batch 一共有 4 个 Image。

image

图9:对1个Batch的数据做增强,每个图片得到2个结果

1.2 通过Encoder获取图片表征

第一步得到的2张图片\(x_i, x_j\) 会通过Encoder获取图片的表征,如下图10所示。所用的编码器是通用的,可以用其他架构代替。下面显示的2个编码器共享权重,我们得到向量 \(h_i, h_j\) 。

image

图10:通过Encoder获取图片表征

本文使用了 ResNet-50 作为 Encoder,输出是 2048 维的向量 ℎ 。

1.3 预测头

使用预测头 Projection head。在 SimCLR 中,Encoder 得到的2个 visual representation再通过Prediction head (g(.))进一步提特征,预测头是一个 2 层的MLP,将 visual representation 这个 2048 维的向量\(h_i, h_j\)进一步映射到 128 维隐空间中,得到新的representation \(z_i, z_j\)。利用 \(z_i, z_j\) 去求loss 完成训练,训练完毕后扔掉预测头,保留 Encoder 用于获取 visual representation。

image

图11:预测头

1.4 相似图片输出更接近

到这一步以后对于每个Batch,我们得到了如下图12所示的Representation \(z_i,...,z_4\) 。

image

图12:最终得到的Representation

首先定义Representation之间的相似度:使用余弦相似度Cosine Similarity:

image

Cosine Similarity把计算两张 Augmented Images \(x_i, x_j\) 的相似度转化成了计算两个Projected Representation \(z_i, z_j\) 的相似度,定义为:

\[ s_{i, j} = \frac{z_i^{T}z_j}{\tau||z_i|||||z_j||} \]

式中, \(\tau\) 是可调节的Temperature 参数。它能够scale 输入并扩展余弦相似度[-1, 1]这个范围。

使用上述公式计算batch里面的每个Augmented Images \(x_i, x_j\) 的成对余弦相似度。 如下图13所示,在理想情况下,狗的增强图像之间的相似度会很高,而狗和鲸鱼图像之间的相似度会较低。

image

图13:Augmented Images的余弦相似度

现在我们有了衡量相似度的办法,但是这还不够,要最终转化成一个能够优化的 Loss Function 才可以。

SimCLR用了一种叫做 NT-Xent loss (Normalized Temperature-Scaled Cross-Entropy Loss)的对比学习损失函数。

我们先拿出Batch里面的第1个Pair:

image

图14:依次拿出Batch里面的每个Pair

使用 softmax 函数来获得这两个图像相似的概率:

image

图15:使用 softmax 函数来获得这两个图像相似的概率

这种 softmax 计算等价于获得第2张增强的狗的图像与该对中的第1张狗的图像最相似的概率。 在这里,分母中的其余的项都是其他图片的增强之后的图片,也是negative samples。

所以我们希望上面的softmax的结果尽量大,所以损失函数取了softmax的负对数:

\[ l(i, j)=\frac{exp(s_{i, j})}{\sum_{k=1}^{2N}1[k!=i]exp(s_{i, k})} \]

image

图16:损失函数取softmax的负对数

再对同一对图片交换位置以后计算损失:

image

图17:同一对图片交换位置以后计算损失

最后,计算每个Batch里面的所有Pair的损失之和取平均:

\[ L=\frac{1}{2N}\sum_{k=1}^N[l(2k-1, 2k)+l(2k, 2k-1)] \]