• 使用 PyTorch 搭建网络 - train_py篇


    train.py

    目录如下:

    • 导包
    • train.py
    • argparse配置参数
    • main函数
    • torch.nn.CrossEntropyLoss类
    • torch.optim.Adam类
    • python中enumerate()方法
    • torch.optim.Adam.zero_grad()方法
    • FP,BP
    • 待解决问题
    • 源码

    导包

    我们需要导入timedatetime用于计算训练时间;导入torch用于使用Pytorch框架;导入网络from model import UNet;导入需要的工具方法from utils import [你需要的方法];导入我们的DIYDatesetfrom dataset import DriveDataest;导入transforms文件import transforms as T

    import os
    import time
    import datetime
    
    import torch
    
    from model import UNet
    from utils import train_one_epoch, evaluate, create_lr_scheduler
    from dataset import DriveDataset
    import transforms as T
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    train.py

    train.py中我们对网络进行训练,我们首先使用argparse配置参数,将参数传入main进行训练。

    案例如下:

    if __name__ == '__main__':
        args = parse_args()
        main(args)
    
    • 1
    • 2
    • 3

    argparse配置参数

    使用argparse封装需要的参数。

    参看https://blog.csdn.net/qq_43369406/article/details/127787799

    argparse函数案例如下:

    def parse_args():
        import argparse
        parser = argparse.ArgumentParser(description="pytorch unet training")
    
        parser.add_argument("--data-path", default="./", help="DRIVE root")
        # exclude background
        parser.add_argument("--num-classes", default=1, type=int)
        parser.add_argument("--device", default="cuda", help="training device")
        parser.add_argument("-b", "--batch-size", default=4, type=int)
        parser.add_argument("--epochs", default=200, type=int, metavar="N",
                            help="number of total epochs to train")
    
        parser.add_argument('--lr', default=0.01, type=float, help='initial learning rate')
        parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                            help='momentum')
        parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                            metavar='W', help='weight decay (default: 1e-4)',
                            dest='weight_decay')
        parser.add_argument('--print-freq', default=1, type=int, help='print frequency')
        parser.add_argument('--resume', default='', help='resume from checkpoint')
        parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                            help='start epoch')
        parser.add_argument('--save-best', default=True, type=bool, help='only save best dice weights')
        # Mixed precision training parameters
        parser.add_argument("--amp", default=False, type=bool,
                            help="Use torch.cuda.amp for mixed precision training")
    
        args = parser.parse_args()
    
        return args
    
    
    if __name__ == '__main__':
        args = parse_args()
    	args.data_path
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35

    main函数

    main函数包括训练的全过程,我们一般这样组织main中结构:DataLoader - 训练参数 - epoch

    DataLoader

    在DataLoader环节我们需要选择合适的Transforms传入Dataset,向DataLoader中传入Dataset和batch,DataLoader就会每次从Dataset中取出batch个数据。其中最为重要的就是选定适合的Transforms传入Dataset中,设定合适的DataLoader。

    Transforms选定如下:

    DataLoader案例如下:

    # dataloader
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    batch_size = args.batch_size
    num_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
    # segmentation nun_classes + background
    num_classes = args.num_classes + 1
    # using compute_mean_std.py
    mean = (0.709, 0.381, 0.224)
    std = (0.127, 0.079, 0.043)
    train_dataset = DriveDataset(args.data_path,
                                 train=True,
                                 transforms=get_transform(train=True, mean=mean, std=std))
    val_dataset = DriveDataset(args.data_path,
                               train=False,
                               transforms=get_transform(train=False, mean=mean, std=std))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               num_workers=num_workers,
                                               shuffle=True,
                                               pin_memory=True,
                                               collate_fn=train_dataset.collate_fn)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             num_workers=num_workers,
                                             pin_memory=True,
                                             collate_fn=val_dataset.collate_fn)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26

    训练参数

    在该步骤中需要指定模型,优化器,加载预训练权重(迁移学习)。

    # 模型
    model = create_model(num_classes=num_classes)
    model.to(device)
    
    params_to_optimize = [p for p in model.parameters() if p.requires_grad]
    # 优化器
    optimizer = torch.optim.SGD(
            params_to_optimize,
            lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    我们使用torch.load()torch.save()用来加载和保存训练超参,我们在load和save中指定model.load_state_load()model.state_dict()用来将训练的权重超参保存为字典格式进行存储,如下:

    # 保存.pth文件
    # 设定文件存储的格式
    save_file = {"model": model.state_dict(),
                 "optimizer": optimizer.state_dict(),	# 优化器中参数
                 "lr_scheduler": lr_scheduler.state_dict(),
                 "epoch": epoch,
                 "args": args}
    torch.save(save_file, "save_weights/best_model.pth")
    
    # 加载.pth文件
    # 从.pth文件中取数据
    checkpoint = torch.load(args.resume, map_location='cpu')	# args.resume="save_weights/best_model.pth"; map_location指的是映射到CPU上加载模型
            model.load_state_dict(checkpoint['model'])			# 从dictionary中根据key取value,若是用.state_dict()进行存储,则需要用.load_state_dict()将值取出
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1			# 若不是用.state_dict()取出,则直接取出来用便可
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    epoch

    每一次训练都是在epoch中进行,每一个epoch需要进行训练和测试并将训练结果进行存储,并记录每一轮训练时长。

    训练的完整代码如下:

    # 用来保存训练以及验证过程中信息
        results_file = "/home/yingmuzhi/unet/results{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))   
        best_dice = 0.
        start_time = time.time()
        for epoch in range(args.start_epoch, args.epochs):
            mean_loss, lr = train_one_epoch(model, optimizer, train_loader, device, epoch, num_classes,
                                            lr_scheduler=lr_scheduler, print_freq=args.print_freq, scaler=scaler)
    
            confmat, dice = evaluate(model, val_loader, device=device, num_classes=num_classes)
            val_info = str(confmat)
            print(val_info)
            print(f"dice coefficient: {dice:.3f}")
            # write into txt
            with open(results_file, "a") as f:
                # 记录每个epoch对应的train_loss、lr以及验证集各指标
                train_info = f"[epoch: {epoch}]\n" \
                             f"train_loss: {mean_loss:.4f}\n" \
                             f"lr: {lr:.6f}\n" \
                             f"dice coefficient: {dice:.3f}\n"
                f.write(train_info + val_info + "\n\n")
    
            if args.save_best is True:
                if best_dice < dice:
                    best_dice = dice
                else:
                    continue
    
            save_file = {"model": model.state_dict(),
                         "optimizer": optimizer.state_dict(),
                         "lr_scheduler": lr_scheduler.state_dict(),
                         "epoch": epoch,
                         "args": args}
            if args.amp:
                save_file["scaler"] = scaler.state_dict()
    
            if args.save_best is True:
                torch.save(save_file, "save_weights/best_model.pth")
            else:
                torch.save(save_file, "save_weights/model_{}.pth".format(epoch))
    
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print("training time {}".format(total_time_str))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
  • 相关阅读:
    一个.Net开发的功能强大、易于使用的流媒体服务器和管理系统
    智慧渔港:海域感知与岸线监控实施方案(智慧渔港渔船综合管控平台)
    超声波清洗机品牌哪些好用?好评不断的超声波清洗机推荐
    Arduino驱动ADXL345三轴加速度传感器(惯性测量传感器篇)
    如何修复 Windows 11/10上的 0x8007023e Windows 更新错误
    shell编程中的流程控制
    解决 /bin/bash^M: bad interpreter: No such file or directory
    OJ练习第172题——可以攻击国王的皇后
    【运维】docker如何删除所有容器
    fastadmin 表单页面,根据一个字段的值显示不同字段
  • 原文地址:https://blog.csdn.net/qq_43369406/article/details/127791648