normalized_weights = nn.softmax(weights, axis=-1)
# attention (B, N, T, D)
attention = jnp.matmul(normalized_weights, v)
attention = self.att_drop(attention, deterministic=not train)
# gather heads
attention = attention.transpose(0, 2, 1, 3).reshape(B, T, N*D)
# project
out = self.proj_drop(self.proj_net(attention), deterministic=not train)
return out
5、使用CLS嵌入进行分类
最后MLP头(分类头) 。
class ViT(nn.Module):
patch_size: int
embed_dim: int
hidden_dim: int
n_heads: int
drop_p: float
num_layers: int
mlp_dim: int
num_classes: int
def setup(self):
self.patch_extracter = Patches(self.patch_size, self.embed_dim)
self.patch_encoder = PatchEncoder(self.hidden_dim)
self.dropout = nn.Dropout(self.drop_p)
self.transformer_encoder = TransformerEncoder(self.embed_dim, self.hidden_dim, self.n_heads, self.drop_p, self.mlp_dim)
self.cls_head = nn.Dense(features=self.num_classes)
def __call__(self, x, train=True):
x = self.patch_extracter(x)
x = self.patch_encoder(x)
x = self.dropout(x, deterministic=not train)
for i in range(self.num_layers):
x = self.transformer_encoder(x, train)
# MLP head
x = x[:, 0] # [CLS] token
x = self.cls_head(x)
return x
使用JAX/Flax训练
现在已经创建了模型 , 下面就是使用JAX/Flax来训练 。
数据集
这里我们直接使用 torchvision的CIFAR10.
首先是一些工具函数
def image_to_numpy(img):
img = np.array(img, dtype=np.float32)
img = (img / 255. - DATA_MEANS) / DATA_STD
return img
def numpy_collate(batch):
if isinstance(batch[0], np.ndarray):
return np.stack(batch)
elif isinstance(batch[0], (tuple, list)):
transposed = zip(*batch)
return [numpy_collate(samples) for samples in transposed]
else:
return np.array(batch)
然后是训练和测试的dataloader
test_transform = image_to_numpy
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop((IMAGE_SIZE, IMAGE_SIZE), scale=CROP_SCALES, ratio=CROP_RATIO),
image_to_numpy
])
# Validation set should not use the augmentation.
train_dataset = CIFAR10('data', train=True, transform=train_transform, download=True)
val_dataset = CIFAR10('data', train=True, transform=test_transform, download=True)
train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000], generator=torch.Generator().manual_seed(SEED))
_, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000], generator=torch.Generator().manual_seed(SEED))
test_set = CIFAR10('data', train=False, transform=test_transform, download=True)
train_loader = torch.utils.data.DataLoader(
train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=2, persistent_workers=True, collate_fn=numpy_collate,
)
val_loader = torch.utils.data.DataLoader(
val_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=2, persistent_workers=True, collate_fn=numpy_collate,
)
test_loader = torch.utils.data.DataLoader(
test_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=2, persistent_workers=True, collate_fn=numpy_collate,
)
初始化模型
初始化ViT模型
def initialize_model(
seed=42,
patch_size=16, embed_dim=192, hidden_dim=192,
n_heads=3, drop_p=0.1, num_layers=12, mlp_dim=768, num_classes=10
):
main_rng = jax.random.PRNGKey(seed)
x = jnp.ones(shape=(5, 32, 32, 3))
# ViT
model = ViT(
patch_size=patch_size,
embed_dim=embed_dim,
hidden_dim=hidden_dim,
n_heads=n_heads,
drop_p=drop_p,
num_layers=num_layers,
mlp_dim=mlp_dim,
num_classes=num_classes
)
main_rng, init_rng, drop_rng = random.split(main_rng, 3)
params = model.init({'params': init_rng, 'dropout': drop_rng}, x, train=True)['params']
return model, params, main_rng
vit_model, vit_params, vit_rng = initialize_model()
创建TrainState
在Flax中常见的模式是创建管理训练的状态的类 , 包括轮次、优化器状态和模型参数等等 。还可以通过在Apply_fn中指定apply_fn来减少学习循环中的函数参数列表 , apply_fn对应于模型的前向传播 。
def create_train_state(
model, params, learning_rate
):
optimizer = optax.adam(learning_rate)
return train_state.TrainState.create(
apply_fn=model.apply,
tx=optimizer,
params=params
)
state = create_train_state(vit_model, vit_params, 3e-4)
循环训练
def train_model(train_loader, val_loader, state, rng, num_epochs=100):
推荐阅读
- 理论+实践,教你如何使用Nginx实现限流
- 使用美国主机时,哪些因素会影响建站时间?
- 怎样实现电脑微信多开?请试试以下方法!
- iPhone照相机夜间模式怎么打开?
- iPhone怎么调时间?
- iphone应用商店下载不了软件怎么办,应用商店的正确使用教程分享给大家
- seo精准引流如何实现,百度精准引流推广方法
- 禁用445端口存储怎么使用!如何关闭445端口?
- 淘宝店铺优惠券怎么用,淘宝店铺优惠券使用规则
- 美图秀秀在线使用证件照……如何使用手机美图秀秀抠图?