• 使用pytorch搭建MobileNetV2并基于迁移学习训练


    MobileNetV2网络结构如下,网络的详细讲解参考博客:MobileNet系列(2):MobileNet-V2 网络详解
    在这里插入图片描述

    图1 MobileNet V2网络架构

    从表格的网络结构可以看出,模型基本上就是堆叠倒残差结构(bottleneck),然后通过1x1的普通卷积核操作,紧接着是池化核为7x7的平均池化下采样,最后通过1x1卷积得到最终的输出。搭建该网络的关键是倒残差结构,只要构建好倒残差结构,就能很方便对网络进行搭建了。

    pytorch 网络搭建

    model.py文件中,首先定义网络的基础组件。
    mobilenet v2网络中卷积基本上都是通过:Conv+BN+ReLU6组成的。

    卷积组件

    Conv+BN+ReLU6

    class ConvBNReLU(nn.Sequential):
    	def __init__(self,in_channel,out_channel,kernel_size,stride=1,groups=1):
    		padding=(kernel_size-1) // 2
    		super(ConvBNReLU,self).__init__(
    			nn.Conv2d(in_channel,out_channel,kernel_size,stride,padding,groups=groups,bias=False),
    			nn.BatchNorm2d(out_channel),
    			nn.ReLU6(inplace=True)
    		)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    注意groups=1表示构建的是普通的卷积,如果groups等于in_channel,那么它就是DW卷积。由于要使用BN层,因此bias是不使用的,设置为False

    倒残差结构

    定义一个InvertedResidual类,它继承与nn.Moudle这个父类。倒残差结构网络图如下:
    在这里插入图片描述
    倒残差网络结构与普通的残差网络结构是类似的,普通残差结构是两头粗中间细的结构,倒残差结构相反是两头细中间粗的结构。详见:MobileNet系列(2):MobileNet-V2 网络详解,DW卷积的个数是个输入channel是一样的,每个DW卷积层只负责一个channel.所以经过DW卷积后不改变channel的大小。

    class InvertedResidual(nn.Module):
    	def __init__(self,in_channel,out_channel,stride,expand_ratio):
    		super(InvertResidual,self).__init__()
    		hidden_channel=in_channel*expand_ratio
    		self.use_shotcut = stride ==1 and in_channel==out_channel
    		layers= []
    		if expand_ratio !=1:
    			# 1x1 Conv
    			layers.append(ConvBNReLU(in_channel,hidden_channel,kernel_size=1))
    		layers.extend([
    			# 3x3 depthwise conv
    			ConvBNReLU(hidden_channel,hidden_channel,stride=stride,groups=hidden_channel)
    			# 1x1 Conv (linear)
    			nn.Conv2d(hidden_channel,out_channel,kernel_size=1,bias=False)
    			nn.BatchNorm2d(out_channel)
    		])
    		self.conv=nn.Sequential(*layers)
    	
    	def forward(self,x):
    		if self.use_shotcut:
    			return x+ self.conv(x)
    		else:
    			return self.conv(x)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23

    MobileNet V2网络结构

    定义MobileNetV2类,继承nn.Module, 完整网络搭建代码如下:

    class MobileNetV2(nn.Module):
    	def __init__(self,num_classes=100,alpha=1.0,round_nearest=8):
    		super(MobileNetV2,self).__init__()
    		block=InvertedResidual
    		input_channel=_make_divisible(32*alpha,round_nearest)
    		last_channel=_make_divisible(1280*alpha,round_nearest)
    		
    		inverted_residual_setting = [
    			# t,c,n,s
    			[1,16,1,1],
    			[6,24,2,2],
    			[6,32,3,2],
    			[6,64,4,2],
    			[9,96,3,1],
    			[6,160,3,2],
    			[6,320,1,1]
    		]
    		
    		features = []
    		# conv1 layer
    		features.append(ConvBNReLU(3,input_channel,stride=2))
    		# build inverted residual blocks
    		for t,c,n,s in inverted_residual_setting:
    			# 通过_make_divisible将卷积核的个数调整为round_nearest的整数倍
    			output_channels= _make_divisible(c*alpha,round_nearest)
    			for i in range(n):
    				
    				stride= s if i==0 else 1 
    				features.append(block(input_channel,output_channel,stride,expand_ratio=t))
    				input_channel=output_channel
    		# building last several layers
    		features.append(ConvBNReLU(input_channel,last_channel,1))
    		#combine feature layers
    		self.features=nn.Sequential(*features)  
    		
    		#building classifier
    		self.avgpool=nn.AdaptiveAvgPool2d((1,1))
    		self.classifier = nn.Sequential(
    			nn.Dropout(0.2),
    			nn.Linear(last_channel,num_classes)
    		)
    		
    		# weight initialization
    		for m in self.modules():
    			if isinstance(m,nn.Conv2d):
    				nn.init.kaiming_normal_(m.weight,mode='fan_out')
    				if m.bias is not None:
    					m.init.zeros_(m.bias)
    				elif isinstance(m,nn.BatchNorm2d):
    					nn.init.ones_(m.weight,0,0.01)
    					nn.init.zeros_(m.bias)
    	# 正向传播过程
    	def forward(self,x):
    		x=self.features(x)
    		x=self.avgpool(x)
    		x=torch.flatten(x,1)
    		x=self.classifier(x)
    		return x
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58

    其中_make_divisible函数l来源于tensorflow官方实现的代码:

    def _make_divisible(ch,divisor=8,min_ch=None):
    	"""
    	https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    	"""
    	if min_ch is None:
    		min_ch=divisor
    	new_ch=max(min_ch,int(ch+divisor/2)//divisor*divisor)
    	#Make sure that round down dose not go down by more than 10%
    	if new_ch <0.9 * ch:
    		new_ch +=divisor
    	return new_ch
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    模型训练

    首先说下,如何去下载官方的预训练模型参数。比如下载mobilenet的预训练模型

    import torchvision.models.mobilenet
    
    • 1

    点击torchvision.models.mobilenet进入官方的函数定义中,这里有一个model_urls,这个url就是模型的预训练权重的下载链接:

    model_urls= {
    	'mobilenet_v2':'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth'
    }
    
    • 1
    • 2
    • 3

    复制模型url到迅雷进行下载,下载后存在当前项目目录下,并命名:mobilenet_v2.pth

    训练脚本

    train.py

    1. import python 包

    import torch
    import torch.nn as nn
    from torchvision import transforms,datasets
    import json
    import os
    import torch.optim as optim
    from model import MobileNetV2
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    2. 数据准备

    data_transform= {
    	"train": transforms.Compose([transforms.RandomResizeCrop(224),
    								transforms.RandomHorizontalFlip(),
    								transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]),
    	"val":transforms.Compose([transforms.Resize(256),
    							  transforms.CenterCrop(224),
    							  transforms.ToTensor(),
    							  transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
    }
    
    data_root = os.path.abspath(os.path.join(os.getcwd(),'../..')) #get data root path
    image_path=data_root +"/data_set/flower_data/" #flower data set path
    
    train_dataset = datasets.ImageFolder(root=image_path + "train",transform=data_transform["train"])
    train_num=len(train_dataset)
    
    #{'daisy':'0','dandelion`:1,'roses':2,'sunflower':3,'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict =dict((val,key) for key,value in flower_list.items())
    # write dict into json file
    json_str=json.dumps(cla_dict,indent=4)
    with open('class_indices.json','w') as json_file:
    	json_file.write(json_str)
    
    bath_size=16
    train_loader=torch.utils.data.DataLoader(train_dataset,
    										batch_size=batch_size,shuffle=True,
    										num_workers=0)
    validate_data=datasets.ImageFolder(root=image_path + "val",
    								  transform=data_transform["val"])
    val_num=len(validate_dataset)
    validate_loader=torch.utils.data.DataLoader(validate_dataset,
    											batch_size=batch_size,shuffle=False,
    											num_works=0)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34

    3. 加载模型

    net=MobileNetV2(num_classes=5)
    model_weight_path="./mobilenet_v2.pth"
    # load pretrain weights
    assert os.path.exists(model_weight_path), "file {} dose not exist.".format(model_weight_path)
    pre_weights = torch.load(model_weight_path, map_location=device)
    # delete classifier weights
    pre_dict=={k:v for k,v in pre_weights.items() if net.state_dict()[k].numel() == v.numel()}
    # strict = False 表示仅读取可以匹配的权重
    missing_keys,unexpected_keys=net.load_state_dict(pre_dict,strict=False)
    
    # freeze features weights
    for param in net.features.parameters():
    	param.requires_grad=False
    net.to(device)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    4. 模型的训练

    # define loss function
    loss_function=nn.CrossEntropyLoss()
    
    # construct an optimizer
    params=[p for p in net.parameters() if p.requires_grad]
    optimizer=optim.Adam(params,lr=0.0001)
    
    best_acc=0.0
    save_path='./MobileNetV2.pth'
    train_steps = len(train_loader)
    
    for epoch in range(epochs):
    	#train
    	net.train()
    	running_loss=0.0
    	train_bar=tqdm(train_loader)
    	for step,data in enumerate(train_bar):
    		images,labels=data
    		optimizer.zero_grad()
    		logits=net(images.to(device))
    		loss=loss_function(logits,labels.to(device))
    		loss.backward()
    		optimizer.step()
    		
    		# print statistics
    		running_loss +=loss.item()
    		
    		train_bar.desc="train epoch [{} / {}] loss:{:.3f}".format(epoch+1,epochs,loss)
    	
    	#validate
    	net.eval()
    	acc=0.0 #accumulate accurate number / epoch
    	with torch.no_grad():
    		val_bar=tqdm(validate_loader)
    		for val_data in val_bar:
    			val_images,val_labels=val_data
    			outputs = net(val_images.to(device))
    			# loss = loss_function(outputs,test_labels)
    			predict_y= torch.max(outputs,dim=1)[1]
    			acc += torch.eq(predict_y,val_labels.to(device)).sum().item()
    			
    			val_bar.desc ="valid epoch [{}/{}]".format(epoch+1,epochs)
    		val_accurate = acc / val_num
            print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
                  (epoch + 1, running_loss / train_steps, val_accurate))
    
            if val_accurate > best_acc:
                best_acc = val_accurate
                torch.save(net.state_dict(), save_path)
       print('Finished Training')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
  • 相关阅读:
    微服务开发系列 第五篇:Redis
    【JavaSE专栏20】浅谈Java中的正则表达式的应用场景
    The WebSocket session [x] has been closed and no method (apart from close())
    网络安全学习--密码学基础
    聊 · Flutter
    Canny算子详解及例程
    亚太C题详细版思路修改版(精)
    Python基本语法(1)注释,基本数据类型
    JVM判断对象是否存活之引用计数法、可达性分析
    Redis 群集模式
  • 原文地址:https://blog.csdn.net/weixin_38346042/article/details/125358925