使用TPU在PyTorch中实现ResNet50( 二 )


SERIAL_EXEC = xmp.MpSerialExecutor()# 只在内存中实例化一次模型权重 。WRAPPED_MODEL = xmp.MpModelWrapper(ResNet50())def train_resnet50():torch.manual_seed(1)def get_dataset():norm = transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),norm,])transform_test = transforms.Compose([transforms.ToTensor(),norm,])train_dataset = datasets.CIFAR10(root=FLAGS['data_dir'],train=True,download=True,transform=transform_train)test_dataset = datasets.CIFAR10(root=FLAGS['data_dir'],train=False,download=True,transform=transform_test)return train_dataset, test_dataset# 使用串行执行器可以避免多个进程# 下载相同的数据 。train_dataset, test_dataset = SERIAL_EXEC.run(get_dataset)train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,num_replicas=xm.xrt_world_size(),rank=xm.get_ordinal(),shuffle=True)train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=FLAGS['batch_size'],sampler=train_sampler,num_workers=FLAGS['num_workers'],drop_last=True)test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=FLAGS['batch_size'],shuffle=False,num_workers=FLAGS['num_workers'],drop_last=True)# 将学习率缩放learning_rate = FLAGS['learning_rate'] * xm.xrt_world_size()# 获取损失函数、优化器和模型device = xm.xla_device()model = WRAPPED_MODEL.to(device)optimizer = optim.SGD(model.parameters(), lr=learning_rate,momentum=FLAGS['momentum'], weight_decay=5e-4)loss_fn = nn.NLLLoss()def train_loop_fn(loader):tracker = xm.RateTracker()model.train()for x, (data, target) in enumerate(loader):optimizer.zero_grad()output = model(data)loss = loss_fn(output, target)loss.backward()xm.optimizer_step(optimizer)tracker.add(FLAGS['batch_size'])if x % FLAGS['log_steps'] == 0:print('[xla:{}]({}) Loss={:.2f} Time={}'.format(xm.get_ordinal(), x, loss.item(), time.asctime()), flush=True)def test_loop_fn(loader):total_samples = 0correct = 0model.eval()data, pred, target = None, None, Nonefor data, target in loader:output = model(data)pred = output.max(1, keepdim=True)[1]correct += pred.eq(target.view_as(pred)).sum().item()total_samples += data.size()[0]accuracy = 100.0 * correct / total_samplesprint('[xla:{}] Accuracy={:.2f}%'.format(xm.get_ordinal(), accuracy), flush=True)return accuracy, data, pred, target# 训练和评估的循环accuracy = 0.0data, pred, target = None, None, Nonefor epoch in range(1, FLAGS['num_epochs'] + 1):para_loader = pl.ParallelLoader(train_loader, [device])train_loop_fn(para_loader.per_device_loader(device))xm.master_print("Finished training epoch {}".format(epoch))para_loader = pl.ParallelLoader(test_loader, [device])accuracy, data, pred, target= test_loop_fn(para_loader.per_device_loader(device))if FLAGS['metrics_debug']:xm.master_print(met.metrics_report(), flush=True)return accuracy, data, pred, target现在,我们将开始ResNet50的训练 。训练将在我们在参数中定义的50个epoch内完成 。训练开始前,我们会记录训练时间,训练结束后,我们将打印总时间 。
start_time = time.time()# 启动训练流程def training(rank, flags):global FLAGSFLAGS = flagstorch.set_default_tensor_type('torch.FloatTensor')accuracy, data, pred, target = train_resnet50()if rank == 0:# 检索TPU核心0上的张量并绘制 。plot_results(data.cpu(), pred.cpu(), target.cpu())xmp.spawn(training, args=(FLAGS,), nprocs=FLAGS['num_cores'],start_method='fork')

使用TPU在PyTorch中实现ResNet50

文章插图
 

使用TPU在PyTorch中实现ResNet50

文章插图
 

使用TPU在PyTorch中实现ResNet50

文章插图
 
训练结束后,我们会打印训练过程所花费的时间 。
使用TPU在PyTorch中实现ResNet50

文章插图
 
最后,在训练过程中,我们将模型对样本测试数据的预测可视化 。
end_time = time.time()print("Time taken = ", end_time-start_time)
使用TPU在PyTorch中实现ResNet50

文章插图
 

【使用TPU在PyTorch中实现ResNet50】


推荐阅读