• 深度学习--实战 LeNet5


    深度学习--实战 LeNet5

    数据集

    数据集选用CIFAR-10的数据集,Cifar-10 是由 Hinton 的学生 Alex Krizhevsky、Ilya Sutskever 收集的一个用于普适物体识别的计算机视觉数据集,它包含 60000 张 32 X 32 的 RGB 彩色图片,总共 10 个分类。其中,包括 50000 张用于训练集,10000 张用于测试集。

    模型实现

    模型需要继承nn.module

    import torch
    from torch import nn
    class Lenet5(nn.Module):
    """
    for cifar10 dataset.
    """
    def __init__(self):
    super(Lenet5,self).__init__()
    self.conv_unit = nn.Sequential(
    #input:[b,3,32,32] ===> output:[b,6,x,x]
    #Conv2d(Input_channel:输入的通道数,kernel_channels:卷积核的数量,输出的通道数,kernel_size:卷积核的大小,stride:步长,padding:边缘补足)
    nn.Conv2d(3,6,kernel_size=5,stride=1,padding=0),
    #池化
    nn.MaxPool2d(kernel_size=2,stride=2,padding=0),
    #卷积层
    nn.Conv2d(6,16,kernel_size=5,stride=1,padding=0),
    #池化
    nn.AvgPool2d(kernel_size=2,stride=2,padding=0)
    #output:[b,16,5,5]
    )
    #flatten
    #Linear层
    self.fc_unit=nn.Sequential(
    nn.Linear(16*5*5,120),
    nn.ReLU(),
    nn.Linear(120,84),
    nn.ReLU(),
    nn.Linear(84,10)
    )
    #测试卷积输出到全连接层的输入
    #tmp = torch.rand(2,3,32,32)
    #out = self.conv_unit(tmp)
    #print("conv_out:",out.shape)
    #Loss评价 Cross Entropy Loss 分类 在其中包含一个softmax()操作
    #self.criteon = nn.MSELoss() 回归
    #self.criteon = nn.CrossEntropyLoss()
    def forward(self,x):
    """
    :param x:[b,3,32,32]
    :return:
    """
    batchsz = x.size(0)
    #[b,3,32,32]=>[b,16,5,5]
    x = self.conv_unit(x)
    #[b,16,5,5]=>[b,16*5*5]
    x = x.view(batchsz,16*5*5)
    #[b,16*5*5]=>[b,10]
    logits = self.fc_unit(x)
    return logits
    # [b,10]
    # pred = F.softmax(logits,dim=1) 这步在CEL中包含了,所以不需要再写一次
    #loss = self.criteon(logits,y)
    def main():
    net = Lenet5()
    tmp = torch.rand(2,3,32,32)
    out = net(tmp)
    print("lenet_out:",out.shape)
    if __name__ == '__main__':
    main()

    训练与测试

    import torch
    from torchvision import datasets
    from torchvision import transforms
    from torch.utils.data import DataLoader
    from lenet5 import Lenet5
    import torch.nn.functional as F
    from torch import nn,optim
    def main():
    batch_size = 32
    epochs = 1000
    learn_rate = 1e-3
    #导入图片,一次只导入一张
    cifer_train = datasets.CIFAR10('cifar',train=True,transform=transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor()
    ]),download=True)
    #加载图
    cifer_train = DataLoader(cifer_train,batch_size=batch_size,shuffle=True)
    #导入图片,一次只导入一张
    cifer_test = datasets.CIFAR10('cifar',train=False,transform=transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor()
    ]),download=True)
    #加载图
    cifer_test = DataLoader(cifer_test,batch_size=batch_size,shuffle=True)
    #iter迭代器,__next__()方法可以获得数据
    x, label = iter(cifer_train).__next__()
    print("x:",x.shape,"label:",label.shape)
    #x: torch.Size([32, 3, 32, 32]) label: torch.Size([32])
    device = torch.device('cuda')
    model = Lenet5().to(device)
    print(model)
    criteon = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(),lr=learn_rate)
    for epoch in range(epochs):
    model.train()
    for batchidx,(x,label) in enumerate(cifer_train):
    x,label = x.to(device),label.to(device)
    logits = model(x)
    #logits:[b,10]
    loss = criteon(logits,label)
    #backprop
    optimizer.zero_grad() #梯度清零
    loss.backward()
    optimizer.step() #梯度更新
    #
    print(epoch,loss.item())
    model.eval()
    with torch.no_grad():
    #test
    total_correct = 0
    total_num = 0
    for x,label in cifer_test:
    x,label = x.to(device),label.to(device)
    #[b,10]
    logits = model(x)
    #[b]
    pred =logits.argmax(dim=1)
    #[b] vs [b] => scalar tensor
    total_correct += torch.eq(pred,label).float().sum().item()
    total_num += x.size(0)
    acc = total_correct/total_num
    print("epoch:",epoch,"acc:",acc)
    if __name__ == '__main__':
    main()
  • 相关阅读:
    网络运维与网络安全 学习笔记2023.11.17
    [SECCON CTF 2022] 只两个小题pwn_koncha,rev_babycmp
    大数据教程-01HDFS的基本组成和原理
    虚拟化基本知识及virtio-net初探
    [SpringBoot]配置文件②(多环境配置,配置文件分类)
    从零开始写 Docker(十)---实现 mydocker logs 查看容器日志
    J2EE——自定义MVC框架的CRUD操作
    安信证券携手共议量化行业的赋能发展
    hive外部表加载parquet类型的数据文件
    element中el-input 输入框 在自动填充(auto-complete=“on“)时,背景颜色会自动改变问题
  • 原文地址:https://www.cnblogs.com/ssl-study/p/17349754.html