本文将展示如何使用JAX/Flax实现Vision Transformer (ViT) , 以及如何使用JAX/Flax训练ViT 。
Vision Transformer
在实现Vision Transformer时 , 首先要记住这张图 。
文章插图
以下是论文描述的ViT执行过程 。
从输入图像中提取补丁图像 , 并将其转换为平面向量 。
投影到 Transformer Encoder 来处理的维度
预先添加一个可学习的嵌入([class]标记) , 并添加一个位置嵌入 。
由 Transformer Encoder 进行编码处理
使用[class]令牌作为输出 , 输入到MLP进行分类 。
细节实现
下面 , 我们将使用JAX/Flax创建每个模块 。
1、图像到展平的图像补丁
下面的代码从输入图像中提取图像补丁 。这个过程通过卷积来实现 , 内核大小为patch_size * patch_size, stride为patch_size * patch_size , 以避免重复 。
class Patches(nn.Module):
patch_size: int
embed_dim: int
def setup(self):
self.conv = nn.Conv(
features=self.embed_dim,
kernel_size=(self.patch_size, self.patch_size),
strides=(self.patch_size, self.patch_size),
padding='VALID'
)
def __call__(self, images):
patches = self.conv(images)
b, h, w, c = patches.shape
patches = jnp.reshape(patches, (b, h*w, c))
return patches
2和3、对展平补丁块的线性投影/添加[CLS]标记/位置嵌入
Transformer Encoder 对所有层使用相同的尺寸大小hidden_dim 。上面创建的补丁块向量被投影到hidden_dim维度向量上 。与BERT一样 , 有一个CLS令牌被添加到序列的开头 , 还增加了一个可学习的位置嵌入来保存位置信息 。
class PatchEncoder(nn.Module):
hidden_dim: int
@nn.compact
def __call__(self, x):
assert x.ndim == 3
n, seq_len, _ = x.shape
# Hidden dim
x = nn.Dense(self.hidden_dim)(x)
# Add cls token
cls = self.param('cls_token', nn.initializers.zeros, (1, 1, self.hidden_dim))
cls = jnp.tile(cls, (n, 1, 1))
x = jnp.concatenate([cls, x], axis=1)
# Add position embedding
pos_embed = self.param(
'position_embedding',
nn.initializers.normal(stddev=0.02), # From BERT
(1, seq_len + 1, self.hidden_dim)
)
return x + pos_embed
4、Transformer encoder
如上图所示 , 编码器由多头自注意(MSA)和MLP交替层组成 。Norm层 (LN)在MSA和MLP块之前 , 残差连接在块之后 。
class TransformerEncoder(nn.Module):
embed_dim: int
hidden_dim: int
n_heads: int
drop_p: float
mlp_dim: int
def setup(self):
self.mha = MultiHeadSelfAttention(self.hidden_dim, self.n_heads, self.drop_p)
self.mlp = MLP(self.mlp_dim, self.drop_p)
self.layer_norm = nn.LayerNorm(epsilon=1e-6)
def __call__(self, inputs, train=True):
# Attention Block
x = self.layer_norm(inputs)
x = self.mha(x, train)
x = inputs + x
# MLP block
y = self.layer_norm(x)
y = self.mlp(y, train)
return x + y
MLP是一个两层网络 。激活函数是GELU 。本文将Dropout应用于Dense层之后 。
class MLP(nn.Module):
mlp_dim: int
drop_p: float
out_dim: Optional[int] = None
@nn.compact
def __call__(self, inputs, train=True):
actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
x = nn.Dense(features=self.mlp_dim)(inputs)
x = nn.gelu(x)
x = nn.Dropout(rate=self.drop_p, deterministic=not train)(x)
x = nn.Dense(features=actual_out_dim)(x)
x = nn.Dropout(rate=self.drop_p, deterministic=not train)(x)
return x
多头自注意(MSA)
qkv的形式应为[B, N, T, D] , 如Single Head中计算权重和注意力后 , 应输出回原维度[B, T, C=N*D] 。
class MultiHeadSelfAttention(nn.Module):
hidden_dim: int
n_heads: int
drop_p: float
def setup(self):
self.q.NET = nn.Dense(self.hidden_dim)
self.k_net = nn.Dense(self.hidden_dim)
self.v_net = nn.Dense(self.hidden_dim)
self.proj_net = nn.Dense(self.hidden_dim)
self.att_drop = nn.Dropout(self.drop_p)
self.proj_drop = nn.Dropout(self.drop_p)
def __call__(self, x, train=True):
B, T, C = x.shape # batch_size, seq_length, hidden_dim
N, D = self.n_heads, C // self.n_heads # num_heads, head_dim
q = self.q_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3) # (B, N, T, D)
k = self.k_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3)
v = self.v_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3)
# weights (B, N, T, T)
weights = jnp.matmul(q, jnp.swapaxes(k, -2, -1)) / math.sqrt(D)
推荐阅读
- 理论+实践,教你如何使用Nginx实现限流
- 使用美国主机时,哪些因素会影响建站时间?
- 怎样实现电脑微信多开?请试试以下方法!
- iPhone照相机夜间模式怎么打开?
- iPhone怎么调时间?
- iphone应用商店下载不了软件怎么办,应用商店的正确使用教程分享给大家
- seo精准引流如何实现,百度精准引流推广方法
- 禁用445端口存储怎么使用!如何关闭445端口?
- 淘宝店铺优惠券怎么用,淘宝店铺优惠券使用规则
- 美图秀秀在线使用证件照……如何使用手机美图秀秀抠图?