vivo手机|PyTorch Parallel Training( 三 )


float16和float相比恰里 , 总结下来就是两个原因:内存占用更少 , 计算更快 。
内存占用更少:这个是显然可见的 , 通用的模型 fp16 占用的内存只需原来的一半 。memory-bandwidth 减半所带来的好处:
模型占用的内存更小 , 训练的时候可以用更大的batchsize 。
模型训练时 , 通信量(特别是多卡 , 或者多机多卡)大幅减少 , 大幅减少等待时间 , 加快数据的流通 。
计算更快:目前的不少GPU都有针对 fp16 的计算进行优化 。论文指出:在近期的GPU中 , 半精度的计算吞吐量可以是单精度的 2-8 倍;从下图我们可以看到混合精度训练几乎没有性能损失 。
vivo手机|PyTorch Parallel Training
文章图片

文章图片

3.2 使用方式3.2.1 混合精度
在混合精度训练上 , Apex 的封装十分优雅 。直接使用amp.initialize 包装模型和优化器 , apex 就会自动帮助我们管理模型参数和优化器的精度了 , 根据精度需求不同可以传入其他配置参数 。from apex import ampmodel, optimizer = amp.initialize(model, optimizer, opt_level='O1')
其中 opt_level 为精度的优化设置 , O0(第一个字母是大写字母O):
O0:纯FP32训练 , 可以作为accuracy的baseline;
【vivo手机|PyTorch Parallel Training】O1:混合精度训练(推荐使用) , 根据黑白名单自动决定使用FP16(GEMM, 卷积)还是FP32(Softmax)进行计算 。
O2:“几乎FP16”混合精度训练 , 不存在黑白名单 , 除了Batch norm , 几乎都是用FP16计算 。
O3:纯FP16训练 , 很不稳定 , 但是可以作为speed的baseline;3.2.2 并行训练
Apex也实现了并行训练模型的转换方式 , 改动并不大 , 主要是优化了NCCL的通信 , 因此代码和 torch.distributed 保持一致 , 换一下调用的API即可:from apex import ampfrom apex.parallel import DistributedDataParallelmodel, optimizer = amp.initialize(model, optimizer, opt_level='O1')model = DistributedDataParallel(model, delay_allreduce=True)# 反向传播时需要调用 amp.scale_loss , 用于根据loss值自动对精度进行缩放with amp.scale_loss(loss, optimizer) as scaled_loss:scaled_loss.backward()3.2.3 同步BN
Apex为我们实现了同步BN , 用于解决单GPU的minibatch太小导致BN在训练时不收敛的问题 。from apex.parallel import convert_syncbn_modelfrom apex.parallel import DistributedDataParallel# 注意顺序:三个顺序不能错model = convert_syncbn_model(UNet3d(n_channels=1, n_classes=1)).to(device)model, optimizer = amp.initialize(model, optimizer, opt_level='O1')model = DistributedDataParallel(model, delay_allreduce=True)
调用该函数后 , Apex会自动遍历model的所有层 , 将BatchNorm层替换掉 。3.3 汇总
Apex的并行训练部分主要与如下代码段有关:# main.pyimport torchimport argparseimport torch.distributed as distfrom apex.parallel import convert_syncbn_modelfrom apex.parallel import DistributedDataParallelparser = argparse.ArgumentParser()parser.add_argument('--local_rank', default=-1, type=int,help='node rank for distributed training')args = parser.parse_args()dist.init_process_group(backend='nccl')torch.cuda.set_device(args.local_rank)train_dataset = ...train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=..., sampler=train_sampler)model = ...model = convert_syncbn_model(model)model, optimizer = amp.initialize(model, optimizer)model = DistributedDataParallel(model, device_ids=[args.local_rank])optimizer = optim.SGD(model.parameters())for epoch in range(100):for batch_idx, (data, target) in enumerate(train_loader):images = images.cuda(non_blocking=True)target = target.cuda(non_blocking=True)...output = model(images)loss = criterion(output, target)optimizer.zero_grad()with amp.scale_loss(loss, optimizer) as scaled_loss:scaled_loss.backward()optimizer.step()


推荐阅读