Stacked Hourglass Networks

Apr 25, 2024
2 views
3D Model

论文介绍了一种新的网络结构用于人体姿态检测,作者在论文中展现了不断重复bottom-up、top-down过程以及运用intermediate supervison(中间监督)对于网络性能的提升,下面来介绍Stacked Hourglass Networks.

简介

理解人类的姿态对于一些高级的任务比如行为识别来说特别重要,而且也是一些人机交互任务的基础。作者提出了一种新的网络结构Stacked Hourglass Networks来对人体的姿态进行识别,这个网络结构能够捕获并整合图像所有尺度的信息。之所以称这种网络为Stacked Hourglass Networks,主要是它长得很像堆叠起来的沙漏,如下图所示:

image

这种堆叠在一起的Hourglass模块结构是对称的,bottom-up过程将图片从高分辨率降到低分辨率,top-down过程将图片从低分辨率升到高分辨率,这种网络结构包含了许多pooling和upsampling的步骤,pooling可以将图片降到一个很低的分辨率,upsampling可以结合多个分辨率的特征。

下面介绍具体的网络结构。

Hourglass Module

Hourglass模块设计的初衷就是为了捕捉每个尺度下的信息,因为捕捉像脸,手这些部分的时候需要局部的特征,而最后对人体姿态进行预测的时候又需要整体的信息。为了捕获图片在多个尺度下的特征,通常的做法是使用多个pipeline分别单独处理不同尺度下的信息,然后再网络的后面部分再组合这些特征,而作者使用的方法就是用带有skip layers的单个pipeline来保存每个尺度下的空间信息。

image

在Hourglass模块中,卷积和max pooling被用来将特征降到一个很低的分辨率,在每一个max pooling步骤中,网络产生分支并在原来提前池化的分辨率下使用更多的卷积,当到达最低的分辨率的时候,网络开始upsample并结合不同尺度下的特征。这里upsample(上采样)采用的方法是最邻近插值,之后再将两个特征集按元素位置相加。

当到达输出分辨率的时候,再接两个1×1的卷积层来进行最后的预测,网络的输出是一组heatmap,对于给定的heatmap,网络预测在每个像素处存在关节的概率。

网络结构

Residual Module

上图中的每个方框都由下面这样的残差块组成:

image

残差结构的代码如下

上图的残差块是论文中的原图,描述的不够详细,自己看了下源代码之后,画出了如下图所示的Residual Module:

image

整体结构

网络输入的图片分辨率为256×256,在hourglass模块中的最大分辨率为64×64,整个网络最开始要经过一个7×7的步长为2的卷积层,之后再经过一个残差块和Max pooling层使得分辨率从256降到64。下面贴出n阶Hourglass代码:

class Hourglass(nn.Module):
    def __init__(self, n, f, bn=None, increase=128):
        super(Hourglass, self).__init__()
        nf = f + increase
        self.up1 = Conv(f, f, 3, bn=bn)
        # Lower branch
        self.pool1 = Pool(2, 2)
        self.low1 = Conv(f, nf, 3, bn=bn)
        # Recursive hourglass
        if n > 1:
            self.low2 = Hourglass(n-1, nf, bn=bn)
        else:
            self.low2 = Conv(nf, nf, 3, bn=bn)
        self.low3 = Conv(nf, f, 3)
        self.up2 = nn.UpsamplingNearest2d(scale_factor=2)

    def forward(self, x):
        up1 = self.up1(x)
        pool1 = self.pool1(x)
        low1 = self.low1(pool1)
        low2 = self.low2(low1)
        low3 = self.low3(low2)
        up2 = self.up2(low3)
        return up1 + up2

image

整体网络代码如下

def __init__(self, nstack, inp_dim, oup_dim, bn=False, increase=128, **kwargs):
      super(PoseNet, self).__init__()
      self.pre = nn.Sequential(
          Conv(3, 64, 7, 2, bn=bn),
          Conv(64, 128, bn=bn),
          Pool(2, 2),
          Conv(128, 128, bn=bn),
          Conv(128, inp_dim, bn=bn)
      )
      self.features = nn.ModuleList( [
      nn.Sequential(
          Hourglass(4, inp_dim, bn, increase),
          Conv(inp_dim, inp_dim, 3, bn=False),
          Conv(inp_dim, inp_dim, 3, bn=False)
      ) for i in range(nstack)])

      self.outs = nn.ModuleList([Conv(inp_dim, oup_dim, 1, relu=False, bn=False) for i in range(nstack)] )
      self.merge_features = nn.ModuleList([Merge(inp_dim, inp_dim) for i in range(nstack-1)])
      self.merge_preds = nn.ModuleList([Merge(oup_dim, inp_dim) for i in range(nstack-1)] )

      self.nstack = nstack
      self.myAEloss = AEloss()
      self.heatmapLoss = HeatmapLoss()

  def forward(self, imgs):
      x = imgs.permute(0, 3, 1, 2)
      x = self.pre(x)
      preds = []
      for i in range(self.nstack):
          feature = self.features[i](x)
          preds.append(self.outs[i](feature))
          if i != self.nstack - 1:
              x = x + self.merge_preds[i](preds[-1]) + self.merge_features[i](feature)
      return torch.stack(preds, 1)

图中的4阶Hourglass Module就是前面讲的4阶Hourglass Module,可以看到整个网络还是挺庞大的,图中的渐变红色块就是加入了中间监督的地方,即在此处使用loss函数,下面讲一下中间监督。

注意,上面的整体网络结构图中中间监督的地方输出的通道数为16是针对于MPII Human Pose这个数据集,因为该数据集将人体划分为16个关节点,具体参见人体姿态估计数据集整理(Pose Estimation/Keypoint)

Intermediate Supervision

作者在整个网络结构中堆叠了许多hourglass模块,从而使得网络能够不断重复自底向上和自顶向下的过程,作者提到采用这种结构的关键是要使用中间监督来对每一个hourglass模块进行预测,即对中间的heatmaps计算损失。

关于中间监督的位置,作者在文中也进行了讨论。大多数高阶特征仅在较低的分辨率下存在,除非发生上采样时才出现。如果在网络进行上采样后进行监督,则无法在更大的全局上下文中重新评估这些特征;最终,作者将中间监督设计在如下图所示位置:

image

在整个网络中,作者共使用了8个hourglass模块,需要注意的是,这些hourglass模块的权重不是共享的,并且所有的模块都基于相同的ground truth添加了损失函数。