vivo手机|PyTorch Parallel Training( 四 )


使用 launch 启动:
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 main.py4 多卡训练时的数据记录(TensorBoard、torch.save)4.1 记录Loss曲线
在我们使用多进程时 , 每个进程有自己计算得到的Loss , 我们在进行数据记录时 , 希望对不同进程上的Loss取平均(也就是 map-reduce 的做法) , 对于其他需要记录的数据也都是一样的做法:def reduce_tensor(tensor: torch.Tensor) -> torch.Tensor:rt = tensor.clone()distributed.all_reduce(rt, op=distributed.reduce_op.SUM)rt /= distributed.get_world_size()return rt# calculate lossloss = criterion(predict, labels)reduced_loss = reduce_tensor(loss.data)train_epoch_loss += reduced_loss.item()注意在写入TensorBoard的时候只让一个进程写入就够了:# TensorBoardif args.local_rank == 0:writer.add_scalars('Loss/training', {'train_loss': train_epoch_loss,'val_loss': val_epoch_loss}, epoch + 1)4.2 torch.save
在保存模型的时候 , 由于是Apex混合精度模型 , 我们需要使用Apex提供的保存、载入方法(见 Apex README ):# Save checkpointcheckpoint = {'model': model.state_dict(),'optimizer': optimizer.state_dict(),'amp': amp.state_dict()}torch.save(checkpoint, 'amp_checkpoint.pt')...# Restoremodel = ...optimizer = ...checkpoint = torch.load('amp_checkpoint.pt')model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)model.load_state_dict(checkpoint['model'])optimizer.load_state_dict(checkpoint['optimizer'])amp.load_state_dict(checkpoint['amp'])# Continue training...5 多卡后的 batch_size 和 learning_rate 的调整
见: https://www. zhihu.com/question/6413 4994/answer/217813386
从理论上来说 , lr = batch_size * base lr , 因为 batch_size 的增大会导致你 update 次数的减少 , 所以为了达到相同的效果 , 应该是同比例增大的 。
但是更大的 lr 可能会导致收敛的不够好 , 尤其是在刚开始的时候 , 如果你使用很大的 lr , 可能会直接爆炸 , 所以可能会需要一些 warmup 来逐步的把 lr 提高到你想设定的 lr 。
实际应用中发现不一定要同比例增长 , 有时候可能增大到 batch_size/2 倍的效果已经很不错了 。
在我的实验中 , 使用8卡训练 , 则增大batch_size 8倍 , learning_rate 4倍是差不多的 。6 完整代码示例(我用来训练3D U-Net的)import osimport datetimeimport argparsefrom tqdm import tqdmimport torchfrom torch import distributed, optimfrom torch.utils.data import DataLoaderfrom torch.utils.data.distributed import DistributedSamplerfrom torch.utils.tensorboard import SummaryWriterfrom apex import ampfrom apex.parallel import convert_syncbn_modelfrom apex.parallel import DistributedDataParallelfrom models import UNet3dfrom datasets import IronGrain3dDatasetfrom losses import BCEDiceLossfrom eval import eval_nettrain_images_folder = '../../datasets/IronGrain/74x320x320/train_patches/images/'train_labels_folder = '../../datasets/IronGrain/74x320x320/train_patches/labels/'val_images_folder = '../../datasets/IronGrain/74x320x320/val_patches/images/'val_labels_folder = '../../datasets/IronGrain/74x320x320/val_patches/labels/'def parse():parser = argparse.ArgumentParser()parser.add_argument('--local_rank', type=int, default=0)args = parser.parse_args()return argsdef main():args = parse()torch.cuda.set_device(args.local_rank)distributed.init_process_group('nccl',init_method='env://')train_dataset = IronGrain3dDataset(train_images_folder, train_labels_folder)val_dataset = IronGrain3dDataset(val_images_folder, val_labels_folder)train_sampler = DistributedSampler(train_dataset)val_sampler = DistributedSampler(val_dataset)epochs = 100batch_size = 8lr = 2e-4weight_decay = 1e-4device = torch.device(f'cuda:{args.local_rank}')train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=4,pin_memory=True, sampler=train_sampler)val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=4,pin_memory=True, sampler=val_sampler)net = convert_syncbn_model(UNet3d(n_channels=1, n_classes=1)).to(device)optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)net, optimizer = amp.initialize(net, optimizer, opt_level='O1')net = DistributedDataParallel(net, delay_allreduce=True)scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[25, 50, 75], gamma=0.2)criterion = BCEDiceLoss().to(device)if args.local_rank == 0:print(f'''Starting training:Epochs:{epochs}Batch size:{batch_size}Learning rate:{lr}Training size:{len(train_dataset)}Validation size: {len(val_dataset)}Device:{device.type}''')writer = SummaryWriter(log_dir=f'runs/irongrain/unet3d_32x160x160_BS_{batch_size}_{datetime.datetime.now()}')for epoch in range(epochs):train_epoch_loss = 0with tqdm(total=len(train_dataset), desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:images = Nonelabels = Nonepredict = None# trainnet.train()for batch_idx, batch in enumerate(train_loader):images = batch['image']labels = batch['label']images = images.to(device, dtype=torch.float32)labels = labels.to(device, dtype=torch.float32)predict = net(images)# calculate lossloss = criterion(predict, labels)reduced_loss = reduce_tensor(loss.data)train_epoch_loss += reduced_loss.item()# optimizeoptimizer.zero_grad()with amp.scale_loss(loss, optimizer) as scaled_loss:scaled_loss.backward()optimizer.step()scheduler.step()# set progress barpbar.set_postfix(**{'loss (batch)': loss.item()})pbar.update(images.shape[0])train_epoch_loss /= (batch_idx + 1)# evalval_epoch_loss, dice, iou = eval_net(net, criterion, val_loader, device, len(val_dataset))# TensorBoardif args.local_rank == 0:writer.add_scalars('Loss/training', {'train_loss': train_epoch_loss,'val_loss': val_epoch_loss}, epoch + 1)writer.add_scalars('Metrics/validation', {'dice': dice,'iou': iou}, epoch + 1)writer.add_images('images', images[:, :, 0, :, :], epoch + 1)writer.add_images('Label/ground_truth', labels[:, :, 0, :, :], epoch + 1)writer.add_images('Label/predict', torch.sigmoid(predict[:, :, 0, :, :]) > 0.5, epoch + 1)if args.local_rank == 0:torch.save(net, f'unet3d-epoch{epoch + 1}.pth')def reduce_tensor(tensor: torch.Tensor) -> torch.Tensor:rt = tensor.clone()distributed.all_reduce(rt, op=distributed.reduce_op.SUM)rt /= distributed.get_world_size()return rtif __name__ == '__main__':main()


推荐阅读