best_eval = 0.0
for epoch_idx in tqdm(range(1, num_epochs + 1)):
state, rng = train_epoch(train_loader, epoch_idx, state, rng)
if epoch_idx % 1 == 0:
eval_acc = eval_model(val_loader, state, rng)
logger.add_scalar('val/acc', eval_acc, global_step=epoch_idx)
if eval_acc >= best_eval:
best_eval = eval_acc
save_model(state, step=epoch_idx)
logger.flush()
# Evaluate after training
test_acc = eval_model(test_loader, state, rng)
print(f'test_acc: {test_acc}')
def train_epoch(train_loader, epoch_idx, state, rng):
metrics = defaultdict(list)
for batch in tqdm(train_loader, desc='Training', leave=False):
state, rng, loss, acc = train_step(state, rng, batch)
metrics['loss'].append(loss)
metrics['acc'].append(acc)
for key in metrics.keys():
arg_val = np.stack(jax.device_get(metrics[key])).mean()
logger.add_scalar('train/' + key, arg_val, global_step=epoch_idx)
print(f'[epoch {epoch_idx}] {key}: {arg_val}')
return state, rng
验证
def eval_model(data_loader, state, rng):
# Test model on all images of a data loader and return avg loss
correct_class, count = 0, 0
for batch in data_loader:
rng, acc = eval_step(state, rng, batch)
correct_class += acc * batch[0].shape[0]
count += batch[0].shape[0]
eval_acc = (correct_class / count).item()
return eval_acc
训练步骤
在train_step中定义损失函数 , 计算模型参数的梯度 , 并根据梯度更新参数;在value_and_gradients方法中 , 计算状态的梯度 。在apply_gradients中 , 更新TrainState 。交叉熵损失是通过apply_fn(与model.apply相同)计算logits来计算的 , apply_fn是在创建TrainState时指定的 。
@jax.jit
def train_step(state, rng, batch):
loss_fn = lambda params: calculate_loss(params, state, rng, batch, train=True)
# Get loss, gradients for loss, and other outputs of loss function
(loss, (acc, rng)), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
# Update parameters and batch statistics
state = state.apply_gradients(grads=grads)
return state, rng, loss, acc
计算损失
def calculate_loss(params, state, rng, batch, train):
imgs, labels = batch
rng, drop_rng = random.split(rng)
logits = state.apply_fn({'params': params}, imgs, train=train, rngs={'dropout': drop_rng})
loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels).mean()
acc = (logits.argmax(axis=-1) == labels).mean()
return loss, (acc, rng)
结果
训练结果如下所示 。在Colab pro的标准GPU上 , 训练时间约为1.5小时 。
test_acc: 0.7704000473022461
如果你对JAX感兴趣 , 请看这里是本文的完整代码:
https://avoid.overfit.cn/post/926b7965ba56464ba151cbbfb6a98a93
作者:satojkovic
【使用JAX实现完整的Vision Transformer】
推荐阅读
- 理论+实践,教你如何使用Nginx实现限流
- 使用美国主机时,哪些因素会影响建站时间?
- 怎样实现电脑微信多开?请试试以下方法!
- iPhone照相机夜间模式怎么打开?
- iPhone怎么调时间?
- iphone应用商店下载不了软件怎么办,应用商店的正确使用教程分享给大家
- seo精准引流如何实现,百度精准引流推广方法
- 禁用445端口存储怎么使用!如何关闭445端口?
- 淘宝店铺优惠券怎么用,淘宝店铺优惠券使用规则
- 美图秀秀在线使用证件照……如何使用手机美图秀秀抠图?