如何从PyTorch中获取过程特征图实例详解
ViperL1 人气:0一、获取Tensor
神经网络在运算过程中实际上是以Tensor为格式进行计算的,我们只需稍稍改动一下forward函数即可从运算过程中抓到Tensor
代码如下:
base_feature = self.extractor.forward(x) #正常的前向传递 feature=base_feature.detach() #抓取tensor feature_imshow(feature) #展示函数(关键代码)
通过将过程张量赋值给一个临时变量,即可将其从前向传递中分离出来且不影响原来的前向传递函数,这种方法远比复杂的hook函数更实用。
将Tensor数据取到后到可视化还需要进行以下几步:
①类型转换
如果网络是在cuda中进行运算,则需要将提取到的tensor转换为cpu类型才能进行接下来的运算
inp = inp.cpu() #类型转换
②张量拆解
网络中的张量一般是高维度的,需要对其进行降维,一般降至两维即可进行显示。这里以Faster R-CNN中的resnet50特征提取网络为例:输出其特征图尺寸为:[1,1024,68,38],可以很明显的看出,第一维实际上是batch_size,在图像显示中不需要,可以直接去除;第二维1024则是网络提取到的特征图张数,故可以对第二维进行遍历;而第3,4维是特征图的尺寸,直接显示即可。
inp=inp.squeeze(0) #除去第一维 for i in range(len(inp)): plt.imshow(transforms.ToPILImage()(inp[i])) #遍历第二维并将其转换为图像
③图像展示
选取你需要的特征图像,进行保存或使用plt进展示
完整的展示函数如下:
def feature_imshow(inp, title=None): inp = inp.cpu() inp=inp.squeeze(0) print(inp.shape) plt.figure(figsize=(12, 7)) for i in range(len(inp)): plt.subplot(4, 5, i+1) #第一二个参数为图像个数,第三参数为图像位置 plt.imshow(transforms.ToPILImage()(inp[i])) i+=1 plt.show() plt.pause(0.001)
总结
加载全部内容