基于关系推理的自监督学习无标记训练( 三 )
为了比较关系推理方法在浅层模型和深层模型上的性能 , 我们将创建一个浅层模型(Conv4) , 并使用深层模型的结构(Resnet34) 。
backbone = Conv4() # 浅层模型backbone = models.resnet34(pretrained = False) # 深层模型
根据作者的建议 , 设置了一些超参数和增强策略 。 我们将在未标记的STL-10数据集上用关系头训练主干 。
# 模拟的超参数K = 16 # tot augmentations, 论文中 K=32 batch_size = 64 # 论文中使用64tot_epochs = 10 # 论文中使用200feature_size = 64 # Conv4 主干的单元数feature_size = 1000 # Resnet34 主干的单元数backbone# 扩充策略normalize = transforms.Normalize(mean=[0.4406, 0.4273, 0.3858],std=[0.2687, 0.2613, 0.2685])color_jitter = transforms.ColorJitter(brightness=0.8, contrast=0.8,saturation=0.8, hue=0.2)rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)rnd_gray = transforms.RandomGrayscale(p=0.2)rnd_rcrop = transforms.RandomResizedCrop(size=96, scale=(0.08, 1.0),interpolation=2)rnd_hflip = transforms.RandomHorizontalFlip(p=0.5)train_transform = transforms.Compose([rnd_rcrop, rnd_hflip,rnd_color_jitter, rnd_gray,transforms.ToTensor(), normalize])# 加载到数据加载器torch.manual_seed(1)torch.cuda.manual_seed(1)train_set = MultiSTL10(K=K, root='data', split='unlabeled', transform=train_transform, download=True)train_loader = DataLoader(train_set,batch_size=batch_size, shuffle=True,num_workers=2, pin_memory=True)
文章插图
到目前为止 , 我们已经创造了训练我们模型所需的一切 。 现在我们将在10个时期和16个增强图像(K)中训练主干和关系头模型 , 使用1个GPU Tesla P100-PCIE-16GB在浅层模型(Conv4)上花费4个小时 , 在深层模型(Resnet34)上花费6个小时(你可以自由地更改时期数以及另一个超参数以获得更好的结果)
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")backbone.to(device)model = RelationalReasoning(backbone, feature_size)model.train(tot_epochs=tot_epochs, train_loader=train_loader)torch.save(model.backbone.state_dict(), 'model.tar')
在训练了我们的主干模型之后 , 我们丢弃了关系头 , 只将主干用于下游任务 。 我们需要使用STL-10(500个图像)中的标记数据来微调我们的主干 , 并在测试集中测试最终的模型(800个图像) 。 训练和测试数据集将加载到Dataloader中 , 而无需进行扩充 。
# set random seedtorch.manual_seed(1)torch.cuda.manual_seed(1)# no augmentations used for linear evaluationtransform_lineval = transforms.Compose([transforms.ToTensor(), normalize])# Download STL10 labeled train and test datasettrain_set_lineval = torchvision.datasets.STL10('data', split='train', transform=transform_lineval)test_set_lineval = torchvision.datasets.STL10('data', split='test', transform=transform_lineval)# Load dataset in data loadertrain_loader_lineval = DataLoader(train_set_lineval, batch_size=128, shuffle=True)test_loader_lineval = DataLoader(test_set_lineval, batch_size=128, shuffle=False)
我们将加载预训练的主干模型 , 并使用一个简单的线性模型将输出特性与数据集中的许多类连接起来 。
# linear modellinear_layer = torch.nn.Linear(64, 10) # if backbone is Conv4linear_layer = torch.nn.Linear(1000, 10) # if backbone is Resnet34# defining a raw backbone modelbackbone_lineval = Conv4() # Conv4backbone_lineval = models.resnet34(pretrained = False) # Resnet34# load modelcheckpoint = torch.load('model.tar') # name of pretrain weightbackbone_lineval.load_state_dict(checkpoint)
推荐阅读
- 华硕基于WRX80的主板现身 为AMD Ryzen Threadripper Pro打造
- 微软新版电子邮件客户端截图曝光:基于网页端Outlook
- 曝光 | 小鹏或春节前推送NGP更新,基于高精地图可自动变道
- 基于Spring+Angular9+MySQL开发平台
- 14款华为手机/平板公测EMUI 11:全部基于麒麟980
- AI赋能,让消防、用电更“智慧”
- 荷兰职员哭泣:中国明明说好自研光刻机,却跟日本尼康扯上关系
- 基于安卓11打造!魅族17系列将升级全新Flyme 8
- 谷歌为用户提供了基于AR的虚拟化妆体验
- 挺进云端AI训练&推理双赛道!独家对话燧原科技COO张亚林:揭秘超高效率背后的“内功”