Hook 函数与 CAM 算法


Hook 函数与 CAM 算法

文章插图
 
这篇文章主要介绍了如何使用 Hook 函数提取网络中的特征图进行可视化,和 CAM(class activation map, 类激活图)
Hook 函数概念Hook 函数是在不改变主体的情况下,实现额外功能 。由于 PyTorch 是基于动态图实现的,因此在一次迭代运算结束后,一些中间变量如非叶子节点的梯度和特征图,会被释放掉 。在这种情况下想要提取和记录这些中间变量,就需要使用 Hook 函数 。
PyTorch 提供了 4 种 Hook 函数 。
torch.Tensor.register_hook(hook)功能:注册一个反向传播 hook 函数,仅输入一个参数,为张量的梯度 。
hook函数:
hook(grad)参数:
  • grad:张量的梯度
代码如下:
w = torch.tensor([1.], requires_grad=True)x = torch.tensor([2.], requires_grad=True)a = torch.add(w, x)b = torch.add(w, 1)y = torch.mul(a, b)# 保存梯度的 lista_grad = list()# 定义 hook 函数,把梯度添加到 list 中def grad_hook(grad): a_grad.Append(grad)# 一个张量注册 hook 函数handle = a.register_hook(grad_hook)y.backward()# 查看梯度print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)# 查看在 hook 函数里 list 记录的梯度print("a_grad[0]: ", a_grad[0])handle.remove()结果如下:
【Hook 函数与 CAM 算法】gradient: tensor([5.]) tensor([2.]) None None Nonea_grad[0]:tensor([2.])在反向传播结束后,非叶子节点张量的梯度被清空了 。而通过hook函数记录的梯度仍然可以查看 。
hook函数里面可以修改梯度的值,无需返回也可以作为新的梯度赋值给原来的梯度 。代码如下:
w = torch.tensor([1.], requires_grad=True)x = torch.tensor([2.], requires_grad=True)a = torch.add(w, x)b = torch.add(w, 1)y = torch.mul(a, b)a_grad = list()def grad_hook(grad):grad *= 2return grad*3handle = w.register_hook(grad_hook)y.backward()# 查看梯度print("w.grad: ", w.grad)handle.remove()结果是:
w.grad:tensor([30.])torch.nn.Module.register_forward_hook(hook)功能:注册 module 的前向传播hook函数,可用于获取中间的 feature map 。
hook函数:
hook(module, input, output)参数:
  • module:当前网络层
  • input:当前网络层输入数据
  • output:当前网络层输出数据
下面代码执行的功能是 $3 times 3$ 的卷积和 $2 times 2$ 的池化 。我们使用register_forward_hook()记录中间卷积层输入和输出的 feature map 。
Hook 函数与 CAM 算法

文章插图
 
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 2, 3)self.pool1 = nn.MaxPool2d(2, 2)def forward(self, x):x = self.conv1(x)x = self.pool1(x)return xdef forward_hook(module, data_input, data_output):fmap_block.append(data_output)input_block.append(data_input)# 初始化网络net = Net()net.conv1.weight[0].detach().fill_(1)net.conv1.weight[1].detach().fill_(2)net.conv1.bias.data.detach().zero_()# 注册hookfmap_block = list()input_block = list()net.conv1.register_forward_hook(forward_hook)# inferencefake_img = torch.ones((1, 1, 4, 4))# batch size * channel * H * Woutput = net(fake_img)# 观察print("output shape: {}noutput value: {}n".format(output.shape, output))print("feature maps shape: {}noutput value: {}n".format(fmap_block[0].shape, fmap_block[0]))print("input shape: {}ninput value: {}".format(input_block[0][0].shape, input_block[0]))输出如下:
output shape: torch.Size([1, 2, 1, 1])output value: tensor([[[[ 9.]],[[18.]]]], grad_fn=<MaxPool2DWithIndicesBackward>)feature maps shape: torch.Size([1, 2, 2, 2])output value: tensor([[[[ 9.,9.],[ 9.,9.]],[[18., 18.],[18., 18.]]]], grad_fn=<ThnnConv2DBackward>)input shape: torch.Size([1, 1, 4, 4])input value: (tensor([[[[1., 1., 1., 1.],[1., 1., 1., 1.],[1., 1., 1., 1.],[1., 1., 1., 1.]]]]),)torch.Tensor.register_forward_pre_hook()功能:注册 module 的前向传播前的hook函数,可用于获取输入数据 。
hook函数:
hook(module, input)参数:
  • module:当前网络层
  • input:当前网络层输入数据
torch.Tensor.register_backward_hook()功能:注册 module 的反向传播的hook函数,可用于获取梯度 。
hook函数:
hook(module, grad_input, grad_output)参数:
  • module:当前网络层
  • input:当前网络层输入的梯度数据
  • output:当前网络层输出的梯度数据
代码如下:
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 2, 3)self.pool1 = nn.MaxPool2d(2, 2)def forward(self, x):x = self.conv1(x)x = self.pool1(x)return xdef forward_hook(module, data_input, data_output):fmap_block.append(data_output)input_block.append(data_input)def forward_pre_hook(module, data_input):print("forward_pre_hook input:{}".format(data_input))def backward_hook(module, grad_input, grad_output):print("backward hook input:{}".format(grad_input))print("backward hook output:{}".format(grad_output))# 初始化网络net = Net()net.conv1.weight[0].detach().fill_(1)net.conv1.weight[1].detach().fill_(2)net.conv1.bias.data.detach().zero_()# 注册hookfmap_block = list()input_block = list()net.conv1.register_forward_hook(forward_hook)net.conv1.register_forward_pre_hook(forward_pre_hook)net.conv1.register_backward_hook(backward_hook)# inferencefake_img = torch.ones((1, 1, 4, 4))# batch size * channel * H * Woutput = net(fake_img)loss_fnc = nn.L1Loss()target = torch.randn_like(output)loss = loss_fnc(target, output)loss.backward()


推荐阅读