告别CNN?一张图等于16x16个字,计算机视觉也用上Transformer了( 二 )


告别CNN?一张图等于16x16个字,计算机视觉也用上Transformer了文章插图
模型原理
该研究提出了一种称为Vision Transformer(ViT)的模型 , 在设计上是尽可能遵循原版Transformer结构 , 这也是为了尽可能保持原版的性能 。
虽然可以并行处理 , 但Transformer依然是以一维序列作为输入 , 然而图片数据都是二维的 , 因此首先要解决的问题是如何将图片以合适的方式输入到模型中 。 本文采用的是切块 + embedding的方法 , 如下图:
告别CNN?一张图等于16x16个字,计算机视觉也用上Transformer了文章插图
首先将原始图片划分为多个子图(patch) , 每个子图相当于一个word , 这个过程也可以表示为:
告别CNN?一张图等于16x16个字,计算机视觉也用上Transformer了文章插图
其中x是输入图片 , xp则是处理后的子图序列 , P2则是子图的分辨率 , N则是切分后的子图数量(即序列长度) , 显然有
告别CNN?一张图等于16x16个字,计算机视觉也用上Transformer了文章插图
。 由于Transformer只接受1D序列作为输入 , 因此还需要对每个patch进行embedding , 通过一个线性变换层将二维的patch嵌入表示为长度为D的一维向量 , 得到的输出被称为patch嵌入 。
类似于BERT模型的[class] token机制 , 对每一个patch嵌入
告别CNN?一张图等于16x16个字,计算机视觉也用上Transformer了文章插图
, 都会额外预测一个可学习的嵌入表示 , 然后将这个嵌入表示在encoder中的最终输出(
告别CNN?一张图等于16x16个字,计算机视觉也用上Transformer了文章插图
)作为对应patch的表示 。 在预训练和微调阶段 , 分类头都依赖于 。
此外还加入了位置嵌入信息(图中的0 , 1 , 2 , 3…) , 因为序列化的patch丢失了他们在图片中的位置信息 。 作者尝试了各种不同的2D嵌入方法 , 但是相较于一般的1D嵌入并没有任何显著的性能提升 , 因此最终使用联合嵌入作为输入 。
模型结构与标准的Transformer相同(如上图右侧) , 即由多个交互层多头注意力(MSA)和多层感知器(MLP)构成 。 在每个模块前使用LayerNorm , 在模块后使用残差连接 。 使用GELU作为MLP的激活函数 。 整个模型的更新公式如下:
告别CNN?一张图等于16x16个字,计算机视觉也用上Transformer了文章插图
其中(1)代表了嵌入层的更新 , 公式(2)和(3)则代表了MSA和MLP的前向传播 。
此外本文还提出了一种直接采用ResNet中间层输出作为图片嵌入表示的方法 , 可以作为上述基于patch分割方法的替代 。
告别CNN?一张图等于16x16个字,计算机视觉也用上Transformer了文章插图
模型训练和分辨率调整
和之前常用的做法一样 , 在针对具体任务时 , 先在大规模数据集上训练 , 然后根据具体的任务需求进行微调 。 这里主要是更换最后的分类头 , 按照分类数来设置分类头的参数形状 。 此外作者还发现在更高的分辨率进行微调往往能取得更好的效果 , 因为在保持patch分辨率不变的情况下 , 原始图像分辨率越高 , 得到的patch数越大 , 因此得到的有效序列也就越长 。
告别CNN?一张图等于16x16个字,计算机视觉也用上Transformer了文章插图
对比实验
4.1 实验设置
首先作者设计了多个不同大小的ViT变体 , 分别对应不同的复杂度 。
告别CNN?一张图等于16x16个字,计算机视觉也用上Transformer了文章插图


推荐阅读