• python 可视化解释模型


    1. 自定义DataSet

    MakeDataset.py
    首先准备好一个数据集文件,这里以mydata文件夹存放图片数据, 实现自定义DataSet

    class MyDataset(Dataset):
    	def __init__(self,resize):
    		super(MyDataset,self).__init__()
    		self.resize = resize
    	
    	def __len__(self):
    		return len(images)
    	
    	def __getitem__(self,idx):
    		img = images[idx]
    		tf = transforms.Compose([
    			lambda x:Image.open(x).convert('RGB'),
    			transforms.Resize((self.resize,self.resize)),
    			transforms.ToTensor(),
    			transforms.Normalize(mean = [0.485,0.456,0.406],
    										std = [0.229,0.224,0.225])		
    		])
    		img_tensor = tf(image)
    		# `mydata\\ICH\\1470718-1.JPG`
    		label_tensor = torch.tensor(class_name_index[image.split(os.sep)[-2]])
    		return img_tensor,label_tensor
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21

    2 模型定义及训练

    2.1 模型

    这里以一个玩具模型作为演示,模型的定义如下:
    MyModle.py

    class MyResNet(nn.Module):
    	def __init__(self):
    		super(MyResNet,self).__init__()
    		general_features = 32
    		
    		# Initial convolution block
    		self.conv0 = nn.Conv2d(3,general_features,3,1,padding=1)
    		self.conv1 = nn.Conv2d(general_features,general_features,3,1,padding =1)
    		self.relu1 = nn.ReLU()
    		self.conv2 = nn.Conv2d(general_features,general_features,3,1,padding=1)
    		self.relu2 = nn.ReLU()
    		
    		# Down sample 1/2
    		self.downsample0 = nn.Maxpool2d(2,2)
    		self.downsample1 = nn.Maxpool2d(2,2)
    		self.downsample2 = nn.Maxpool2d(2,2)
    		self.downsample3 = nn.Maxpool2d(2,2)
    		
    		self.fc0 = nn.Linear(32*8*8, 2)
    	
    	def forward(self,x):
    		x = self.conv0(x)			#[1,32,128,128]
    		x = self.downsample0(x)		#[1,32,64,64]
    		x = self.downsample1(x)		#[1,32,32,32]
    		
    		x = self.relu1(self.conv1(x)) #[1,32,32,32]
    		x = self.downsample2(x)		  # [1,32,16,16]
    		
    		x = self.relu2(self.conv2(x)) #[1,32,16,16]
    		x = self.downsample3(x)		  # [1,32,8,8]
    		
    		x = x.view(x.shape[0],-1)     # Flatten
    		
    		x = x.softmax(self.fc0(x),dim=1)
    		return x
    
    # x = torch.randn(1,3,128,128)
    # m = myResNet()
    # summary(m,(3,128,128))
    # print(m(x).shape)	
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40

    2.2 训练

    训练train.py获得权重文件

    import torch
    from torch import optim,nn
    from torch.utils.data import Dataloader
    from MakeDataSet import MyDataset
    
    from MyModel import MyResNet
    
    train_db = MyDataset(resize = 128)
    train_loader = DataLoader(train_db,batch_size=4,shuffle=True)
    print('num_train:',len(train_loader.dataset))
    
    model = MyResNet()
    
    optimizer = optim.Adam(model.parameters(),lr =0.001)
    criteon = nn.CrossEntropyLoss()
    
    epochs = 5
    
    for epoch in range(epochs):
    	for step,(x,y) in enumerate(train_loader):
    		model.train()
    		logits = model(x)
    		loss = criteon(logits,y)
    		
    		optimizer.zero_grad()
    		loss.backward()
    		optimizer.step()
    	print('Epochs:',epoch,'Loss:',loss)
    
    torch.save(model.state_dict(),'weights_MyResNet.mdl')
    print('Save Done')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31

    3 利用SmoothGradCAMpp对特征图可视化

    Visualize_featrue_map, 这里介绍smooth gradcampp用法

    import torch
    from torchvision import transforms
    from torchvision.transforms.functional import to_pil_image
    from torchcam.methods import SmoothGradCAMpp,CAM,GradCAM,GradCAMpp,XGradCAM,ScoreCAM
    from torchcam.utils import overlay_mask
    from MyModel import MyResNet
    from PIL import image
    import matplotlib.pyplot import plt
    
    tf = transforms.Compose([
    		lambda x:Image.open(x).convert('RGB')
    		transforms.Resize(128,128)
    		transforms.ToTensor(),
    		transforms.Normalize(mean = [0.485,0.456,0.406],
    							std = [0.229,0.224,0.225]
    							)
    ])
    
    img_ICH_test = tf('ICH_test.jpg').unsqueeze(dim=0)
    #print(img_ICH_test.shape)
    
    img_Normal_test = tf('Normal_test.jpg').unsqueeze(dim=0)
    
    model = MyResNet()
    model.load_state_dict(torch.load('weights_MyResNet.mdl'))
    print('loaded from ckpt')
    model.eval()
    
    cam_extractor = SmoothGradCAMpp(model,input_shape=(3,128,128))
    # cam_extractor = GradCAMpp(model,input_shape=(3,128,128))
    # cam_extractor = XGradCAM(model,input_shape=(3,128,128))
    # cam_extractor = ScoreCAM(model,input_shape=(3,128,128))
    # cam_extractor = SSCAM(model,input_shape=(3,128,128))
    # cam_extractor =ISCAM(model,input_shape=(3,128,128))
    # cam_extractor = LayerCAM(model,input_shape=(3,128,128))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 载入测试图片Normal_test.jpg
      在这里插入图片描述
    • 加载预训练权重,实例化模型
    output = model(img_Normal_test)
    print(output)
    
    activation_map = cam_extractor(output.sequeeze(0).argmax().item(),output)
    print(activation_map[0],activation_map[0].min(),activation_map[0].max(),activation_map[0].shape)
    
    #fused_map = cam_extractor.fuse_cams(activation_map)
    #print(fused_map[0],fused_map[0].min(),fused_map[0].max(),fused_map[0].shape)
    
    
    result = overlay_mask(to_pil_image(img_Normal_test[0]),
    					to_pil_image(activation_map[0],mode='F'),alpha=0.3)
    plt.imshow(result)
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 将模型的输出预测类别索引送入到构建的cam_extractor对象中,由于activation_map输出的是一个tuple,通过索引0取值
    • 接着用overlay_mask进行可视化效果展示,传入原图和激活map,并利用alpha参数设置一定的透明度
    • 由于输出的result是PIL格式,所以可以直接用imshow显示
      在这里插入图片描述
      最热的区域就是模型主要依据这部分来判断类别,这里没有指定可视化feature map的哪一层的话,就默认是全连接测上一层feature map

    这个包的主页在: https://pypi.org/project/torchcam/,感兴趣的可以看看

  • 相关阅读:
    设计模式之订阅发布模式
    Spring framework Day10:JSR330注入注解
    【HBZ分享】Mysql的InnoDB原理
    《golong入门教程📚》,从零开始入门❤️(建议收藏⭐️)
    【回归预测】基于DBO-BP(蜣螂优化算法优化BP神经网络)的回归预测 多输入单输出【Matlab代码#68】
    【408数据结构与算法】—栈的抽象数据类型定义(十)
    [R] Underline your idea with ggplot2
    Java 调用Python+Opencv实现图片定位
    SpringBoot详解(二)
    简单工厂、工厂方法、抽象工厂对比
  • 原文地址:https://blog.csdn.net/weixin_38346042/article/details/127875669