• MobileNet系列(5):使用pytorch搭建MobileNetV3并基于迁移学习训练


    本博文实现的MobileNetV3的代码,参考pytorch官方实现的mobilenet源码

    在这里插入图片描述

    MobileNetV3 网络结构表

    详细的MobileNetV3 网络讲解,参考博文:MobileNet系列(4):MobileNetv3网络详解

    代码详解

    打开model_v3.py文件

    • 将channel调整到离它最近的8的整数倍
    def _make_divisible(ch,divisor=8,min_ch=None):
    	if min_ch is None:
    		min_ch=divisor
    	new_ch=max(min_ch,int(ch+divisor/2)//divisor*divisor)
    	if new_ch <0.9*ch:
    		new_ch += divisor
    	return new_ch
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • V3网络的卷积结构
      在V3网络中使用的卷积,基本上都是:卷积Conv +BN+激活函数,这里定义一个卷积类ConvBNActivation
    class ConvBNActivation(nn.Sequential):
    	def __init__(self,
    				in_planes:int,
    				out_planes:int,
    				kernel_size:int =3,
    				stride:int =1,
    				groups:int=1,
    				norm_layer:Optional[Callable[...,nn.Module]]=None,
    				activation_layer:Optional[Callable[...,nn.Module]]=None
    				):
    		padding=(kernel_size-1)//2
    		if norm_layer is None:
    			norm_layer =nn.BatchNorm2d
    		if activation_layer is None:
    			activation_layer = nn.ReLU6
    		super(ConvBNActivation,self).__init__(nn.Conv2d(in_channels=in_planes,
    														out_channels=out_planes,
    														kernel_size=kernel_size,
    														stride=stride,
    														padding=padding,
    														groups=groups,
    														bias=False,
    														norm_layer(out_planes),
    														activation_layer))		
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • SE模块
      在这里插入图片描述
    • 在之前博客中,我们有讲到过SE模块,其实是两个全连接层。对于第一个全连接层,它的节点个数是输入特征矩阵channel的1/4,而第二个全链接层它的节点个数与我们输入特征矩阵的channel是保持一致的。
      注意:第一个卷积层它的激活函是ReLU,第二个卷积层它的激活函数使h-sigmoid激活函数。

    SE模块的代码如下:

    class SqueezeExcitaion(nn.Module):
    	def __init__(self,input_c:int,squeeze_factor:int=4):
    		super(SqueezeExcitaion,self).__init__()
    		squeeze_c=_make_divisible(input_c//squeeze_factor,8)
    		self.fc1 = nn.Conv2d(input_c,squeeze_c,1)
    		self.fc2 = nn.Conv2d(squeeze_c,input_c,1)
    	
    	def forward(self,x:Tensor)  -> Tensor:
    		scale = F.adaptive_avg_pool2d(x,output_size=(1,1))
    		scale =self.fc1(scale)
    		scale=F.relu(scale,inplace=True)
    		scale=self.fc2(scale)
    		scale=F.hardsigmoid(scale,inplace=True)
    		return scale * x	
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    网络配置InvertedResidualConfig
    针对MobileNetv3每一层的参数,参数如下表所示:
    在这里插入图片描述
    网络参数配置类:InvertedResidualConfig

    class InvertedResidualConfig:
    	def __init__(self,
    				input_c:int,
    				kernel:int,
    				expanded_c:int,
    				out_c:int,
    				use_se:bool,
    				activation:str,
    				stride:int,
    				width_multi:float) #width_multi 卷积层使用channel的倍率因子
    			self.input_c=self.adjust_channels(input_c,width_multi)
    			self.kernel=kernel
    			self.expanded_c=self.adjust_channels(expanded_c,width_multi)
    			self.out_c=self.adjust_channels(out_c,width_multi)
    			self.use_se=use_se
    			self.use_hs=ativation=="HS" #whether using h-switch activation
    			self.stride=stride
    	@staticmethod
    	def adjust_channels(channels:int,width:float):
    		return _make_divisivle(channels*width_multi,8)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    bneck模块
    mobilenectv3网络是由一系列的bneck堆叠形成的,bneck模块详解参考之前博客
    在这里插入图片描述

    创建类InvertedResidual继承nn.Module

    class InvertedResidual(nn.Module):
    	def __init__(self,
    				cnf:InvertedResidualConfig,
    				norm_layer:Callable[...,nn.Module]):
    		super(InvertedResidual,self).__init__()
    		
    		if cnf.stride not in [1,2]:
    			raise ValueError("illegal stride value.")
    		self.use_res_connect=(cnf.stride ==1 and cnf.input_c=cnf.out_c)
    		layers:List[nn.Module] = []
    		# 使用nn.Hardswish pytorch版本需要1.7或以上
    		activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU 
    		
    		#expand
    		#网络结构中,第一个bneck的 input channel和exp size大小相等,即没有通过1x1卷积升维,因此第一个bneck没有1x1卷积
    		if cnf.expanded_c ! = cnf.input_c: #相等没有1x1卷积升维,不等表示有1x1卷积升维
    			layers.append(ConvBNActivation(cnf.input_c,
    											cnf.expanded_c,
    											kernel_size=1,
    											norm_layer=norm_layer,
    											activation_layer=activation_layer
    										))
    		# depthwise
    		layers.append(ConvBNActivation(
    									 cnf.expanded_c,
    									 cnf.expanded_c,
    									 kernel_size=cnf.kernel,
    									 stride=cnf.stride,
    									 groups=cnf.expanded_c, # depthwise卷积 groups=通道数
    									 norm_layer=norm_layer,
    									 activation_layer=activation_layer
    									))
    		if cnf.use_se:
    			layers.append(SqueezeExcitaion(cnf.expaned_c))
    		
    		# 1x1降维的卷积层
    		layers.append(ConvBNActivation( cnf.expanded_c,
    										cnf.out_c,
    										kernel_size=1,
    										norm_layer=norm_layer,
    										activation_layer=nn.Identity #线性激活
    										))
    		self.block=nn.Sequential(*layers)
    		self.out_channel=cnf.out_c
    	
    	def forward(self,x:Tensor) -> Tensor
    		result=self.block(x)
    		if self.use_res_connect:
    			result +=x
    		return result
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 构建MobileNetV3
    class MobileNetV3(nn.Module):
    	def __init__(self,
    				inverted_residual_setting:List[InvertedResidualConfig],
    				last_channel:int,
    				num_classes:int =1000,
    				block:Optional[Callable[...,nn.Module]]=None,
    				norm_layer:Optional[Callable[...,nn.Module]]=None
    				):
    			super(MobileNetVe,self).__init__()
    			
    			if not inverted_residual_setting:
    				raise ValueError("The inverted_residual_setting should not be empty")
    			elif not (isinstance(inverted_residual_setting,List) and
    			         all([isInstance(s,InvertedResidualConfig) for s in inverted_residual_setting]))
    			     raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]")
    			
    			if block is None:
    				block = InvertedResidual
    			
    			if norm_layer is None:
    				norm_layer=partial(nn.BatchNorm2d,eps=0.001,momentum=0.01)
    			
    			layers:List[nn.Module] = []
    			
    			# building first layer
    			fisrtconv_output_c =inverted_residual_setting[0].input_c
    			layers.append(ConvBNActivation(3,
    											firstconv_output_c,
    											kernel_size=3,
    											stride=2,
    											norm_layer=norm_layer,
    											activation_layer=nn.Hardwish))
    			# building inverted residual block
    			for cnf in inverted_residual_setting:
    				layers.append(block(cnf,norm_layer))
    			
    			# building last several layers
    			lastconv_input_c=inverted_residual_setting[-1].out_c
    			lastconv_output_c=6*lastconv_input_c
    			layers.append(ConvBNActivation(lastconv_input_c,
    										  lastconv_output_c,
    										  kernel_size=1,
    										  norm_layer =norm_layer,
    										  activation_layer=nn.Hardswish
    		    ))
    		    self.features =nn.Sequential(*layers)
    		    self.avgpool=nn.AdaptiveAvgPool2d(1)
    		    self.classifier=nn.Sequential(nn.Linear(lastconv_output_c,last_channel),
    		    							 nn.Hardswish(inplace=True),
    		    							 nn.Dropout(p=0.2,inplace=True),
    		    							 nn.Linear(last_channel,num_classes))
    		    
    		    # initial weights
    		    for m in self.modules():
    		    	if isinstance(m.nn.Conv2d):
    		    		nn.init.kaiming_normal_(m.weight,mode="fan_out")
    		    		if m.bias is not None:
    		    			nn.init.zeros_(m.bias)
    		    		elif isinstance(m,(nn.BatchNorm2d,nn.GroupNorm)):
    		    			nn.init.ones_(m.weight)
    		    			nn.init.zers_(m.bias)
    		    		elif isinstance(m,nn.Linear):
    		    			nn.init.normal_(m.weight,0,0.01)
    		    			nn.init.zeros(m.bias)
    	def _forward_impl(self,x:Tensor) ->Tensor:
    		x=self.features(x)
    		x=self.avgpool(x)
    		x=torch.flatten(x,1)
    		x=self.classifier(x)
    		return x
    	
    	def forward(self,x:Tensor) ->Tensor:
    		return self._forward_impl(x)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 构建MobileNetV3 Large
    def mobilenet_v3_large(num_classes: int = 1000,
                           reduced_tail: bool = False) -> MobileNetV3:
        """
        Constructs a large MobileNetV3 architecture from
        "Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>.
    
        weights_link:
        https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth
    
        Args:
            num_classes (int): number of classes
            reduced_tail (bool): If True, reduces the channel counts of all feature layers
                between C4 and C5 by 2. It is used to reduce the channel redundancy in the
                backbone for Detection and Segmentation.
        """
        width_multi = 1.0
        bneck_conf = partial(InvertedResidualConfig, width_multi=width_multi)
        adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_multi=width_multi)
    
        reduce_divider = 2 if reduced_tail else 1
    
        inverted_residual_setting = [
            # input_c, kernel, expanded_c, out_c, use_se, activation, stride
            bneck_conf(16, 3, 16, 16, False, "RE", 1),
            bneck_conf(16, 3, 64, 24, False, "RE", 2),  # C1
            bneck_conf(24, 3, 72, 24, False, "RE", 1),
            bneck_conf(24, 5, 72, 40, True, "RE", 2),  # C2
            bneck_conf(40, 5, 120, 40, True, "RE", 1),
            bneck_conf(40, 5, 120, 40, True, "RE", 1),
            bneck_conf(40, 3, 240, 80, False, "HS", 2),  # C3
            bneck_conf(80, 3, 200, 80, False, "HS", 1),
            bneck_conf(80, 3, 184, 80, False, "HS", 1),
            bneck_conf(80, 3, 184, 80, False, "HS", 1),
            bneck_conf(80, 3, 480, 112, True, "HS", 1),
            bneck_conf(112, 3, 672, 112, True, "HS", 1),
            bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2),  # C4
            bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1),
            bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1),
        ]
        last_channel = adjust_channels(1280 // reduce_divider)  # C5
    
        return MobileNetV3(inverted_residual_setting=inverted_residual_setting,
                           last_channel=last_channel,
                           num_classes=num_classes)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 构建MobileNetV3 Small
    def mobilenet_v3_small(num_classes: int = 1000,
                           reduced_tail: bool = False) -> MobileNetV3:
        """
        Constructs a large MobileNetV3 architecture from
        "Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>.
    
        weights_link:
        https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth
    
        Args:
            num_classes (int): number of classes
            reduced_tail (bool): If True, reduces the channel counts of all feature layers
                between C4 and C5 by 2. It is used to reduce the channel redundancy in the
                backbone for Detection and Segmentation.
        """
        width_multi = 1.0
        bneck_conf = partial(InvertedResidualConfig, width_multi=width_multi)
        adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_multi=width_multi)
    
        reduce_divider = 2 if reduced_tail else 1
    
        inverted_residual_setting = [
            # input_c, kernel, expanded_c, out_c, use_se, activation, stride
            bneck_conf(16, 3, 16, 16, True, "RE", 2),  # C1
            bneck_conf(16, 3, 72, 24, False, "RE", 2),  # C2
            bneck_conf(24, 3, 88, 24, False, "RE", 1),
            bneck_conf(24, 5, 96, 40, True, "HS", 2),  # C3
            bneck_conf(40, 5, 240, 40, True, "HS", 1),
            bneck_conf(40, 5, 240, 40, True, "HS", 1),
            bneck_conf(40, 5, 120, 48, True, "HS", 1),
            bneck_conf(48, 5, 144, 48, True, "HS", 1),
            bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2),  # C4
            bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1),
            bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1)
        ]
        last_channel = adjust_channels(1024 // reduce_divider)  # C5
    
        return MobileNetV3(inverted_residual_setting=inverted_residual_setting,
                           last_channel=last_channel,
                           num_classes=num_classes)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40

    网络训练

    训练的图像数据集下载
    train.py

    import os
    import json
    
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torchvision import transforms, datasets
    from tqdm import tqdm
    
    from model_v2 import MobileNetV2  # 可以替换为MobileNetV3 MobileNetV3-Large MobileNetV3-Smalll
    
    
    def main():
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        print("using {} device.".format(device))
    
        batch_size = 16 #根据显存大小调整
        epochs = 5
    
        data_transform = {
            "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                         transforms.RandomHorizontalFlip(),
                                         transforms.ToTensor(),
                                         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
            "val": transforms.Compose([transforms.Resize(256),
                                       transforms.CenterCrop(224),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
    
        data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
        image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
        assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
        train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                             transform=data_transform["train"])
        train_num = len(train_dataset)
    
        # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
        flower_list = train_dataset.class_to_idx
        cla_dict = dict((val, key) for key, val in flower_list.items())
        # write dict into json file
        json_str = json.dumps(cla_dict, indent=4)
        with open('class_indices.json', 'w') as json_file:
            json_file.write(json_str)
    
        nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
        print('Using {} dataloader workers every process'.format(nw))
    
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=batch_size, shuffle=True,
                                                   num_workers=nw)
    
        validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                                transform=data_transform["val"])
        val_num = len(validate_dataset)
        validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                      batch_size=batch_size, shuffle=False,
                                                      num_workers=nw)
    
        print("using {} images for training, {} images for validation.".format(train_num,
                                                                               val_num))
    
        # create model
        net = MobileNetV2(num_classes=5)   # 可以替换为MobileNetV3 MobileNetV3-Large MobileNetV3-Smalll
    
        # load pretrain weights
        # download url: https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
        model_weight_path = "./mobilenet_v2.pth"  # 可以替换MobileNetV3 MobileNetV3-Large MobileNetV3-Smalll 的预训练权重
    
        assert os.path.exists(model_weight_path), "file {} dose not exist.".format(model_weight_path)
        pre_weights = torch.load(model_weight_path, map_location=device)
    
        # delete classifier weights
        pre_dict = {k: v for k, v in pre_weights.items() if net.state_dict()[k].numel() == v.numel()}
        missing_keys, unexpected_keys = net.load_state_dict(pre_dict, strict=False)
    
        # freeze features weights
        for param in net.features.parameters():
            param.requires_grad = False
    
        net.to(device)
    
        # define loss function
        loss_function = nn.CrossEntropyLoss()
    
        # construct an optimizer
        params = [p for p in net.parameters() if p.requires_grad]
        optimizer = optim.Adam(params, lr=0.0001)
    
        best_acc = 0.0
        save_path = './MobileNetV2.pth'
        train_steps = len(train_loader)
        for epoch in range(epochs):
            # train
            net.train()
            running_loss = 0.0
            train_bar = tqdm(train_loader)
            for step, data in enumerate(train_bar):
                images, labels = data
                optimizer.zero_grad()
                logits = net(images.to(device))
                loss = loss_function(logits, labels.to(device))
                loss.backward()
                optimizer.step()
    
                # print statistics
                running_loss += loss.item()
    
                train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                         epochs,
                                                                         loss)
    
            # validate
            net.eval()
            acc = 0.0  # accumulate accurate number / epoch
            with torch.no_grad():
                val_bar = tqdm(validate_loader)
                for val_data in val_bar:
                    val_images, val_labels = val_data
                    outputs = net(val_images.to(device))
                    # loss = loss_function(outputs, test_labels)
                    predict_y = torch.max(outputs, dim=1)[1]
                    acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
    
                    val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
                                                               epochs)
            val_accurate = acc / val_num
            print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
                  (epoch + 1, running_loss / train_steps, val_accurate))
    
            if val_accurate > best_acc:
                best_acc = val_accurate
                torch.save(net.state_dict(), save_path)
    
        print('Finished Training')
    
    
    if __name__ == '__main__':
        main()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139

    模型预测

    predict.py

    import os
    import json
    
    import torch
    from PIL import Image
    from torchvision import transforms
    import matplotlib.pyplot as plt
    
    from model_v2 import MobileNetV2
    
    
    def main():
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
        data_transform = transforms.Compose(
            [transforms.Resize(256),
             transforms.CenterCrop(224),
             transforms.ToTensor(),
             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    
        # load image
        img_path = "../tulip.jpg"
        assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
        img = Image.open(img_path)
        plt.imshow(img)
        # [N, C, H, W]
        img = data_transform(img)
        # expand batch dimension
        img = torch.unsqueeze(img, dim=0)
    
        # read class_indict
        json_path = './class_indices.json'
        assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
    
        json_file = open(json_path, "r")
        class_indict = json.load(json_file)
    
        # create model
        model = MobileNetV2(num_classes=5).to(device)
        # load model weights
        model_weight_path = "./MobileNetV2.pth"
        model.load_state_dict(torch.load(model_weight_path, map_location=device))
        model.eval()
        with torch.no_grad():
            # predict class
            output = torch.squeeze(model(img.to(device))).cpu()
            predict = torch.softmax(output, dim=0)
            predict_cla = torch.argmax(predict).numpy()
    
        print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                     predict[predict_cla].numpy())
        plt.title(print_res)
        print(print_res)
        plt.show()
    
    
    if __name__ == '__main__':
        main()
    
     - List item
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61

    模型预测效果
    在这里插入图片描述

    源码下载

  • 相关阅读:
    【新知实验室-TRTC开发】实时音视频之web端云监工系统(Vue3+Element plus+TS+Pinia)
    VAEGAN:理解 VAE 与 GAN【图像生成】
    工具篇--分布式定时任务springBoot--elasticjob使用--工作流程(2)
    001图机器学习与图神经网络简介
    剑指 Offer II 079+080+081+082
    第十九届全国环境友好科技竞赛(绿色创业类)正式启动
    zynq pl访问ps ddr
    代码解析 折半查找 二叉搜索树
    理解并解决Maven版本冲突
    Pytorch:一些常用代码
  • 原文地址:https://blog.csdn.net/weixin_38346042/article/details/125520003