• 使用 PyTorch 搭建网络 - train_py篇


    train_py篇碰到的问题

    python中采用驼峰书写法且首字母大写的变量符号一般表示类名。

    学习网络步骤:看原论文+看别人对原论文的理解,学习网络结构,看损失函数计算,看数据集,看别人写的代码,复现代码。

    经历以上步骤我们便可以选择合适的框架复现代码,这里使用PyTorch复现网络结构。我们用PyTorch搭建网络可以分为以下几个module,数据处理dataloader.py网络模型model.py,训练模块train.py,工具模块utils.py,预测模块predict.py。接下来我们将以LeNet为例进行讲解。

    由于LeNet使用了CIFAR10数据集,所以我们直接用内置方法生成dataset便可。

    目录如下:

    • train.py 模块综述
    • torch.utils.data.DataLoader
    • iter()和next()
    • torch.nn.CrossEntropyLoss类
    • torch.optim.Adam类
    • python中enumerate()方法
    • torch.optim.Adam.zero_grad()方法
    • FP,BP
    • 待解决问题
    • 源码

    train_py 综述

    在train_py模块中需要导入torch.utils.data.DataLoader

    需要导入自己写的model.LeNet

    torch.utils.data.DataLoader

    DataLoader的作用是接收一个dataset对象,并生成一个DataLoader对象,它的函数声明如下:

    torch.utils.data.DataLoader(dataset, batch_size=1, 
    							shuffle=None, sampler=None, batch_sampler=None, num_workers=0, 
    							collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, 
    							*, prefetch_factor=2, persistent_workers=False, pin_memory_device='')
    
    • 1
    • 2
    • 3
    • 4

    其实我们只要知道DataLoader接收一个dataset对象并生成一个DataLoader对象便可,我们需要指定DataLoader中的dataset对象,batch_size每一次迭代(一个epoch)导入的图片的个数,batch_size由硬件设备显存决定,一般batch_size越大训练效果越好,shuffle是否打乱,num_workers载入数据的线程数(在linux下可以定义,在windows下设置为0)。

    iter()和next()

    iter()和next()事python自带的函数,iter() 函数接收一个支持迭代的集合对象(注意list不是迭代器),返回一个迭代器。object – 支持迭代的集合对象,函数定义如下:

    iter(object[, sentinel])
    
    • 1

    next()会调用迭代器的下一个元素。

    torch.nn.CrossEntropyLoss类

    指定损失函数为CrossEntropyLoss,它的函数定义如下:

    torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', label_smoothing=0.0)
    
    • 1

    往往不需要传实参,直接默认值便可。

    torch.optim.Adam类

    定义优化器为Adam优化器,函数声明如下:

    torch.optim.Adam(params, lr=0.001, 
    				betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False, *, foreach=None, maximize=False, capturable=False)
    
    • 1
    • 2

    需要指定paramslr参数,其中params处往往传入网络的全部参数net.parameters(),torch.nn.Module继承而来的方法,使传递全部参数,指定初始学习率lr=0.001

    python中enumerate()方法

    enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,以tuple类型返回。一般用在 for 循环当中,常见案例如下:

    >>> seq = ['one', 'two', 'three']
    >>> for i, element in enumerate(seq):
    ...     print i, element
    ...
    0 one
    1 two
    2 three
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    dataloader 是个迭代器,故将dataloader传入后会输出每一次迭代的batch,如这个数据集每次的batch就是一个四维tensor(images)和一个一维tensor(labels)。

    torch.optim.Adam.zero_grad()方法

    没一次batch的images进行处理后都需要调用该方法清除梯度,每一次batch要累加loss。每一次epoch清除loss。但每一次epoch和batch都会更新net.parameters()

    FP,BP

    # 每一次batch都要做一次这个
    outputs = net(inputs)	# 正向传播计算y估
    loss 	= loss_function(outputs, lables)	# 应该又是一个回调函数,计算损失函数的值,即y估和y的残差
    loss.backward()			# 误差的反向传播,这一步骤和下一步骤才是BP的完整算法-更新参数(即计算偏导,给参数赋值)
    optimizer.step()		# 更新参数 update parameters in net
    running_loss += loss.item()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    loss是一个tensor(1.1)的0维tensor,使用item()将其转换为标量。

    with torch.no_grad()

    在验证集中计算准确值时候不进行梯度的自动更新。

    predict_y = torch.max(outputs, dim=1)[1]

    最终经过网络输出的outputs是一个[batch, labels]的[10000, 10]的Tensor,torch.max(outputs, dim=1)指我们对outsputs的第一个维度(10个数据中)取最大值,torch.max(outputs, dim=1)[1]指的是将所取数的序列号(0-9)返回给predict_y。所以predict_y是一个10000的Tensor。

    accuracy = (predict_y == test_label).sum().item() / test_label.size(0)

    predict_y和test_label都是torch.Size([10000])的tensor,索引tensor([])的方法可以看做list[0, 1]用索引调值,索引torch.Size([])的方法需要使用tensor.size(0)索引值。特殊的,对于tensor(1)这样的0维向量,则使用.item()方法将其转换为数值。

    .sum()语句计算predict_y和test_label中相等元素的个数,返回一个0维的tensor(12)变量,使用.item()方法获取它的数值。使用test_label.size(0)获得test_label在第一维度的值(10000)。相除即得accuracy准确率,precision是精度。

    vscode stepinto 不能进入代码

    launch.json里添加

    "purpose":["debug-in-terminal"]

    待解决疑问

    running_loss += loss.item()
    predict_y = torch.max(outputs, dim=1)[1]
    # mac vscode为何不能debug进别人代码
    
    • 1
    • 2
    • 3

    源码

    # add path
    import os, sys
    project_path = os.path.dirname(os.path.dirname(__file__))
    sys.path.append(project_path)
    # add other's package
    import torch
    import torch.utils.data
    import torch.nn
    import torch.optim
    # add my package
    import utils.recite_dataloader
    import utils.recite_model
    
    
    train_loader = torch.utils.data.DataLoader(
        utils.recite_dataloader.LeNetDataSet(os.path.join(project_path, "data")).train_set,
        batch_size=36, shuffle=True, num_workers=0,
    )
    test_loader  = torch.utils.data.DataLoader(
        utils.recite_dataloader.LeNetDataSet(os.path.join(project_path, "data")).test_set,
        batch_size=10000, shuffle=False, num_workers=0,
    )
    test_data_iter = iter(test_loader)
    test_image, test_label = test_data_iter.next()
    
    # show img
    def imshow(img):
        pass
    
    net = utils.recite_model.LeNet()
    loss_function = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr = 0.001)
    
    for epoch in range(1):  # loop
        running_loss = 0.0
        for step, data in enumerate(train_loader, start=0):
            inputs, labels = data   # [images, labels], data is a list
            # print(type(step), type(data), type(inputs))
            # print("step is {}, and inputs/images shape is {}, and labels shape is {}".format(step, inputs.shape, labels.shape))
            optimizer.zero_grad()
            outputs = net(inputs)
            # print("outputs shape is {}".format(outputs.shape))
            loss    = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()    # update parameters in net
    
            running_loss += loss.item()
            if step % 500 == 499:
                with torch.no_grad():   # 在下面的操作中不会自动计算梯度
                    outputs = net(test_image)   # [batch, 10]
                    # print("test_image's shape is {}, and outputs' shape is {}".format(test_image.shape, outputs.shape))
                    predict_y = torch.max(outputs, dim=1)[1]
                    accuracy = (predict_y == test_label).sum().item() / test_label.size(0) 
    
                    print("[%d, %5d] train_loss: %.3f test_accuracy: %.3f" %(epoch + 1, step + 1, running_loss / 500, accuracy))
                    running_loss = 0.0
    
    
    save_path = os.path.join(project_path, "utils", "te")
    torch.save(net.state_dict(), save_path)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
  • 相关阅读:
    基于langchain的开源大模型应用开发1
    修复dinput8.dll文件的缺失,以及修复dinput8.dll文件时需要注意什么
    字节跳动端智能工程链路 Pitaya 的架构设计
    【HTML专栏4】常用标签(标题、段落、换行、文本格式化、注释及特殊字符)
    【考研】数据结构考点——希尔排序
    windows mysql安装
    从0搭建vue3组件库: Input组件
    Mybatis简介
    原生小程序,手机键盘弹出会将输入框文本顶上去
    报错信息Missing unknown database driver(MySQLdb模块)
  • 原文地址:https://blog.csdn.net/qq_43369406/article/details/127440690