pytorch实现 GoogLeNet——CNN经典网络模型详解( 五 )


pytorch实现 GoogLeNet——CNN经典网络模型详解

文章插图
 
在这里插入图片描述
#predict.pyimport torchfrom model import GoogLeNetfrom PIL import Imagefrom torchvision import transformsimport matplotlib.pyplot as pltimport jsondata_transform = transforms.Compose(    [transforms.Resize((224, 224)),     transforms.ToTensor(),     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load imageimg = Image.open("../tulip.jpg")plt.imshow(img)# [N, C, H, W]img = data_transform(img)# expand batch dimensionimg = torch.unsqueeze(img, dim=0)# read class_indicttry:    json_file = open('./class_indices.json', 'r')    class_indict = json.load(json_file)except Exception as e:    print(e)    exit(-1)# create modelmodel = GoogLeNet(num_classes=5, aux_logits=False)# load model weightsmodel_weight_path = "./googleNet.pth"missing_keys, unexpected_keys = model.load_state_dict(torch.load(model_weight_path), strict=False)model.eval()with torch.no_grad():    # predict class    output = torch.squeeze(model(img))    predict = torch.softmax(output, dim=0)    predict_cla = torch.argmax(predict).numpy()print(class_indict[str(predict_cla)])plt.show()



推荐阅读