• 【Pytorch深度学习开发实践学习】B站刘二大人课程笔记整理lecture10 Basic_CNN


    Pytorch深度学习开发实践学习】B站刘二大人课程笔记整理lecture10 Basic_CNN

    部分课件内容:
    在这里插入图片描述

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    在这里插入图片描述
    在这里插入图片描述

    在这里插入图片描述
    在这里插入图片描述在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述
    在这里插入图片描述

    在这里插入图片描述
    在这里插入图片描述

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    代码:

    import torch
    from torchvision import datasets, transforms
    from torch.utils.data import DataLoader
    import torch.nn as nn
    import torch.nn.functional as F
    
    batch_size = 64
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) #把原始图像转为tensor  这是均值和方差
    
    train_set = datasets.MNIST(root='./data/mnist', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    
    test_set = datasets.MNIST(root='./data/mnist', train=False, download=True, transform=transform)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)
    class Net(torch.nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)
            self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)
            self.pooling = torch.nn.MaxPool2d(kernel_size=2)
            self.fc1 = torch.nn.Linear(320, 10)
    
        def forward(self, x):
            batch_size = x.size(0)
            x = F.relu(self.pooling(self.conv1(x), ))
            x = F.relu(self.pooling(self.conv2(x), ))
            x = x.view(batch_size,-1)    # flatten
            x = self.fc1(x)
            return x
    
    model = Net()
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  #把模型迁移到GPU
    model = model.to(device)   #把模型迁移到GPU
    
    def train(epoch):
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            inputs,labels = inputs.to(device), labels.to(device)  #训练内容迁移到GPU上
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            if i % 300 == 299:    # print every 300 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 300))
                running_loss = 0.0
    
    def test(epoch):
        correct = 0
        total = 0
        with torch.no_grad():
            for data in test_loader:
                images, labels = data
                images,labels = images.to(device), labels.to(device)  #测试内容迁移到GPU上
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
    
        print('Accuracy of the network on the 10000 test images: %d %%' % (
            100 * correct / total))
    
    if __name__ == '__main__':
        for epoch in range(100):
            train(epoch)
            if epoch % 10 == 0:
                test(epoch)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
  • 相关阅读:
    安装配置MySQL5.7详细教程
    Python 操作mongodb库
    Python爬虫:selenium动态加载HTML的常用方法【汇总笔记】
    说说我的实习收获
    企业工程项目管理系统源码(三控:进度组织、质量安全、预算资金成本、二平台:招采、设计管理)
    AI高考志愿填报:大厂神仙打架,考生付费围观
    【JavaSE专栏89】Java字符串和XML数据结构的转换,高效灵活转变数据
    Ajax中什么时候用同步,什么时候用异步?
    Qt QObject Cannot create children for a parent that is in a different thread
    uniapp 小程序拍照上传,百度识别人体关键点,显示拖拽元素,生成海报
  • 原文地址:https://blog.csdn.net/weixin_44184852/article/details/136253652