• AugMixDataset的一些示例


    第一个示例

    # 2022.02.14-Changed for main script for wavemlp model
    #            Huawei Technologies Co., Ltd. 
    #!/usr/bin/env python
    """ ImageNet Training Script
    
    This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet
    training results with some of the latest networks and training techniques. It favours canonical PyTorch
    and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed
    and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit.
    
    This script was started from an early version of the PyTorch ImageNet example
    (https://github.com/pytorch/examples/tree/master/imagenet)
    
    NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
    (https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
    
    Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
    """
    
    
    import warnings
    warnings.filterwarnings('ignore')
    import argparse
    import time
    import yaml
    import os
    import logging
    from collections import OrderedDict
    from contextlib import suppress
    from datetime import datetime
    
    import math
    
    import torch
    import torch.nn as nn
    import torchvision.utils
    from torch.nn.parallel import DistributedDataParallel as NativeDDP
    
    from timm.data import Dataset, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset 
    from timm.models import create_model, resume_checkpoint, convert_splitbn_model
    from timm.utils import *
    from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
    from timm.optim import create_optimizer
    from timm.scheduler import create_scheduler
    from timm.utils import ApexScaler, NativeScaler
    
    from data.myloader import create_loader
    import models.wavemlp
    
    try:
        from apex import amp
        from apex.parallel import DistributedDataParallel as ApexDDP
        from apex.parallel import convert_syncbn_model
        has_apex = True
    except ImportError:
        has_apex = False
    
    has_native_amp = False
    try:
        if getattr(torch.cuda.amp, 'autocast') is not None:
            has_native_amp = True
    except AttributeError:
        pass
    
    torch.backends.cudnn.benchmark = True
    _logger = logging.getLogger('train')
    
    # The first arg parser parses out only the --config argument, this argument is used to
    # load a yaml file containing key-values that override the defaults for the main parser below
    config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
    parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
                        help='YAML config file specifying default arguments')
    
    
    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
    
    # Dataset / Model parameters
    parser.add_argument('data', metavar='DIR',
                        help='path to dataset')
    parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
                        help='Name of model to train (default: "countception"')
    parser.add_argument('--pretrained', action='store_true', default=False,
                        help='Start with pretrained version of specified network (if avail)')
    parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
                        help='Initialize model from this checkpoint (default: none)')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='Resume full model and optimizer state from checkpoint (default: none)')
    parser.add_argument('--no-resume-opt', action='store_true', default=False,
                        help='prevent resume of optimizer state when resuming model')
    parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
                        help='number of label classes (default: 1000)')
    parser.add_argument('--gp', default=None, type=str, metavar='POOL',
                        help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
    parser.add_argument('--img-size', type=int, default=None, metavar='N',
                        help='Image patch size (default: None => model default)')
    parser.add_argument('--crop-pct', default=None, type=float,
                        metavar='N', help='Input image center crop percent (for validation only)')
    parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
                        help='Override mean pixel value of dataset')
    parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
                        help='Override std deviation of of dataset')
    parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
                        help='Image resize interpolation type (overrides model)')
    parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
                        help='input batch size for training (default: 32)')
    parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',
                        help='ratio of validation batch size to training batch size (default: 1)')
    
    # Optimizer parameters
    parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
                        help='Optimizer (default: "sgd"')
    parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
                        help='Optimizer Epsilon (default: None, use opt default)')
    parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
                        help='Optimizer Betas (default: None, use opt default)')
    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                        help='Optimizer momentum (default: 0.9)')
    parser.add_argument('--weight-decay', type=float, default=0.0001,
                        help='weight decay (default: 0.0001)')
    parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
                        help='Clip gradient norm (default: None, no clipping)')
    
    
    
    # Learning rate schedule parameters
    parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
                        help='LR scheduler (default: "step"')
    parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
                        help='learning rate noise on/off epoch percentages')
    parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
                        help='learning rate noise limit percent (default: 0.67)')
    parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
                        help='learning rate noise std-dev (default: 1.0)')
    parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
                        help='learning rate cycle len multiplier (default: 1.0)')
    parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
                        help='learning rate cycle limit')
    parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
                        help='warmup learning rate (default: 0.0001)')
    parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
    parser.add_argument('--epochs', type=int, default=200, metavar='N',
                        help='number of epochs to train (default: 2)')
    parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
                        help='epoch interval to decay LR')
    parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
                        help='epochs to warmup LR, if scheduler supports')
    parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
                        help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
    parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
                        help='patience epochs for Plateau LR scheduler (default: 10')
    parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
                        help='LR decay rate (default: 0.1)')
    
    # Augmentation & regularization parameters
    parser.add_argument('--no-aug', action='store_true', default=False,
                        help='Disable all training augmentation, override other train aug args')
    parser.add_argument('--repeated-aug', action='store_true')
    parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
                        help='Random resize scale (default: 0.08 1.0)')
    parser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO',
                        help='Random resize aspect ratio (default: 0.75 1.33)')
    parser.add_argument('--hflip', type=float, default=0.5,
                        help='Horizontal flip training aug probability')
    parser.add_argument('--vflip', type=float, default=0.,
                        help='Vertical flip training aug probability')
    parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
                        help='Color jitter factor (default: 0.4)')
    parser.add_argument('--aa', type=str, default=None, metavar='NAME',
                        help='Use AutoAugment policy. "v0" or "original". (default: None)'),
    parser.add_argument('--aug-splits', type=int, default=0,
                        help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
    parser.add_argument('--jsd', action='store_true', default=False,
                        help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
    parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
                        help='Random erase prob (default: 0.)')
    parser.add_argument('--remode', type=str, default='const',
                        help='Random erase mode (default: "const")')
    parser.add_argument('--recount', type=int, default=1,
                        help='Random erase count (default: 1)')
    parser.add_argument('--resplit', action='store_true', default=False,
                        help='Do not random erase first (clean) augmentation split')
    parser.add_argument('--mixup', type=float, default=0.0,
                        help='mixup alpha, mixup enabled if > 0. (default: 0.)')
    parser.add_argument('--cutmix', type=float, default=0.0,
                        help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
    parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
                        help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
    parser.add_argument('--mixup-prob', type=float, default=1.0,
                        help='Probability of performing mixup or cutmix when either/both is enabled')
    parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
                        help='Probability of switching to cutmix when both mixup and cutmix enabled')
    parser.add_argument('--mixup-mode', type=str, default='batch',
                        help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
    parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
                        help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
    parser.add_argument('--smoothing', type=float, default=0.1,
                        help='Label smoothing (default: 0.1)')
    parser.add_argument('--train-interpolation', type=str, default='random',
                        help='Training interpolation (random, bilinear, bicubic default: "random")')
    parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
                        help='Dropout rate (default: 0.)')
    parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
                        help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
    parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',
                        help='Drop path rate (default: None)')
    parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
                        help='Drop block rate (default: None)')
    
    # Batch norm parameters (only works with gen_efficientnet based models currently)
    parser.add_argument('--bn-tf', action='store_true', default=False,
                        help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
    parser.add_argument('--bn-momentum', type=float, default=None,
                        help='BatchNorm momentum override (if not None)')
    parser.add_argument('--bn-eps', type=float, default=None,
                        help='BatchNorm epsilon override (if not None)')
    parser.add_argument('--sync-bn', action='store_true',
                        help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
    parser.add_argument('--dist-bn', type=str, default='',
                        help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
    parser.add_argument('--split-bn', action='store_true',
                        help='Enable separate BN layers per augmentation split.')
    
    # Model Exponential Moving Average
    parser.add_argument('--model-ema', action='store_true', default=False,
                        help='Enable tracking moving average of model weights')
    parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
                        help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
    parser.add_argument('--model-ema-decay', type=float, default=0.9998,
                        help='decay factor for model weights moving average (default: 0.9998)')
    
    # Misc
    parser.add_argument('--seed', type=int, default=42, metavar='S',
                        help='random seed (default: 42)')
    parser.add_argument('--log-interval', type=int, default=50, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
                        help='how many batches to wait before writing recovery checkpoint')
    parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
                        help='how many training processes to use (default: 1)')
    parser.add_argument('--num-gpu', type=int, default=1,
                        help='Number of GPUS to use')
    parser.add_argument('--save-images', action='store_true', default=False,
                        help='save images of input bathes every log interval for debugging')
    parser.add_argument('--amp', action='store_true', default=False,
                        help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
    parser.add_argument('--apex-amp', action='store_true', default=False,
                        help='Use NVIDIA Apex AMP mixed precision')
    parser.add_argument('--native-amp', action='store_true', default=False,
                        help='Use Native Torch AMP mixed precision')
    parser.add_argument('--channels-last', action='store_true', default=False,
                        help='Use channels_last memory layout')
    parser.add_argument('--pin-mem', action='store_true', default=False,
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no-prefetcher', action='store_true', default=False,
                        help='disable fast prefetcher')
    parser.add_argument('--output', default='', type=str, metavar='PATH',
                        help='path to output folder (default: none, current dir)')
    parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
                        help='Best metric (default: "top1"')
    parser.add_argument('--tta', type=int, default=0, metavar='N',
                        help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
    parser.add_argument("--local_rank", default=0, type=int)
    parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
                        help='use the multi-epochs-loader to save time at the beginning of every epoch')
    
    
    parser.add_argument("--init_method", default='env://', type=str)
    
    
    
    def _parse_args():
        # Do we have a config file to parse?
        args_config, remaining = config_parser.parse_known_args()
        if args_config.config:
            with open(args_config.config, 'r') as f:
                cfg = yaml.safe_load(f)
                parser.set_defaults(**cfg)
    
        # The main arg parser parses the rest of the args, the usual
        # defaults will have been overridden if config file specified.
        args = parser.parse_args(remaining)
    
        # Cache the args as a text string to save them in the output dir later
        args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
        return args, args_text
    
    
    def main():
        setup_default_logging()
        args, args_text = _parse_args()
    
        args.prefetcher = not args.no_prefetcher
        args.distributed = False
        if 'WORLD_SIZE' in os.environ:
            args.distributed = int(os.environ['WORLD_SIZE']) > 1
            if args.distributed and args.num_gpu > 1:
                _logger.warning(
                    'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')
                args.num_gpu = 1
    
        args.device = 'cuda:0'
        args.world_size = 1
        args.rank = 0  # global rank
        if args.distributed:
            args.num_gpu = 1
            args.device = 'cuda:%d' % args.local_rank
            torch.cuda.set_device(args.local_rank)
            args.world_size = int(os.environ['WORLD_SIZE'])
            args.rank = int(os.environ['RANK'])
            torch.distributed.init_process_group(backend='nccl', init_method=args.init_method, rank=args.rank, world_size=args.world_size)
            args.world_size = torch.distributed.get_world_size()
            args.rank = torch.distributed.get_rank()
        assert args.rank >= 0
    
        if args.distributed:
            _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
                         % (args.rank, args.world_size))
        else:
            _logger.info('Training with a single process on %d GPUs.' % args.num_gpu)
    
        torch.manual_seed(args.seed + args.rank)
    
    
        extra_dict = {} 
        extra_dict['args'] = args
        model = create_model( 
            args.model,
            pretrained=args.pretrained,
            num_classes=args.num_classes,
            drop_rate=args.drop,
            drop_connect_rate=args.drop_connect, 
            drop_path_rate=args.drop_path,
            drop_block_rate=args.drop_block,
            global_pool=args.gp,
            bn_tf=args.bn_tf,
            bn_momentum=args.bn_momentum,
            bn_eps=args.bn_eps,
            checkpoint_path=args.initial_checkpoint,
            **extra_dict)
        
        ################### flops #################
        if args.local_rank == 0:
            print(model)
        from thop import profile
        from thop.vision.basic_hooks import zero_ops
        if hasattr(model, 'default_cfg'):
            default_cfg = model.default_cfg
            input_size = [1] + list(default_cfg['input_size'])
        else:
            input_size = [1, 3, 224, 224]
        input = torch.randn(input_size)#.cuda()
        
      
        _, params = profile(model, inputs=(input, ), custom_ops={nn.BatchNorm2d: zero_ops})
        print('params:',params) 
        from torchprofile import profile_macs
        macs = profile_macs(model, input)
        print('model flops:', macs, 'input_size:', input_size)
        
    
        
        if args.local_rank == 0:
            _logger.info('Model %s created, param count: %d' %
                         (args.model, sum([m.numel() for m in model.parameters()])))
    
        data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
    
        num_aug_splits = 0
        if args.aug_splits > 0:
            assert args.aug_splits > 1, 'A split of 1 makes no sense'
            num_aug_splits = args.aug_splits
    
        if args.split_bn:
            assert num_aug_splits > 1 or args.resplit
            model = convert_splitbn_model(model, max(num_aug_splits, 2))
    
        use_amp = None
        if args.amp:
            # for backwards compat, `--amp` arg tries apex before native amp
            if has_apex:
                args.apex_amp = True
            elif has_native_amp:
                args.native_amp = True
        if args.apex_amp and has_apex:
            use_amp = 'apex'
        elif args.native_amp and has_native_amp:
            use_amp = 'native'
        elif args.apex_amp or args.native_amp:
            _logger.warning("Neither APEX or native Torch AMP is available, using float32. "
                            "Install NVIDA apex or upgrade to PyTorch 1.6")
    
        if args.num_gpu > 1:
            if use_amp == 'apex':
                _logger.warning(
                    'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.')
                use_amp = None
            model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
            assert not args.channels_last, "Channels last not supported with DP, use DDP."
        else:
            model.cuda()
            if args.channels_last:
                model = model.to(memory_format=torch.channels_last)
    
        optimizer = create_optimizer(args, model)
    
        amp_autocast = suppress  # do nothing
        loss_scaler = None
        if use_amp == 'apex':
            model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
            loss_scaler = ApexScaler()
            if args.local_rank == 0:
                _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
        elif use_amp == 'native':
            amp_autocast = torch.cuda.amp.autocast
            loss_scaler = NativeScaler()
            if args.local_rank == 0:
                _logger.info('Using native Torch AMP. Training in mixed precision.')
        else:
            if args.local_rank == 0:
                _logger.info('AMP not enabled. Training in float32.')
    
        # optionally resume from a checkpoint
        resume_epoch = None
        if args.resume:
            resume_epoch = resume_checkpoint(
                model, args.resume,
                optimizer=None if args.no_resume_opt else optimizer,
                loss_scaler=None if args.no_resume_opt else loss_scaler,
                log_info=args.local_rank == 0)
    
        model_ema = None
        if args.model_ema:
            # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
            model_ema = ModelEma(
                model,
                decay=args.model_ema_decay,
                device='cpu' if args.model_ema_force_cpu else '',
                resume=args.resume)
    
        if args.distributed:
            if args.sync_bn:
                assert not args.split_bn
                try:
                    if has_apex and use_amp != 'native':
                        # Apex SyncBN preferred unless native amp is activated
                        model = convert_syncbn_model(model)
                    else:
                        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
                    if args.local_rank == 0:
                        _logger.info(
                            'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                            'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
                except Exception as e:
                    _logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
            if has_apex and use_amp != 'native':
                # Apex DDP preferred unless native amp is activated
                if args.local_rank == 0:
                    _logger.info("Using NVIDIA APEX DistributedDataParallel.")
                model = ApexDDP(model, delay_allreduce=True)
            else:
                if args.local_rank == 0:
                    _logger.info("Using native Torch DistributedDataParallel.")
                #model = NativeDDP(model, device_ids=[args.local_rank])  # can use device str in Torch >= 1.1
                model = NativeDDP(model, device_ids=[args.local_rank],find_unused_parameters=True) 
            # NOTE: EMA model does not need to be wrapped by DDP
    
        lr_scheduler, num_epochs = create_scheduler(args, optimizer)
        start_epoch = 0
        if args.start_epoch is not None:
            # a specified start_epoch will always override the resume epoch
            start_epoch = args.start_epoch
        elif resume_epoch is not None:
            start_epoch = resume_epoch
        if lr_scheduler is not None and start_epoch > 0:
            lr_scheduler.step(start_epoch)
    
        if args.local_rank == 0:
            _logger.info('Scheduled epochs: {}'.format(num_epochs))
        
        if 'cifar100' in args.data: 
            dataset_train = torchvision.datasets.CIFAR100(root=args.data, train=True, download=False, transform=None)
        elif 'cifar10' in args.data: 
            dataset_train = torchvision.datasets.CIFAR10(root=args.data, train=True, download=False, transform=None)
        else: 
            train_dir = os.path.join(args.data, 'train')
            if not os.path.exists(train_dir):
                _logger.error('Training folder does not exist at: {}'.format(train_dir))
                exit(1)
            dataset_train = Dataset(train_dir)
    
        collate_fn = None
        mixup_fn = None
        mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
        if mixup_active:
            mixup_args = dict(
                mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
                prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
                label_smoothing=args.smoothing, num_classes=args.num_classes)
            if args.prefetcher:
                assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
                collate_fn = FastCollateMixup(**mixup_args)
            else:
                mixup_fn = Mixup(**mixup_args)
    
        if num_aug_splits > 1:
            dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
    
        train_interpolation = args.train_interpolation
        if args.no_aug or not train_interpolation:
            train_interpolation = data_config['interpolation']
        loader_train = create_loader(
            dataset_train,
            input_size=data_config['input_size'],
            batch_size=args.batch_size,
            is_training=True,
            use_prefetcher=args.prefetcher,
            no_aug=args.no_aug,
            re_prob=args.reprob,
            re_mode=args.remode,
            re_count=args.recount,
            re_split=args.resplit,
            scale=args.scale,
            ratio=args.ratio,
            hflip=args.hflip,
            vflip=args.vflip,
            color_jitter=args.color_jitter,
            auto_augment=args.aa,
            num_aug_splits=num_aug_splits,
            interpolation=train_interpolation,
            mean=data_config['mean'],
            std=data_config['std'],
            num_workers=args.workers,
            distributed=args.distributed,
            collate_fn=collate_fn,
            pin_memory=args.pin_mem,
            use_multi_epochs_loader=args.use_multi_epochs_loader,
            repeated_aug=args.repeated_aug
        )
        if 'cifar100' in args.data: 
            dataset_eval = torchvision.datasets.CIFAR100(root=args.data, train=False, download=False, transform=None)
        elif 'cifar10' in args.data: 
            dataset_eval = torchvision.datasets.CIFAR10(root=args.data, train=False, download=False, transform=None)
        else:
            eval_dir = os.path.join(args.data, 'val')
            if not os.path.isdir(eval_dir):
                eval_dir = os.path.join(args.data, 'validation')
                if not os.path.isdir(eval_dir):
                    _logger.error('Validation folder does not exist at: {}'.format(eval_dir))
                    exit(1)
            dataset_eval = Dataset(eval_dir)
    
        loader_eval = create_loader(
            dataset_eval,
            input_size=data_config['input_size'],
            batch_size=args.validation_batch_size_multiplier * args.batch_size,
            is_training=False,
            use_prefetcher=args.prefetcher,
            interpolation=data_config['interpolation'],
            mean=data_config['mean'],
            std=data_config['std'],
            num_workers=args.workers,
            distributed=args.distributed,
            crop_pct=data_config['crop_pct'],
            pin_memory=args.pin_mem,
        )
    
        if args.jsd:
            assert num_aug_splits > 1  # JSD only valid with aug splits set
            train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
        elif mixup_active:
            # smoothing is handled with mixup target transform
            train_loss_fn = SoftTargetCrossEntropy().cuda()
        elif args.smoothing:
            train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
        else:
            train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    
        eval_metric = args.eval_metric
        best_metric = None
        best_epoch = None
        saver = None
        output_dir = ''
        if args.local_rank == 0:
            output_base = args.output if args.output else './output'
            exp_name = '-'.join([
                datetime.now().strftime("%Y%m%d-%H%M%S"),
                args.model,
                str(data_config['input_size'][-1])
            ])
            output_dir = get_outdir(output_base, 'train', exp_name)
            decreasing = True if eval_metric == 'loss' else False
            saver = CheckpointSaver(
                model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
                checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing,max_history=args.max_history)
            ###max_history 最多保存的checkpoint数
            with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
                f.write(args_text)
    
        try:
            for epoch in range(start_epoch, num_epochs):
                if args.distributed:
                    loader_train.sampler.set_epoch(epoch)
    
                train_metrics = train_epoch(
                    epoch, model, loader_train, optimizer, train_loss_fn, args,
                    lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
                    amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
    
                if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                    if args.local_rank == 0:
                        _logger.info("Distributing BatchNorm running means and vars")
                    distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
    
                eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
    
                if model_ema is not None and not args.model_ema_force_cpu:
                    if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                        distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
                    ema_eval_metrics = validate(
                        model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
                    eval_metrics = ema_eval_metrics
    
                if lr_scheduler is not None:
                    # step LR for next epoch
                    lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
    
                update_summary(
                    epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
                    write_header=best_metric is None)
    
                if saver is not None:
                    # save proper checkpoint with eval metric
                    save_metric = eval_metrics[eval_metric]
                    best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
    
        except KeyboardInterrupt:
            pass
        if best_metric is not None:
            _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
            
    
            
    def train_epoch(
            epoch, model, loader, optimizer, loss_fn, args,
            lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,
            loss_scaler=None, model_ema=None, mixup_fn=None):
    
        if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
            if args.prefetcher and loader.mixup_enabled:
                loader.mixup_enabled = False
            elif mixup_fn is not None:
                mixup_fn.mixup_enabled = False
    
        second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
        batch_time_m = AverageMeter()
        data_time_m = AverageMeter()
        losses_m = AverageMeter()
        
        top1_m = AverageMeter()
        top5_m = AverageMeter()
        
        
        model.train()
    
        end = time.time()
        last_idx = len(loader) - 1
        num_updates = epoch * len(loader)
        for batch_idx, (input, target) in enumerate(loader):
            last_batch = batch_idx == last_idx
            data_time_m.update(time.time() - end)
            if not args.prefetcher:
                input, target = input.cuda(), target.cuda()
                if mixup_fn is not None:
                    input, target = mixup_fn(input, target)
            if args.channels_last:
                input = input.contiguous(memory_format=torch.channels_last)
    
            with amp_autocast():
                output = model(input)
                loss = loss_fn(output, target)
            
            acc1, acc5 = accuracy(output, target.argmax(dim=-1), topk=(1, 5))
            if args.distributed:               
                acc1 = reduce_tensor(acc1, args.world_size)
                acc5 = reduce_tensor(acc5, args.world_size)
            
            
            if not args.distributed:
                losses_m.update(loss.item(), input.size(0))
    
            optimizer.zero_grad()
            if loss_scaler is not None:
                loss_scaler(
                    loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order)
            else:
                loss.backward(create_graph=second_order)
                if args.clip_grad is not None:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
                optimizer.step()
    
            torch.cuda.synchronize()
            
            top1_m.update(acc1.item(), output.size(0))
            top5_m.update(acc5.item(), output.size(0))
            
            if model_ema is not None:
                model_ema.update(model)
            num_updates += 1
    
            batch_time_m.update(time.time() - end)
            if last_batch or batch_idx % args.log_interval == 0:
                lrl = [param_group['lr'] for param_group in optimizer.param_groups]
                lr = sum(lrl) / len(lrl)
    
                if args.distributed:
                    reduced_loss = reduce_tensor(loss.data, args.world_size)
                    losses_m.update(reduced_loss.item(), input.size(0))
    
                if args.local_rank == 0:
                    _logger.info(
                        'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '
                        'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '
                        'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '
                        '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                        'LR: {lr:.3e}  '
                        'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
                            epoch,
                            batch_idx, len(loader),
                            100. * batch_idx / last_idx,
                            loss=losses_m,
                            batch_time=batch_time_m,
                            rate=input.size(0) * args.world_size / batch_time_m.val,
                            rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
                            lr=lr,
                            data_time=data_time_m))
                    _logger.info(
                        'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '
                        'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(top1=top1_m, top5=top5_m))
                    
                    if args.save_images and output_dir:
                        torchvision.utils.save_image(
                            input,
                            os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
                            padding=0,
                            normalize=True)
    
            if saver is not None and args.recovery_interval and (
                    last_batch or (batch_idx + 1) % args.recovery_interval == 0):
                saver.save_recovery(epoch, batch_idx=batch_idx)
    
            if lr_scheduler is not None:
                lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
    
            end = time.time()
            # end for
    
        if hasattr(optimizer, 'sync_lookahead'):
            optimizer.sync_lookahead()
    
        return OrderedDict([('loss', losses_m.avg)])
    
    
    def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):
        batch_time_m = AverageMeter()
        losses_m = AverageMeter()
        top1_m = AverageMeter()
        top5_m = AverageMeter()
    
        model.eval()
    
        end = time.time()
        last_idx = len(loader) - 1
        with torch.no_grad():
            for batch_idx, (input, target) in enumerate(loader):
                last_batch = batch_idx == last_idx
                if not args.prefetcher:
                    input = input.cuda()
                    target = target.cuda()
                if args.channels_last:
                    input = input.contiguous(memory_format=torch.channels_last)
    
                with amp_autocast():
                    output = model(input)
                if isinstance(output, (tuple, list)):
                    output = output[0]
    
                # augmentation reduction
                reduce_factor = args.tta
                if reduce_factor > 1:
                    output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
                    target = target[0:target.size(0):reduce_factor]
    
                loss = loss_fn(output, target)
                acc1, acc5 = accuracy(output, target, topk=(1, 5))
    
                if args.distributed:
                    reduced_loss = reduce_tensor(loss.data, args.world_size)
                    acc1 = reduce_tensor(acc1, args.world_size)
                    acc5 = reduce_tensor(acc5, args.world_size)
                else:
                    reduced_loss = loss.data
    
                torch.cuda.synchronize()
    
                losses_m.update(reduced_loss.item(), input.size(0))
                top1_m.update(acc1.item(), output.size(0))
                top5_m.update(acc5.item(), output.size(0))
    
                batch_time_m.update(time.time() - end)
                end = time.time()
                if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
                    log_name = 'Test' + log_suffix
                    _logger.info(
                        '{0}: [{1:>4d}/{2}]  '
                        'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '
                        'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                        'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '
                        'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
                            log_name, batch_idx, last_idx, batch_time=batch_time_m,
                            loss=losses_m, top1=top1_m, top5=top5_m))
    
        metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
    
        return metrics
    
    
    if __name__ == '__main__':
        main()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249
    • 250
    • 251
    • 252
    • 253
    • 254
    • 255
    • 256
    • 257
    • 258
    • 259
    • 260
    • 261
    • 262
    • 263
    • 264
    • 265
    • 266
    • 267
    • 268
    • 269
    • 270
    • 271
    • 272
    • 273
    • 274
    • 275
    • 276
    • 277
    • 278
    • 279
    • 280
    • 281
    • 282
    • 283
    • 284
    • 285
    • 286
    • 287
    • 288
    • 289
    • 290
    • 291
    • 292
    • 293
    • 294
    • 295
    • 296
    • 297
    • 298
    • 299
    • 300
    • 301
    • 302
    • 303
    • 304
    • 305
    • 306
    • 307
    • 308
    • 309
    • 310
    • 311
    • 312
    • 313
    • 314
    • 315
    • 316
    • 317
    • 318
    • 319
    • 320
    • 321
    • 322
    • 323
    • 324
    • 325
    • 326
    • 327
    • 328
    • 329
    • 330
    • 331
    • 332
    • 333
    • 334
    • 335
    • 336
    • 337
    • 338
    • 339
    • 340
    • 341
    • 342
    • 343
    • 344
    • 345
    • 346
    • 347
    • 348
    • 349
    • 350
    • 351
    • 352
    • 353
    • 354
    • 355
    • 356
    • 357
    • 358
    • 359
    • 360
    • 361
    • 362
    • 363
    • 364
    • 365
    • 366
    • 367
    • 368
    • 369
    • 370
    • 371
    • 372
    • 373
    • 374
    • 375
    • 376
    • 377
    • 378
    • 379
    • 380
    • 381
    • 382
    • 383
    • 384
    • 385
    • 386
    • 387
    • 388
    • 389
    • 390
    • 391
    • 392
    • 393
    • 394
    • 395
    • 396
    • 397
    • 398
    • 399
    • 400
    • 401
    • 402
    • 403
    • 404
    • 405
    • 406
    • 407
    • 408
    • 409
    • 410
    • 411
    • 412
    • 413
    • 414
    • 415
    • 416
    • 417
    • 418
    • 419
    • 420
    • 421
    • 422
    • 423
    • 424
    • 425
    • 426
    • 427
    • 428
    • 429
    • 430
    • 431
    • 432
    • 433
    • 434
    • 435
    • 436
    • 437
    • 438
    • 439
    • 440
    • 441
    • 442
    • 443
    • 444
    • 445
    • 446
    • 447
    • 448
    • 449
    • 450
    • 451
    • 452
    • 453
    • 454
    • 455
    • 456
    • 457
    • 458
    • 459
    • 460
    • 461
    • 462
    • 463
    • 464
    • 465
    • 466
    • 467
    • 468
    • 469
    • 470
    • 471
    • 472
    • 473
    • 474
    • 475
    • 476
    • 477
    • 478
    • 479
    • 480
    • 481
    • 482
    • 483
    • 484
    • 485
    • 486
    • 487
    • 488
    • 489
    • 490
    • 491
    • 492
    • 493
    • 494
    • 495
    • 496
    • 497
    • 498
    • 499
    • 500
    • 501
    • 502
    • 503
    • 504
    • 505
    • 506
    • 507
    • 508
    • 509
    • 510
    • 511
    • 512
    • 513
    • 514
    • 515
    • 516
    • 517
    • 518
    • 519
    • 520
    • 521
    • 522
    • 523
    • 524
    • 525
    • 526
    • 527
    • 528
    • 529
    • 530
    • 531
    • 532
    • 533
    • 534
    • 535
    • 536
    • 537
    • 538
    • 539
    • 540
    • 541
    • 542
    • 543
    • 544
    • 545
    • 546
    • 547
    • 548
    • 549
    • 550
    • 551
    • 552
    • 553
    • 554
    • 555
    • 556
    • 557
    • 558
    • 559
    • 560
    • 561
    • 562
    • 563
    • 564
    • 565
    • 566
    • 567
    • 568
    • 569
    • 570
    • 571
    • 572
    • 573
    • 574
    • 575
    • 576
    • 577
    • 578
    • 579
    • 580
    • 581
    • 582
    • 583
    • 584
    • 585
    • 586
    • 587
    • 588
    • 589
    • 590
    • 591
    • 592
    • 593
    • 594
    • 595
    • 596
    • 597
    • 598
    • 599
    • 600
    • 601
    • 602
    • 603
    • 604
    • 605
    • 606
    • 607
    • 608
    • 609
    • 610
    • 611
    • 612
    • 613
    • 614
    • 615
    • 616
    • 617
    • 618
    • 619
    • 620
    • 621
    • 622
    • 623
    • 624
    • 625
    • 626
    • 627
    • 628
    • 629
    • 630
    • 631
    • 632
    • 633
    • 634
    • 635
    • 636
    • 637
    • 638
    • 639
    • 640
    • 641
    • 642
    • 643
    • 644
    • 645
    • 646
    • 647
    • 648
    • 649
    • 650
    • 651
    • 652
    • 653
    • 654
    • 655
    • 656
    • 657
    • 658
    • 659
    • 660
    • 661
    • 662
    • 663
    • 664
    • 665
    • 666
    • 667
    • 668
    • 669
    • 670
    • 671
    • 672
    • 673
    • 674
    • 675
    • 676
    • 677
    • 678
    • 679
    • 680
    • 681
    • 682
    • 683
    • 684
    • 685
    • 686
    • 687
    • 688
    • 689
    • 690
    • 691
    • 692
    • 693
    • 694
    • 695
    • 696
    • 697
    • 698
    • 699
    • 700
    • 701
    • 702
    • 703
    • 704
    • 705
    • 706
    • 707
    • 708
    • 709
    • 710
    • 711
    • 712
    • 713
    • 714
    • 715
    • 716
    • 717
    • 718
    • 719
    • 720
    • 721
    • 722
    • 723
    • 724
    • 725
    • 726
    • 727
    • 728
    • 729
    • 730
    • 731
    • 732
    • 733
    • 734
    • 735
    • 736
    • 737
    • 738
    • 739
    • 740
    • 741
    • 742
    • 743
    • 744
    • 745
    • 746
    • 747
    • 748
    • 749
    • 750
    • 751
    • 752
    • 753
    • 754
    • 755
    • 756
    • 757
    • 758
    • 759
    • 760
    • 761
    • 762
    • 763
    • 764
    • 765
    • 766
    • 767
    • 768
    • 769
    • 770
    • 771
    • 772
    • 773
    • 774
    • 775
    • 776
    • 777
    • 778
    • 779
    • 780
    • 781
    • 782
    • 783
    • 784
    • 785
    • 786
    • 787
    • 788
    • 789
    • 790
    • 791
    • 792
    • 793
    • 794
    • 795
    • 796
    • 797
    • 798
    • 799
    • 800
    • 801
    • 802
    • 803
    • 804
    • 805
    • 806
    • 807
    • 808
    • 809
    • 810
    • 811
    • 812
    • 813
    • 814
    • 815
    • 816
    • 817
    • 818
    • 819
    • 820
    • 821
    • 822
    • 823
    • 824
    • 825
    • 826
    • 827
    • 828
    • 829
    • 830
    • 831
    • 832
    • 833
    • 834
    • 835

    第二个示例

    def main():
        setup_default_logging()
        args, args_text = _parse_args()
        
        if args.log_wandb:
            if has_wandb:
                wandb.init(project=args.experiment, config=args)
            else: 
                _logger.warning("You've requested to log metrics to wandb but package not found. "
                                "Metrics not being logged to wandb, try `pip install wandb`")
                 
        args.prefetcher = not args.no_prefetcher
        args.distributed = False
        if 'WORLD_SIZE' in os.environ:
            args.distributed = int(os.environ['WORLD_SIZE']) > 1
        args.world_size = int(os.environ['WORLD_SIZE'])
        args.rank = 0  # global rank
        if args.distributed:
            args.device = 'npu:%d' % args.local_rank
            torch.npu.set_device(args.local_rank)
            torch.distributed.init_process_group(backend='hccl', world_size=args.world_size, rank=args.local_rank)
            args.world_size = torch.distributed.get_world_size()
            args.rank = torch.distributed.get_rank()
            _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
                         % (args.rank, args.world_size))
        else:
            args.device = f'npu:{args.local_rank}'
            torch.npu.set_device(args.device)
            _logger.info('Training with a single process on 1 GPUs.')
        assert args.rank >= 0
    
        # resolve AMP arguments based on PyTorch / Apex availability
        use_amp = None
        if args.amp:
            # `--amp` chooses native amp before apex (APEX ver not actively maintained)
            if has_native_amp:
                args.native_amp = True
            elif has_apex:
                args.apex_amp = True
        if args.apex_amp and has_apex:
            use_amp = 'apex'
        elif args.native_amp and has_native_amp:
            use_amp = 'native'
        elif args.apex_amp or args.native_amp:
            _logger.warning("Neither APEX or native Torch AMP is available, using float32. "
                            "Install NVIDA apex or upgrade to PyTorch 1.6")
    
        random_seed(args.seed, args.rank)
    
        model = create_model(
            args.model,
            pretrained=args.pretrained,
            num_classes=args.num_classes,
            drop_rate=args.drop,
            drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path
            drop_path_rate=args.drop_path,
            drop_block_rate=args.drop_block,
            global_pool=args.gp,
            bn_tf=args.bn_tf,
            bn_momentum=args.bn_momentum,
            bn_eps=args.bn_eps,
            scriptable=args.torchscript,
            checkpoint_path=args.initial_checkpoint)
        if args.num_classes is None:
            assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
            args.num_classes = model.num_classes  # FIXME handle model default vs config num_classes more elegantly
    
        if args.local_rank == 0:
            _logger.info(
                f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
    
        data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
    
        # setup augmentation batch splits for contrastive loss or split bn
        num_aug_splits = 0
        if args.aug_splits > 0:
            assert args.aug_splits > 1, 'A split of 1 makes no sense'
            num_aug_splits = args.aug_splits
    
        # enable split bn (separate bn stats per batch-portion)
        if args.split_bn:
            assert num_aug_splits > 1 or args.resplit
            model = convert_splitbn_model(model, max(num_aug_splits, 2))
    
        # move model to GPU, enable channels last layout if set
        model.npu()
        if args.channels_last:
            model = model.to(memory_format=torch.channels_last)
    
        # setup synchronized BatchNorm for distributed training
        if args.distributed and args.sync_bn:
            assert not args.split_bn
            if has_apex and use_amp != 'native':
                # Apex SyncBN preferred unless native amp is activated
                model = convert_syncbn_model(model)
            else:
                model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            if args.local_rank == 0:
                _logger.info(
                    'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                    'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
    
        if args.torchscript:
            assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
            assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
            model = torch.jit.script(model)
    
        optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
    
        # setup automatic mixed-precision (AMP) loss scaling and op casting
        amp_autocast = suppress  # do nothing
        loss_scaler = None
        if use_amp == 'apex':
            model, optimizer = amp.initialize(model, optimizer, opt_level='O2', loss_scale=64, combine_grad=True)
            loss_scaler = ApexScaler()
            if args.local_rank == 0:
                _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
        elif use_amp == 'native':
            amp_autocast = torch.npu.amp.autocast
            loss_scaler = NativeScaler()
            if args.local_rank == 0:
                _logger.info('Using native Torch AMP. Training in mixed precision.')
        else:
            if args.local_rank == 0:
                _logger.info('AMP not enabled. Training in float32.')
    
        # optionally resume from a checkpoint
        resume_epoch = None
        if args.resume:
            resume_epoch = resume_checkpoint(
                model, args.resume,
                optimizer=None if args.no_resume_opt else optimizer,
                loss_scaler=None if args.no_resume_opt else loss_scaler,
                log_info=args.local_rank == 0)
    
        # setup exponential moving average of model weights, SWA could be used here too
        model_ema = None
        if args.model_ema:
            # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
            model_ema = ModelEmaV2(
                model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
            if args.resume:
                load_checkpoint(model_ema.module, args.resume, use_ema=True)
    
        # setup distributed training
        if args.distributed:
            # if has_apex and use_amp != 'native':
            #     # Apex DDP preferred unless native amp is activated
            #     if args.local_rank == 0:
            #         _logger.info("Using NVIDIA APEX DistributedDataParallel.")
            #     model = ApexDDP(model, delay_allreduce=True)
            # else:
            #     if args.local_rank == 0:
            #         _logger.info("Using native Torch DistributedDataParallel.")
            #     model = NativeDDP(model, device_ids=[args.local_rank])  # can use device str in Torch >= 1.1
            # NOTE: EMA model does not need to be wrapped by DDP
            if args.local_rank == 0:
                _logger.info("Using native Torch DistributedDataParallel.")
            model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=False)  # can use device str in Torch >= 1.1
    
        # setup learning rate schedule and starting epoch
        lr_scheduler, num_epochs = create_scheduler(args, optimizer)
        start_epoch = 0
        if args.start_epoch is not None:
            # a specified start_epoch will always override the resume epoch
            start_epoch = args.start_epoch
        elif resume_epoch is not None:
            start_epoch = resume_epoch
        if lr_scheduler is not None and start_epoch > 0:
            lr_scheduler.step(start_epoch)
    
        if args.local_rank == 0:
            _logger.info('Scheduled epochs: {}'.format(num_epochs))
    
        # create the train and eval datasets
        dataset_train = create_dataset(
            args.dataset,
            root=args.data_dir, split=args.train_split, is_training=True,
            batch_size=args.batch_size, repeats=args.epoch_repeats)
        dataset_eval = create_dataset(
            args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size)
    
        # setup mixup / cutmix
        collate_fn = None
        mixup_fn = None
        mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
        if mixup_active:
            mixup_args = dict(
                mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
                prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
                label_smoothing=args.smoothing, num_classes=args.num_classes)
            if args.prefetcher:
                assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
                collate_fn = FastCollateMixup(**mixup_args)
            else:
                mixup_fn = Mixup(**mixup_args)
    
        # wrap dataset in AugMix helper
        if num_aug_splits > 1:
            dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
    
        # create data loaders w/ augmentation pipeiine
        train_interpolation = args.train_interpolation
        if args.no_aug or not train_interpolation:
            train_interpolation = data_config['interpolation']
        loader_train = create_loader(
            dataset_train,
            input_size=data_config['input_size'],
            batch_size=args.batch_size,
            is_training=True,
            use_prefetcher=args.prefetcher,
            no_aug=args.no_aug,
            re_prob=args.reprob,
            re_mode=args.remode,
            re_count=args.recount,
            re_split=args.resplit,
            scale=args.scale,
            ratio=args.ratio,
            hflip=args.hflip,
            vflip=args.vflip,
            color_jitter=args.color_jitter,
            auto_augment=args.aa,
            num_aug_splits=num_aug_splits,
            interpolation=train_interpolation,
            mean=data_config['mean'],
            std=data_config['std'],
            num_workers=args.workers,
            distributed=args.distributed,
            collate_fn=collate_fn,
            pin_memory=args.pin_mem,
            use_multi_epochs_loader=args.use_multi_epochs_loader
        )
    
        loader_eval = create_loader(
            dataset_eval,
            input_size=data_config['input_size'],
            batch_size=args.validation_batch_size_multiplier * args.batch_size,
            is_training=False,
            use_prefetcher=args.prefetcher,
            interpolation=data_config['interpolation'],
            mean=data_config['mean'],
            std=data_config['std'],
            num_workers=args.workers,
            distributed=args.distributed,
            crop_pct=data_config['crop_pct'],
            pin_memory=args.pin_mem,
        )
    
        # setup loss function
        if args.jsd:
            assert num_aug_splits > 1  # JSD only valid with aug splits set
            train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).npu()
        elif mixup_active:
            # smoothing is handled with mixup target transform
            train_loss_fn = SoftTargetCrossEntropy().npu()
        elif args.smoothing:
            train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).npu()
        else:
            train_loss_fn = nn.CrossEntropyLoss().npu()
        validate_loss_fn = nn.CrossEntropyLoss().npu()
    
        # setup checkpoint saver and eval metric tracking
        eval_metric = args.eval_metric
        best_metric = None
        best_epoch = None
        saver = None
        output_dir = None
        if args.local_rank == 0:
            if args.experiment:
                exp_name = args.experiment
            else:
                exp_name = '-'.join([
                    datetime.now().strftime("%Y%m%d-%H%M%S"),
                    safe_model_name(args.model),
                    str(data_config['input_size'][-1])
                ])
            output_dir = get_outdir(args.output if args.output else './output/train', exp_name)
            decreasing = True if eval_metric == 'loss' else False
            saver = CheckpointSaver(
                model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
                checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist)
            with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
                f.write(args_text)
    
        try:
            for epoch in range(start_epoch, num_epochs):
                if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
                    loader_train.sampler.set_epoch(epoch)
    
                train_metrics = train_one_epoch(
                    epoch, model, loader_train, optimizer, train_loss_fn, args,
                    lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
                    amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
    
                if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                    if args.local_rank == 0:
                        _logger.info("Distributing BatchNorm running means and vars")
                    distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
    
                eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
    
                if model_ema is not None and not args.model_ema_force_cpu:
                    if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                        distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
                    ema_eval_metrics = validate(
                        model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
                    eval_metrics = ema_eval_metrics
    
                if lr_scheduler is not None:
                    # step LR for next epoch
                    lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
    
                if output_dir is not None:
                    update_summary(
                        epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
                        write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)
    
                if saver is not None:
                    # save proper checkpoint with eval metric
                    save_metric = eval_metrics[eval_metric]
                    best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
    
        except KeyboardInterrupt:
            pass
        if best_metric is not None:
            _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
    
    
    def train_one_epoch(
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249
    • 250
    • 251
    • 252
    • 253
    • 254
    • 255
    • 256
    • 257
    • 258
    • 259
    • 260
    • 261
    • 262
    • 263
    • 264
    • 265
    • 266
    • 267
    • 268
    • 269
    • 270
    • 271
    • 272
    • 273
    • 274
    • 275
    • 276
    • 277
    • 278
    • 279
    • 280
    • 281
    • 282
    • 283
    • 284
    • 285
    • 286
    • 287
    • 288
    • 289
    • 290
    • 291
    • 292
    • 293
    • 294
    • 295
    • 296
    • 297
    • 298
    • 299
    • 300
    • 301
    • 302
    • 303
    • 304
    • 305
    • 306
    • 307
    • 308
    • 309
    • 310
    • 311
    • 312
    • 313
    • 314
    • 315
    • 316
    • 317
    • 318
    • 319
    • 320
    • 321
    • 322
    • 323
    • 324
    • 325
    • 326
    • 327
    • 328
    • 329
    • 330

    第三个示例

    def main():
        setup_default_logging()
        args, args_text = _parse_args()
    
        args.prefetcher = not args.no_prefetcher
        args.distributed = False
        if 'WORLD_SIZE' in os.environ:
            args.distributed = int(os.environ['WORLD_SIZE']) > 1
        args.device = 'npu:0'
        args.world_size = 1
        args.rank = 0  # global rank
        if args.distributed:
            args.device = 'npu:%d' % args.local_rank
            torch.npu.set_device(args.local_rank)
            torch.distributed.init_process_group(backend='hccl', init_method='env://')
            args.world_size = torch.distributed.get_world_size()
            args.rank = torch.distributed.get_rank()
            _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
                         % (args.rank, args.world_size))
        else:
            _logger.info('Training with a single process on 1 GPUs.')
        assert args.rank >= 0
    
        # resolve AMP arguments based on PyTorch / Apex availability
        use_amp = None
        if args.amp:
            # `--amp` chooses native amp before apex (APEX ver not actively maintained)
            if has_native_amp:
                args.native_amp = True
            elif has_apex:
                args.apex_amp = True
        if args.apex_amp and has_apex:
            use_amp = 'apex'
        elif args.native_amp and has_native_amp:
            use_amp = 'native'
        elif args.apex_amp or args.native_amp:
            _logger.warning("Neither APEX or native Torch AMP is available, using float32. "
                            "Install NVIDA apex or upgrade to PyTorch 1.6")
    
        torch.manual_seed(args.seed + args.rank)
    
        if args.pretrained_load:
            model = create_model(
                args.model,
                pretrained=args.pretrained,
                num_classes=args.num_classes,
                drop_rate=args.drop,
                drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path
                drop_path_rate=args.drop_path,
                drop_block_rate=args.drop_block,
                global_pool=args.gp,
                bn_tf=args.bn_tf,
                bn_momentum=args.bn_momentum,
                bn_eps=args.bn_eps,
                scriptable=args.torchscript,
                checkpoint_path=args.initial_checkpoint)
            checkpoint = torch.load(args.resume_pth, map_location='cpu')
            model.load_state_dict(checkpoint)
        else:
            model = create_model(
                args.model,
                pretrained=args.pretrained,
                num_classes=args.num_classes,
                drop_rate=args.drop,
                drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path
                drop_path_rate=args.drop_path,
                drop_block_rate=args.drop_block,
                global_pool=args.gp,
                bn_tf=args.bn_tf,
                bn_momentum=args.bn_momentum,
                bn_eps=args.bn_eps,
                scriptable=args.torchscript,
                checkpoint_path=args.initial_checkpoint)
        if args.num_classes is None:
            assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
            args.num_classes = model.num_classes  # FIXME handle model default vs config num_classes more elegantly
    
        if args.local_rank == 0:
            _logger.info(
                f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
    
        data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
    
        # setup augmentation batch splits for contrastive loss or split bn
        num_aug_splits = 0
        if args.aug_splits > 0:
            assert args.aug_splits > 1, 'A split of 1 makes no sense'
            num_aug_splits = args.aug_splits
    
        # enable split bn (separate bn stats per batch-portion)
        if args.split_bn:
            assert num_aug_splits > 1 or args.resplit
            model = convert_splitbn_model(model, max(num_aug_splits, 2))
    
        # move model to GPU, enable channels last layout if set
        model.npu()
        if args.channels_last:
            model = model.to(memory_format=torch.channels_last)
    
        # setup synchronized BatchNorm for distributed training
        if args.distributed and args.sync_bn:
            assert not args.split_bn
            if has_apex and use_amp != 'native':
                # Apex SyncBN preferred unless native amp is activated
                model = convert_syncbn_model(model)
            else:
                model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            if args.local_rank == 0:
                _logger.info(
                    'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                    'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
    
        if args.torchscript:
            assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
            assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
            model = torch.jit.script(model)
    
        optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
    
        # setup automatic mixed-precision (AMP) loss scaling and op casting
        amp_autocast = suppress  # do nothing
        loss_scaler = None
        if use_amp == 'apex':
            model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
            loss_scaler = ApexScaler()
            if args.local_rank == 0:
                _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
        elif use_amp == 'native':
            amp_autocast = torch.npu.amp.autocast
            loss_scaler = NativeScaler()
            if args.local_rank == 0:
                _logger.info('Using native Torch AMP. Training in mixed precision.')
        else:
            if args.local_rank == 0:
                _logger.info('AMP not enabled. Training in float32.')
    
        # optionally resume from a checkpoint
        resume_epoch = None
        if args.resume:
            resume_epoch = resume_checkpoint(
                model, args.resume,
                optimizer=None if args.no_resume_opt else optimizer,
                loss_scaler=None if args.no_resume_opt else loss_scaler,
                log_info=args.local_rank == 0)
    
        # setup exponential moving average of model weights, SWA could be used here too
        model_ema = None
        if args.model_ema:
            # Important to create EMA model after npu(), DP wrapper, and AMP but before SyncBN and DDP wrapper
            model_ema = ModelEmaV2(
                model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
            if args.resume:
                load_checkpoint(model_ema.module, args.resume, use_ema=True)
    
        # setup distributed training
        if args.distributed:
            if has_apex and use_amp != 'native':
                # Apex DDP preferred unless native amp is activated
                if args.local_rank == 0:
                    _logger.info("Using NVIDIA APEX DistributedDataParallel.")
                model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], broadcast_buffers=False)
            else:
                if args.local_rank == 0:
                    _logger.info("Using native Torch DistributedDataParallel.")
                model = NativeDDP(model, device_ids=[args.local_rank])  # can use device str in Torch >= 1.1
            # NOTE: EMA model does not need to be wrapped by DDP
    
        # setup learning rate schedule and starting epoch
        lr_scheduler, num_epochs = create_scheduler(args, optimizer)
        start_epoch = 0
        if args.start_epoch is not None:
            # a specified start_epoch will always override the resume epoch
            start_epoch = args.start_epoch
        elif resume_epoch is not None:
            start_epoch = resume_epoch
        if lr_scheduler is not None and start_epoch > 0:
            lr_scheduler.step(start_epoch)
    
        if args.local_rank == 0:
            _logger.info('Scheduled epochs: {}'.format(num_epochs))
    
        # create the train and eval datasets
        dataset_train = create_dataset(
            args.dataset,
            root=args.data_path, split=args.train_split, is_training=True,
            batch_size=args.batch_size, repeats=args.epoch_repeats)
        dataset_eval = create_dataset(
            args.dataset, root=args.data_path, split=args.val_split, is_training=False, batch_size=args.batch_size)
    
        # setup mixup / cutmix
        collate_fn = None
        mixup_fn = None
        mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
        if mixup_active:
            mixup_args = dict(
                mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
                prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
                label_smoothing=args.smoothing, num_classes=args.num_classes)
            if args.prefetcher:
                assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
                collate_fn = FastCollateMixup(**mixup_args)
            else:
                mixup_fn = Mixup(**mixup_args)
    
        # wrap dataset in AugMix helper
        if num_aug_splits > 1:
            dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
    
        # create data loaders w/ augmentation pipeiine
        train_interpolation = args.train_interpolation
        if args.no_aug or not train_interpolation:
            train_interpolation = data_config['interpolation']
        loader_train = create_loader(
            dataset_train,
            input_size=data_config['input_size'],
            batch_size=args.batch_size,
            is_training=True,
            use_prefetcher=args.prefetcher,
            no_aug=args.no_aug,
            re_prob=args.reprob,
            re_mode=args.remode,
            re_count=args.recount,
            re_split=args.resplit,
            scale=args.scale,
            ratio=args.ratio,
            hflip=args.hflip,
            vflip=args.vflip,
            color_jitter=args.color_jitter,
            auto_augment=args.aa,
            num_aug_splits=num_aug_splits,
            interpolation=train_interpolation,
            mean=data_config['mean'],
            std=data_config['std'],
            num_workers=args.workers,
            distributed=args.distributed,
            collate_fn=collate_fn,
            pin_memory=args.pin_mem,
            use_multi_epochs_loader=args.use_multi_epochs_loader
        )
    
        loader_eval = create_loader(
            dataset_eval,
            input_size=data_config['input_size'],
            batch_size=args.validation_batch_size_multiplier * args.batch_size,
            is_training=False,
            use_prefetcher=args.prefetcher,
            interpolation=data_config['interpolation'],
            mean=data_config['mean'],
            std=data_config['std'],
            num_workers=args.workers,
            distributed=args.distributed,
            crop_pct=data_config['crop_pct'],
            pin_memory=args.pin_mem,
        )
    
        # setup loss function
        if args.jsd:
            assert num_aug_splits > 1  # JSD only valid with aug splits set
            train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).npu()
        elif mixup_active:
            # smoothing is handled with mixup target transform
            train_loss_fn = SoftTargetCrossEntropy().npu()
        elif args.smoothing:
            train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).npu()
        else:
            train_loss_fn = nn.CrossEntropyLoss().npu()
        validate_loss_fn = nn.CrossEntropyLoss().npu()
    
        # setup checkpoint saver and eval metric tracking
        eval_metric = args.eval_metric
        best_metric = None
        best_epoch = None
        saver = None
        output_dir = ''
        if args.local_rank == 0:
            if args.experiment:
                exp_name = args.experiment
            else:
                exp_name = '-'.join([
                    datetime.now().strftime("%Y%m%d-%H%M%S"),
                    safe_model_name(args.model),
                    str(data_config['input_size'][-1])
                ])
            output_dir = get_outdir(args.output if args.output else './output/train', exp_name)
            decreasing = True if eval_metric == 'loss' else False
            saver = CheckpointSaver(
                model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
                checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist)
            with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
                f.write(args_text)
    
        try:
            for epoch in range(start_epoch, num_epochs):
                if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
                    loader_train.sampler.set_epoch(epoch)
    
                train_metrics = train_one_epoch(
                    epoch, model, loader_train, optimizer, train_loss_fn, args,
                    lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
                    amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
    
                if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                    if args.local_rank == 0:
                        _logger.info("Distributing BatchNorm running means and vars")
                    distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
    
                eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
    
                if model_ema is not None and not args.model_ema_force_cpu:
                    if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                        distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
                    ema_eval_metrics = validate(
                        model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
                    eval_metrics = ema_eval_metrics
    
                if lr_scheduler is not None:
                    # step LR for next epoch
                    lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
    
                update_summary(
                    epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
                    write_header=best_metric is None)
    
                if saver is not None:
                    # save proper checkpoint with eval metric
                    save_metric = eval_metrics[eval_metric]
                    best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
    
        except KeyboardInterrupt:
            pass
        if best_metric is not None:
            _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
    
    
    def train_one_epoch(
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249
    • 250
    • 251
    • 252
    • 253
    • 254
    • 255
    • 256
    • 257
    • 258
    • 259
    • 260
    • 261
    • 262
    • 263
    • 264
    • 265
    • 266
    • 267
    • 268
    • 269
    • 270
    • 271
    • 272
    • 273
    • 274
    • 275
    • 276
    • 277
    • 278
    • 279
    • 280
    • 281
    • 282
    • 283
    • 284
    • 285
    • 286
    • 287
    • 288
    • 289
    • 290
    • 291
    • 292
    • 293
    • 294
    • 295
    • 296
    • 297
    • 298
    • 299
    • 300
    • 301
    • 302
    • 303
    • 304
    • 305
    • 306
    • 307
    • 308
    • 309
    • 310
    • 311
    • 312
    • 313
    • 314
    • 315
    • 316
    • 317
    • 318
    • 319
    • 320
    • 321
    • 322
    • 323
    • 324
    • 325
    • 326
    • 327
    • 328
    • 329
    • 330
    • 331
    • 332
    • 333
    • 334
    • 335
    • 336

    剩下的详见:
    https://programtalk.com/python-more-examples/timm.data.AugMixDataset/

  • 相关阅读:
    Java 定时任务-最简单的3种实现方法
    Mybatis的学习(一)只Mybatis的介绍和使用
    Redis持久化实战
    bootstrap-fileinput拦截文件上传处理失败,根据后台返回数据处理
    鸿蒙3.0将删除谷歌代码,只是为让国产系统更纯粹
    一幅长文细学GaussDB(二)——数据库基础知识
    09.16 周六|Move Dev Meetup上海站火热报名中
    大数据高级开发工程师——Spark学习笔记(9)
    江协科技51单片机学习- p16 矩阵键盘
    测试平台前端部署
  • 原文地址:https://blog.csdn.net/hhhhhhhhhhwwwwwwwwww/article/details/127987371