可视化特征图:python读取pth模型,并可视化特征图。亲测有效。 您所在的位置:网站首页 dsc特征值 可视化特征图:python读取pth模型,并可视化特征图。亲测有效。

可视化特征图:python读取pth模型,并可视化特征图。亲测有效。

2024-07-12 02:35| 来源: 网络整理| 查看: 265

一、前言

我们有时候需要可视化特征图,尤其是发paper,或者对比算法等情况。而且通过可视化特征图,也可以让我们对这个整个cnn模型更加熟悉,废话不多说了。

二、效果图

下面我会给出代码,效果图分为单channel绘图和1:1通道特征图融合图。

我生成了很多特征图,我就简单的放两张吧,意思意思。

                                                             单通道特征图

                                                              叠加后的特征图

三、代码

我再次描述清楚我的需求以及我现有的东西,我有网络的结构和网络的预训练权重,我想通过输入图片,得到图片在网络特定层的特征图。

从main()开始看代码,我会说得详细一点,尽量让大家看懂, 这样你修改起来会方便很多。

图片保存和读取的路径相关的问题,我就不说了,这里大家应该懂。

1.首先我们看导入的包,DepthCompletionFrontNet 这是我的网络结构,首先你要搭建起的你的网络(这个得有)。

2.看main()函数,定位到get_feature()函数

3.get_feature做了下面得几个事儿,第一,读取图片,也就是要输入网络得图片(我得网络是双分支,所以是读取两个图,这里你读取一个图就行,就 img_rgb 就行,把 img_pc 相关内容注释);第二,定义网络,实例化,载入预训练权重模型;第三,定义我们要提取出得特定层,这里必须和你网络结构定义得层一模一样,一模一样,一模一样。

4.已经定义的网络结构需要进行修改,假设你网络定义的代码如下:

# 仅仅举例子,我懒得补全了,直接csdn手打的 class Net(nn.Module): super(Net,self).__init__() self.conv1 = nn.conv1 self.conv2 = nn.conv2 self.conv3 = nn.conv3 forward(self,x): x = conv1(x) x = conv2(x) x = conv3(x) return x

网络的定义不需要修改,我们需要修改下网络的 forward,加入字典 all_dict去存储每层的tensor,forward修改如下:

forward(self,x): all_dict = {} x = conv1(x) all_dict['conv1'] = x x = conv2(x) all_dict['conv2'] = x x = conv3(x) all_dict['conv3'] = x return x,all_dict

这样子就修改完成了

总结一下:首先读入模型和图片,图片在前向传播的过程中,我们通过字典保存每层的tensor,需要提取哪层,就从哪层去获取tensor,进而可视化。

大家有问题可以留言,我看到一定会回复。如可以运行,麻烦点赞下,谢谢!希望帮到大家。

 

 

完整代码如下(网络结构我的很复杂,就不放了, 网络结构修改就像上面我说的一样,你可以直接读取img_rgb,在模型的前向传播输入img_rgb,我的网络是双分支,所以我输入两个图组合的字典):

import torch import torchvision.transforms as transforms import skimage.data import skimage.io import skimage.transform import numpy as np import matplotlib.pyplot as plt from completion_segmentation_model import DepthCompletionFrontNet # from completion_segmentation_model_v3_eca_attention import DepthCompletionFrontNet import math #https://blog.csdn.net/missyougoon/article/details/85645195 # https://blog.csdn.net/grayondream/article/details/99090247 # 定义是否使用GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 定义数据预处理方式(将输入的类似numpy中arrary形式的数据转化为pytorch中的张量(tensor)) transform = transforms.ToTensor() def get_picture(picture_dir, transform): ''' 该算法实现了读取图片,并将其类型转化为Tensor ''' img = skimage.io.imread(picture_dir) img256 = skimage.transform.resize(img, (128, 256)) img256 = np.asarray(img256) img256 = img256.astype(np.float32) return transform(img256) def get_picture_rgb(picture_dir): ''' 该函数实现了显示图片的RGB三通道颜色 ''' img = skimage.io.imread(picture_dir) img256 = skimage.transform.resize(img, (256, 256)) skimage.io.imsave('4.jpg', img256) # 取单一通道值显示 # for i in range(3): # img = img256[:,:,i] # ax = plt.subplot(1, 3, i + 1) # ax.set_title('Feature {}'.format(i)) # ax.axis('off') # plt.imshow(img) # r = img256.copy() # r[:,:,0:2]=0 # ax = plt.subplot(1, 4, 1) # ax.set_title('B Channel') # # ax.axis('off') # plt.imshow(r) # g = img256.copy() # g[:,:,0]=0 # g[:,:,2]=0 # ax = plt.subplot(1, 4, 2) # ax.set_title('G Channel') # # ax.axis('off') # plt.imshow(g) # b = img256.copy() # b[:,:,1:3]=0 # ax = plt.subplot(1, 4, 3) # ax.set_title('R Channel') # # ax.axis('off') # plt.imshow(b) # img = img256.copy() # ax = plt.subplot(1, 4, 4) # ax.set_title('image') # # ax.axis('off') # plt.imshow(img) img = img256.copy() ax = plt.subplot() ax.set_title('image') # ax.axis('off') plt.imshow(img) plt.show() def visualize_feature_map_sum(item,name): ''' 将每张子图进行相加 :param feature_batch: :return: ''' feature_map = item.squeeze(0) c = item.shape[1] print(feature_map.shape) feature_map_combination=[] for i in range(0,c): feature_map_split = feature_map.data.cpu().numpy()[i, :, :] feature_map_combination.append(feature_map_split) feature_map_sum = sum(one for one in feature_map_combination) # feature_map = np.squeeze(feature_batch,axis=0) plt.figure() plt.title("combine figure") plt.imshow(feature_map_sum) plt.savefig('E:/Dataset/qhms/feature_map/feature_map_sum_'+name+'.png') # 保存图像到本地 plt.show() def get_feature(): # 输入数据 root_path = 'E:/Dataset/qhms/data/small_data/' pic_dir = 'test_umm_000067.png' pc_path = root_path+'knn_pc_crop_0.6/'+pic_dir rgb_path = root_path+'train_image_2_lane_crop_0.6/'+pic_dir img_rgb = get_picture(rgb_path, transform) # 插入维度 img_rgb = img_rgb.unsqueeze(0) img_rgb = img_rgb.to(device) img_pc = get_picture(pc_path, transform) # 插入维度 img_pc = img_pc.unsqueeze(0) img_pc = img_pc.to(device) # 加载模型 checkpoint = torch.load('E:/Dataset/qhms/all_result/v3/crop_0.6_old/hah/checkpoint-195.pth.tar') args = checkpoint['args'] print(args) model = DepthCompletionFrontNet(args) print(model.keys()) model.load_state_dict(checkpoint['model']) model.to(device) exact_list = ["conv1","conv2","conv3","conv4","convt4","convt3","convt2_","convt1_","lane"] # myexactor = FeatureExtractor(model, exact_list) img1 = { 'pc': img_pc, 'rgb': img_rgb } # print(img1['pc']) # x = myexactor(img1) result,all_dict = model(img1) outputs = [] # 挑选exact_list的层 for item in exact_list: x = all_dict[item] outputs.append(x) # 特征输出可视化 x = outputs k=0 print(x[0].shape[1]) for item in x: c = item.shape[1] plt.figure() name = exact_list[k] plt.suptitle(name) for i in range(c): wid = math.ceil(math.sqrt(c)) ax = plt.subplot(wid, wid, i + 1) ax.set_title('Feature {}'.format(i)) ax.axis('off') figure_map = item.data.cpu().numpy()[0, i, :, :] plt.imshow(figure_map, cmap='jet') plt.savefig('E:/Dataset/qhms/feature_map/feature_map_' + name + '.png') # 保存图像到本地 visualize_feature_map_sum(item,name) k = k + 1 plt.show() # 训练 if __name__ == "__main__": # get_picture_rgb(pic_dir) get_feature()

 

参考:

https://blog.csdn.net/missyougoon/article/details/85645195 https://blog.csdn.net/grayondream/article/details/99090247

 

 

 

 



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有