Reading

Swin-Transformer

前言

首先看论文题目。Swin Transformer: Hierarchical Vision Transformer using Shifted Windows。即:Swin Transformer是一个用了移动窗口的层级式Vision Transformer

所以Swin来自于 Shifted Windows , 它能够使Vision Transformer像卷积神经网络一样,做层级式的特征提取,这样提取出来的特征具有多尺度的概念 ,这也是 Swin Transformer这篇论文的主要贡献。

标准的Transformer直接用到视觉领域有一些挑战,即:

  • 多尺度问题:比如一张图片里的各种物体尺度不统一,NLP中没有这个问题;
  • 分辨率太大:如果将图片的每一个像素值当作一个token直接输入Transformer,计算量太大,不利于在多种机器视觉任务中的应用。
    基于这两点,本文提出了 hierarchical Transformer,通过移动窗口来学习特征。
  • 移动窗口学习,即只在滑动窗口内部计算自注意力,所以称为W-MSA(Window Multi-Self-Attention)。
  • W-MSA大大降低了降低了计算复杂度。同时通过Shiting(移动)的操作可以使相邻的两个窗口之间进行交互,也因此上下层之间有了cross-window connection,从而变相达到了全局建模的能力。
  • 分层结构使得模型能够灵活处理不同尺度的图片,并且计算复杂度与图像大小呈线性关系,这样模型就可以处理更大分辨率的图片(为作者后面提出的Swin V2铺平了道路)。

Swin-Transformer 对比 VIT

Vision Transformer:进行MSA(多头注意力)计算时,任何一个patch都要与其他所有的patch都进行attention计算,计算量与图片的大小成平方增长
Swin Transformer:采用了W-MSA,只对window内部计算MSA,当图片大小增大时,计算量仅仅是呈线性增加。
image

可以看出主要区别有两个:

  1. 层次化构建方法(Hierarchical feature maps) :Swin Transformer使用了类似卷积神经网络中的层次化构建方法。
    • 对于计算机视觉的下游任务,尤其是密集预测型的任务(检测、分割),有多尺寸的特征至关重要的。(比如目标检测里的FPN、分割里面的UNet等等)
    • Vision Transformer中是一开始就直接下采样16倍,这样模型自始至终都是处理的16倍下采样率过后的特征,这样在处理需要多尺寸特征的任务时,效果不够好。
    • Swin Transformer 使用patch merging,可以把相邻的四个小的patch合成一个大的patch,提高了感受野,这样就能获取多尺度的特征(类似CNN中的池化效果)。这些特征通过FPN结构就可以做检测,通过UNet结构就可以做分割了。
  1. 使用W-MSA ,好处有两点:
    • Swin Transformer使用窗口(Window)的形式将特征图划分成了多个不相交的区域,并且只在每个窗口内进行多头注意力计算,大大减少计算量。
    • 获得了和CNN一样的归纳偏置特性——locality。 归纳偏置:一种先验知识或者说提前的假设 locality:CNN是以滑动窗口的形式一点一点地在图片上进行卷积的,所以假设图片上相邻的区域会有相邻的特征,靠得越近的东西相关性越强

在 Swin Transformer里,默认每个窗口有49个patch,第一层每个patch尺寸是4*4。

locality进一步说明:对于图片来说,语义相近的不同物体还是大概率会出现在相连的地方,所以即使是在一个小范围的窗口内计算自注意力也是差不多够用的,全局计算自注意力对于视觉任务来说,其实是有点浪费资源的。

W-MSA虽然减少了计算量,但也会隔绝不同窗口之间的信息传递。所以在论文中作者又提出了 SW-MSA的概念,通过此方法能够让信息在相邻的窗口中进行传递,后面会细讲。

方法部分

模型结构

原论文中给出的关于Swin Transformer(Swin-T)网络的架构图如下:

image

前向过程:

  1. Patch Partition层:类似ViT一样将图片分割成一个个4*4大小的patch([224,224,3]—>[56,56,48])
  2. Linear Embeding层:将每个像素的channel数调整为C,并对每个channel做一次Layer Norm。([56,56,48]—>[56,56,96])
    1. 假设输入的是RGB三通道图片,那么每个patch就有4x4=16个像素,然后每个像素有R、G、B三个值,所以展平后是16x3=48。
    2. swin-transformer有T、S、B、L等不同大小,其C的值也不同,比如Swin-Tiny中,C=96。
    3. 在源码中Patch Partition和Linear Embeding就是直接通过一个卷积层实现的,和之前Vision Transformer中讲的 Embedding层结构一模一样。(kernel size=4×4,stride=4,num_kernel=48)
  3. 将每49个patch划分为一个窗口,后续只在窗口内进行计算。
  4. 通过四个Stage构建不同大小的特征图。其中后三个stage都是先通过一个Patch Merging层进行2倍的下采样。([56,56,96]—>[28,28,192]—>[14,14,384]—>[7,7,768])
  5. 每个stage中,重复堆叠Swin Transformer Block偶数次(结构见上图右侧,分别使用W-MSA和SW-MSA,两个结构成对出现)。
  6. 如果是分类任务,后面还会接上一个Layer Norm层、全局池化层以及全连接层得到最终输出。[7,7,768]—>[1,768]—>[1,num_class](也就是做序列的全局平均,类似CNN的做法,而不是加上CLS做分类)
    1. 如果不划分窗口,以Swin-Tiny举例,Linear Embeding层输出矩阵为[56,56,96]。如果计算全局注意力的话,输入序列长度为56*56=3136,每个元素是96维,这个序列就太长了。
    2. 引入 Shifted Windows后,每个序列长度固定为49。
    3. 与ViT还有一点不同的是:ViT在输入时会给embedding加上1D-位置编码。而Swin-T这里则是作为一个可选项(self.ape)。另外Swin-T在计算Attention时用的是相对位置编码

看完整个前向过程之后,就会发现 Swin Transformer 有四个 stage,还有类似于池化的 patch merging 操作,自注意力还是在小窗口之内做的,以及最后还用的是全局平均池化 。所以可以说 Swin Transformer是披着Transformer皮的卷积神经网络,将二者进行了完美的结合。

接下来,在分别对Patch Merging、W-MSA、SW-MSA以及使用到的相对位置偏置(relative position bias)进行详解。

图片预处理:分块和降维 (Patch Partition)

Swin Transformer 首先把\(x\in H\times W \times 3\) 的图片,变成一个 \(x_p\in N\times(P^2\cdot C)\) 的2维的image patches。它可以看做是一系列的展平的2D块的序列,这个序列中一共有 \(N=HW/P^2\)个展平的2D块,每个块的维度是 \(P^2\cdot 3\) 。其中 \(P\) 是块大小。

在 Swin Transformer 中,块的大小 \(P=4\) ,所以得到的 \(x_p\in N×48\) ,这里的 \(N=HW/16=\frac{H}{4}\times\frac{W}{4}\) 。

所以经过了这一步的分块操作,一张\(x\in H\times W \times 3\) 的图片就变成了 \(\frac{H}{4}\times\frac{W}{4}\times 48\) 的张量,可以理解成是 \(\frac{H}{4}\times\frac{W}{4}\) 个图片块,每个块是一个 48 维的 token。

线性变换 (Linear Embedding)

现在得到的向量维度是: \(\frac{H}{4}\times\frac{W}{4}\times 48\) ,还需要做一步叫做Linear Embedding的步骤,对每个向量都做一个线性变换(即全连接层),变换后的维度为 \(C\) ,这里我们称其为 Linear Embedding。这一步之后得到的张量维度是: \(\frac{H}{4}\times\frac{W}{4}\times C\) 。

Stage1: Swin Transformer Block

接下来 \(\frac{H}{4}\times\frac{W}{4}\times 48\) 这个张量进入2个连续的 Swin Transformer Block 中,这被称作 Stage 1,在整个的 Stage 1 里面 token 的数量一直维持 \(\frac{H}{4}\times\frac{W}{4}\) 不变。

Swin Transformer Block 具体是如何操作的呢?

image

Swin Transformer Block 的结构如上图所示。上图是2个连续的 Swin Transformer Block。其中一个 Swin Transformer Block 由一个带两层 MLP 的 Shifted Window-based MSA 组成,另一个 Swin Transformer Block 由一个带两层 MLP 的 **Window-based MSA **组成。在每个 MSA 模块和每个 MLP 之前使用 LayerNorm(LN) 层,并在每个 MSA 和 MLP之后使用残差连接。

可以看到 Swin Transformer Block 和 ViT Block 的区别就在于将 ViT 的多头注意力机制 MSA 替换为了 Shifted Window-based MSA 和 Window-based MSA

Stage1: Swin Transformer Block:Window-based MSA

标准 ViT 的多头注意力机制 MSA 采用的是全局自注意力机制,即:计算每个 token 和所有其他 token 的 attention map。全局自注意力机制的计算复杂度是 \(O(N^2d)\) ,其中, \(N\) 是 token的数量, \(d\) 是 Embedding dimension。全局自注意力机制的计算复杂度与序列长度 \(N\) 成平方关系。当图片分辨率较高或是密集预测任务中计算量会过大。

Window-based MSA 不同于普通的 MSA,它在一个个 window 里面去计算 self-attention。假设每个 window 里面包括 \(M×M \)个 image patches,则 Window-based MSA 和普通的 MSA 的计算量分别为:

\[Ω(MSA)=4ℎwC^2+2(ℎw)^2C,\\ Ω(W-MSA)=4ℎwC^2+2M^2ℎwC,\]

由于 Window 的 patch 数量 \(M\) 远小于图片patch数量 \(ℎw\) ,Window-based MSA 的计算量与序列长度 \(N=hw\) 成线性关系。

Stage1: Swin Transformer Block:Shifted Window-based MSA

与W-MSA不同的地方在于这个模块存在窗口滑动,所以叫做shifted window。滑动距离是window_size//2,方向是向右和向下。
滑动窗口是为了解决W-MSA计算attention时,窗口与窗口之间无法进行信息传递的问题。如下图所示,左侧是网络第L层使用的W-MSA模块,右侧是第L+1层使用SW-MSA模块。对比可以发现,窗口(Windows)发生了偏移。

比如在L+1层特征图上,对于第一行第2列的2x4的窗口,它能够使第L层的第一排的两个窗口信息进行交流。再比如,第二行第二列的4x4的窗口,他能够使第L层的四个窗口信息进行交流,其他同理。

image

但是引入 Shifted Window 会带来另一个问题就是会造成 window 数发生改变,而且有的 window 大,有的 window 小

image

一种简单的解决办法是把所有 window 都做 padding 操作,使之达到相同的大小。但是这会因为 window 数量的增加 (从\( ⌈\frac{h}{M}⌉×⌈\frac{w}{M}⌉\) 增加到 \( (⌈\frac{h}{M}⌉+1)×(⌈\frac{w}{M}⌉+1)\)而增加计算量。所以作者在这里提出了一种更加高效的 batch computation 计算方法,通过 cycle shift 的方法,合并小的 windows,仔细看上图,将 A,B,C 这3个小的 windows 进行循环移位,使之合并小的 windows。

经过了 cycle shift 的方法,一个 window 可能会包括来自不同 window 的内容。比如图中右下角的 window,来自4个不同的 sub-window。因此,要采用 masked MSA 机制将 self-attention 的计算限制在每个子窗口内。最后通过 reverse cycle shift 的方法将每个 window 的 self-attention 结果返回。

这里进行下简单的图解,下图代表 cycle shift 的过程,这9个 window 通过移位从左边移动到右侧的位置。

image

这样再按照之前的 window 划分,就能够得到 window 5 的attention 的结果了。但是这样操作会使得 window 6 和 4 的 attention 混在一起,window 1,3,7 和 9 的 attention 混在一起。所以需要采用 masked MSA 机制将 self-attention 的计算限制在每个子窗口内。具体怎么做呢?

按照 Swin Transformer 的代码实现 (下面会有讲解),还是做正常的 self-attention (在 window_size 上做),之后要进行一次 mask 操作,把不需要的 attention 值给它置为0。

  • 例1:比如右上角这个 window,如下图所示。它由4个 patch 组成,所以应该计算出的 attention map是4×4的。但是6和4是2个不同的 sub-window,我们又不想让它们的 attention 发生交叠。所以我们希望的 attention map 应该是下图这个样子。
image.png
image.png

因此我们就需要如下图8所示的 mask。

image.png
  • 例2:比如右下角这个 window,如下图9所示。它由4个 patch 组成,所以应该计算出的 attention map是4×4的。但是1,3,7和9是4个不同的 sub-window,我们又不想让它们的 attention 发生交叠。所以我们希望的 mask 应该是这个样子。
image.jpeg
image.jpeg

Stage 2/3/4

Stage 2 的输入是维度是 \(\frac{H}{4}\times\frac{W}{4}\times C\) 的张量。从 Stage 2 到 Stage 4 的每个 stage 的初始阶段都会先做一步 Patch Merging 操作,Patch Merging 操作的目的是为了减少 tokens 的数量,它会把相邻的 2×2 个 tokens 给合并到一起,得到的 token 的维度是 \(4C\) 。Patch Merging 操作再通过一个\(1\times 1\)的卷积把维度降为 \(2C\) 。至此,维度是 \(\frac{H}{4}\times\frac{W}{4}\times C\) 的张量经过Patch Merging 操作变成了维度是 \(\frac{H}{8}\times\frac{W}{8}\times 2C\) 的张量。

同理,Stage 3 的Patch Merging 操作会把维度是\(\frac{H}{8}\times\frac{W}{8}\times 2C\) 的张量变成维度是 \(\frac{H}{16}\times\frac{W}{16}\times 4C\) 的张量。Stage 4 的Patch Merging 操作会把维度是 \(\frac{H}{16}\times\frac{W}{16}\times 4C\) 的张量变成维度是 \(\frac{H}{32}\times\frac{W}{32}\times 8C\) 的张量。

每个 Stage 都会改变张量的维度,形成一种层次化的表征。因此,这种层次化的表征可以方便地替换为各种视觉任务的骨干网络。

相对位置编码

注意 Swin Transformer 的位置编码是加在 attention 矩阵上的,attention 是个四维张量,它的维度是:

(num_windows,num_heads,windows_sizewindows_size,windows_sizewindows_size)

具体操作为

\[Attention(Q, K, V ) = SoftMax(QK^T/√d + B)V\]

下面具体讲解什么是Relative Position Bias(假设特征图大小为2×2)

image
  1. 计算相对位置索引:比如蓝色像素,在蓝色像素使用q与所有像素k进行匹配过程中,是以蓝色像素为参考点。然后用蓝色像素的绝对位置索引与其他位置索引进行相减,就得到其他位置相对蓝色像素的相对位置索引,同理可以得到其他位置相对蓝色像素的相对位置索引矩阵(第一排四个位置矩阵)。
  2. 展平拼接:将每个相对位置索引矩阵按行展平,并拼接在一起可以得到第二排的这个4x4矩阵
  3. 索引转换为一维:在源码中作者为了方便把二维索引给转成了一维索引。
    • 首先在原始的相对位置索引上加上\(M-1\)(\(M\)为窗口的大小,在本示例中\(M=2\)),加上之后索引中就不会有负数了。
    • 接着将所有的行标都乘上\(2M-1\)
    • 最后将行标和列标进行相加。这样即保证了相对位置关系,而且不会出现直接相加后位置重叠的问题。(0±1和-1+0结果都一样,但其实其位置不一样)
image.png
image.png
image.png
  1. 取出相对位置偏置参数。真正使用到的可训练参数B 是保存在relative position bias table表里的,其长度是等于 \((2M-1) \times (2M-1)\)。相对位置偏置参数B,是根据相对位置索引来查relative position bias table表得到的,如下图所示。

为啥表长是 \((2M-1) \times (2M-1)\)?考虑两个极端位置,(0,0)能取到的相对位置极值为(-1,-1),(-1,-1)能取到的极值是(1,1),即行和列都能取到(2M-1)个数。考虑到所有的排列组合,表的长度就是\(( 2 M − 1 ) × ( 2 M − 1 ) \)

Swin Transformer 的结构

Swin Transformer 分为 Swin-T,Swin-S,Swin-B,Swin-L 这四种结构。使用的 window 的大小统一为 \(M=7\) ,每个 head 的embedding dimension 都是 32,每个 stage 的层数如下:

  • Swin-T:\(C=96\)  ,layer number:\(\{2,2,6,2\}\)
  • Swin-T:\(C=96\)  ,layer number:\(\{2,2,18,2\}\)
  • Swin-T:\(C=128\)  ,layer number:\(\{2,2,18,2\}\)
  • Swin-T:\(C=192\)  ,layer number:\(\{2,2,18,2\}\)

Experiments:

图像分类:

数据集:ImageNet

(a)表是直接在 ImageNet-1k 上训练,(b)表是先在 ImageNet-22k 上预训练,再在 ImageNet-1k 上微调。

对标 88M 参数的 DeiT-B 模型,它在 ImageNet-1k 上训练的结果是83.1% Top1 Accuracy,Swin-B 模型的参数是80M,它在 ImageNet-1k 上训练的结果是83.5% Top1 Accuracy,优于DeiT-B 模型。

image

图像分类上比 ViT、DeiT等 Transformer 类型的网络效果更好,但是比不过 CNN 类型的EfficientNet,猜测 Swin Transformer 还是更加适用于更加复杂、尺度变化更多的任务。

目标检测:

数据集:COCO 2017 (118k Training, 5k validation, 20k test)

(a) 表是在 Cascade Mask R-CNN, ATSS, RepPoints v2, 和 Sparse RCNN 上对比 Swin-T 和 ResNet-50 作为 Backbone 的性能。

(b) 表是使用 Cascade Mask R-CNN 模型的不同 Backbone 的性能对比。

(c) 表是整体的目标检测系统的对比,在 COCO test-dev 上达到了 58.7 box AP 和 51.1 mask AP。

image

语义分割:

数据集:ADE20K (20k Training, 2k validation, 3k test)

下图13列出了不同方法/Backbone的mIoU、模型大小()、FLOPs和FPS。从这些结果可以看出,Swin-S 比具有相似计算成本的 DeiT-S 高出+5.3 mIoU (49.3 vs . 44.0)。也比ResNet-101 高+4.4 mIoU,比 ResNeSt-101 高 +2.4 mIoU。

image

Reference

Vision Transformer 超详细解读 (原理分析+代码解读) (十七)

李沐论文精读系列二:Vision Transformer、MAE、Swin-Transformer