SERIAL_EXEC = xmp.MpSerialExecutor()# 只在内存中实例化一次模型权重 。WRAPPED_MODEL = xmp.MpModelWrapper(ResNet50())def train_resnet50():torch.manual_seed(1)def get_dataset():norm = transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),norm,])transform_test = transforms.Compose([transforms.ToTensor(),norm,])train_dataset = datasets.CIFAR10(root=FLAGS['data_dir'],train=True,download=True,transform=transform_train)test_dataset = datasets.CIFAR10(root=FLAGS['data_dir'],train=False,download=True,transform=transform_test)return train_dataset, test_dataset# 使用串行执行器可以避免多个进程# 下载相同的数据 。train_dataset, test_dataset = SERIAL_EXEC.run(get_dataset)train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,num_replicas=xm.xrt_world_size(),rank=xm.get_ordinal(),shuffle=True)train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=FLAGS['batch_size'],sampler=train_sampler,num_workers=FLAGS['num_workers'],drop_last=True)test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=FLAGS['batch_size'],shuffle=False,num_workers=FLAGS['num_workers'],drop_last=True)# 将学习率缩放learning_rate = FLAGS['learning_rate'] * xm.xrt_world_size()# 获取损失函数、优化器和模型device = xm.xla_device()model = WRAPPED_MODEL.to(device)optimizer = optim.SGD(model.parameters(), lr=learning_rate,momentum=FLAGS['momentum'], weight_decay=5e-4)loss_fn = nn.NLLLoss()def train_loop_fn(loader):tracker = xm.RateTracker()model.train()for x, (data, target) in enumerate(loader):optimizer.zero_grad()output = model(data)loss = loss_fn(output, target)loss.backward()xm.optimizer_step(optimizer)tracker.add(FLAGS['batch_size'])if x % FLAGS['log_steps'] == 0:print('[xla:{}]({}) Loss={:.2f} Time={}'.format(xm.get_ordinal(), x, loss.item(), time.asctime()), flush=True)def test_loop_fn(loader):total_samples = 0correct = 0model.eval()data, pred, target = None, None, Nonefor data, target in loader:output = model(data)pred = output.max(1, keepdim=True)[1]correct += pred.eq(target.view_as(pred)).sum().item()total_samples += data.size()[0]accuracy = 100.0 * correct / total_samplesprint('[xla:{}] Accuracy={:.2f}%'.format(xm.get_ordinal(), accuracy), flush=True)return accuracy, data, pred, target# 训练和评估的循环accuracy = 0.0data, pred, target = None, None, Nonefor epoch in range(1, FLAGS['num_epochs'] + 1):para_loader = pl.ParallelLoader(train_loader, [device])train_loop_fn(para_loader.per_device_loader(device))xm.master_print("Finished training epoch {}".format(epoch))para_loader = pl.ParallelLoader(test_loader, [device])accuracy, data, pred, target= test_loop_fn(para_loader.per_device_loader(device))if FLAGS['metrics_debug']:xm.master_print(met.metrics_report(), flush=True)return accuracy, data, pred, target
现在,我们将开始ResNet50的训练 。训练将在我们在参数中定义的50个epoch内完成 。训练开始前,我们会记录训练时间,训练结束后,我们将打印总时间 。
start_time = time.time()# 启动训练流程def training(rank, flags):global FLAGSFLAGS = flagstorch.set_default_tensor_type('torch.FloatTensor')accuracy, data, pred, target = train_resnet50()if rank == 0:# 检索TPU核心0上的张量并绘制 。plot_results(data.cpu(), pred.cpu(), target.cpu())xmp.spawn(training, args=(FLAGS,), nprocs=FLAGS['num_cores'],start_method='fork')
![使用TPU在PyTorch中实现ResNet50](http://img.jiangsulong.com/220420/0K4396422-3.jpg)
文章插图
![使用TPU在PyTorch中实现ResNet50](http://img.jiangsulong.com/220420/0K4393U3-4.jpg)
文章插图
![使用TPU在PyTorch中实现ResNet50](http://img.jiangsulong.com/220420/0K4393C1-5.jpg)
文章插图
训练结束后,我们会打印训练过程所花费的时间 。
![使用TPU在PyTorch中实现ResNet50](http://img.jiangsulong.com/220420/0K4394559-6.jpg)
文章插图
最后,在训练过程中,我们将模型对样本测试数据的预测可视化 。
end_time = time.time()print("Time taken = ", end_time-start_time)
![使用TPU在PyTorch中实现ResNet50](http://img.jiangsulong.com/220420/0K4392259-7.jpg)
文章插图
【使用TPU在PyTorch中实现ResNet50】
推荐阅读
- 想知道是什么占用你的电脑空间,正确使用Windows 10查看磁盘空间
- Kafka-manager部署与使用简单介绍
- 如何使用 Squid 配置 SSH 代理服务器
- 小白一键重装系统备份文件在哪
- 外星飞碟是否真的存在 UFO外星飞碟
- RabbitMq七种工作模式,结合简单的java实例使用,别再说你不会
- 泾阳茯茶的功效与作用,茯茶正在重新焕发活力
- 何时使用约束求解而不是机器学习
- 在Windows和Linux中找出磁盘分区使用的文件系统,就是这么简单
- 用 Excel 将证件照蓝底换成红底