• 【Pytorch深度学习开发实践学习】B站刘二大人课程笔记整理lecture09 Softmax多分类


    在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述
    在这里插入图片描述

    在这里插入图片描述
    在这里插入图片描述

    在这里插入图片描述
    在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述在这里插入图片描述

    在这里插入图片描述
    在这里插入图片描述

    在这里插入图片描述
    在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    在这里插入图片描述
    在这里插入图片描述

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述

    代码:

    import torch
    from torchvision import datasets, transforms
    from torch.utils.data import DataLoader
    import torch.nn as nn
    import torch.nn.functional as F
    
    batch_size = 64
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) #把原始图像转为tensor  这是均值和方差
    
    train_set = datasets.MNIST(root='./data/mnist', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    
    test_set = datasets.MNIST(root='./data/mnist', train=False, download=True, transform=transform)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)
    
    class LinearClassifier(nn.Module):
        def __init__(self):
            super(LinearClassifier, self).__init__()
            self.l1 = nn.Linear(784, 512)
            self.l2 = nn.Linear(512, 256)
            self.l3 = nn.Linear(256, 128)
            self.l4 = nn.Linear(128, 64)
            self.l5 = nn.Linear(64, 10)
    
    
        def forward(self, x):
            x = x.view(-1, 784)
            x = F.relu(self.l1(x))
            x = F.relu(self.l2(x))
            x = F.relu(self.l3(x))
            x = F.relu(self.l4(x))
            return self.l5(x)
    
    criterion = nn.CrossEntropyLoss()
    model = LinearClassifier()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    def train(epoch):
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            if i % 300 == 299:    # print every 300 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 300))
                running_loss = 0.0
    
    def test(epoch):
        correct = 0
        total = 0
        with torch.no_grad():
            for data in test_loader:
                images, labels = data
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
    
        print('Accuracy of the network on the 10000 test images: %d %%' % (
            100 * correct / total))
    
    if __name__ == '__main__':
        for epoch in range(100):
            train(epoch)
            if epoch % 10 == 0:
                test(epoch)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
  • 相关阅读:
    DNS入门学习:什么是TTL值?如何设置合适的TTL值?
    【39元用上Rockchip linux 1.5G双核开发板】-[板载Flash烧写镜像系统]-幸狐Luckfox Pico-超越树莓派PICO
    设备树——dtb格式到struct device node结构体的转换
    Java WebSocket框架
    【vue】 vue2 监听滚动条滚动事件
    LabVIEW样式检查表1
    酷睿处理器型号前面的字母代表什么
    python记录-01
    AMAZON LINUX/CENTOS 部署PYTHON+MYSQL+DJANGO项目流程记录
    元学习基础理解
  • 原文地址:https://blog.csdn.net/weixin_44184852/article/details/136253559