• Facebook的ZeRO算法原理及简单代码实验(小显卡训大模型)


    1、MP效果已经很好了,为什么还要ZeRO?

    模型并行(MP:https://zhuanlan.zhihu.com/p/366906920): 将模型横向或垂直分割,将计算和参数划分到每一层,跨多个设备,需要每一层之间的重要通信。在 GPU 之间通信带宽高的单个节点内工作良好,但跨节点工作会较慢。
    ZeRO:同样是跨多个设备,将模型和参数放到不同设备,但是通信量却大大减少。

    2、那ZeRO具体是怎么优化的呢?

    2.1 优化器回顾

    SGD: 没有动量概念
    在这里插入图片描述

    Adam优化器:Adam在SGD基础上,为每个参数梯度增加了一阶动量(momentum)和二阶动量(variance)

    在这里插入图片描述

    2.2 显存去哪了

    1. 模型状态(model states): 模型参数(fp16)、模型梯度(fp16)和Adam状态(fp32的模型参数备份,fp32的momentum和fp32的variance)。
    2. 剩余状态(residual states): 除了模型状态之外的显存占用,包括激活值(activation)、各种临时缓冲区(buffer)以及无法使用的显存碎片(fragmentation)。
      问题一:
      GPT-2(2B)在混合精度训练的情况下,放到显卡需要的内存是多少?

    2.3 ZeRO原理

    针对模型状态的存储优化,ZeRO使用的方法是分片,即每张卡只存 1/N 的模型状态量,这样系统内只维护一份模型状态。
    ZeRO-1:optimizer分片
    ZeRO-2:optimizer + Gradient分片
    ZeRO-3:optimizer + Gradient + model分片
    在这里插入图片描述

    VS DDP

    在这里插入图片描述

    1.3.1 ZeRO-1

    解决的问题:Optimizer state的冗余。
    没有即没有将模型本身进行分,也没有将Gradient进行分片,而是只将优化器进行分片。
    过程动画:https://zhuanlan.zhihu.com/p/394064174 (原文链接)
    动画链接(可能会过期):https://vdn6.vzuu.com/SD/bd03b0bc-ef95-11eb-8ee1-ce96bf022449.mp4?pkey=AAWvE2ChU9kHMLO_n5M8CJeSDHqfRZRy2dMrPU4eOnfzHOWOaGYoGlxJIEAzhzJT4Fgsk-wW1oESc3ngsFZcCFRM&c=avc.0.0&f=mp4&pu=078babd7&bu=078babd7&expiration=1660109505&v=ks6
    在这里插入图片描述

    • 训练过程与DDP类似。forward过程由每个rank的GPU独自完整的完成,然后进行backward过程。在backward过程中,梯度通过allReduce进行同步。
    • Optimizer state 使用贪心策略基于参数量进行分片,以此确保每个rank几乎拥有相同大小的优化器内存。
    • 每个rank只负责更新当前优化器分片的部分,由于每个rank只有分片的优化器state,所以当前rank忽略其余的state。
    • 在更新过后,通过广播或者allGather的方式确保所有的rank都收到最新更新过后的模型参数。
    • ZeRO-1 非常适合使用类似Adam进行优化的模型训练,因为Adam拥有额外的参数m(momentum)与v(variance),特别是FP16混合精度训练。
    • ZeRO-1 不适合使用SGD类似的优化器进行模型训练,因为SGD只有较少的参数内存,并且由于需要更新模型参数,导致额外的通讯成本。

    1.3.2 ZeRO-2

    解决的问题:gradient 的冗余。
    为了减少梯度Gradient冗余以此进一步节省内存,ZeRO-2提出gradient sharding,在FairScale里称之为Sharded Data Parallel(SDP)。相比与ZeRO-1, ZeRO-2除了对optimizer state进行切分,还对Gradient进行了切分。

    • 像ZeRO-1 一样将optimizer的参数进行分片,并安排在不同的rank上。
    • 在backward过程中,gradients被reduce操作到对应的rank上,取代了all-reduce,以此减少了通讯开销。
    • 每个rank独自更新各自负责的参数。
    • 在更新操作之后,广播或allGather保证所有的ranks接受到更新后的参数。
      在这里插入图片描述

    1.3.3 ZeRO-3
    解决的问题:model参数分割。
    为了进一步节省更多的内存,ZeRO-3提出进行模型参数的分片。类似以上两种分片方式,ranks仅负责模型参数的切片。可以进行参数切片的原因主要有以下两点:

    • AllReduce操作可以被拆分为Reduce与allgather操作的结合。
    • 模型的每一层拥有该层的完整参数,并且整个层能够直接被一个GPU装下。所以计算前向的时候,除了当前rank需要的层之外,其余的层的参数可以抛弃。
      过程动画:https://vdn6.vzuu.com/SD/2a0e318c-ef96-11eb-9cd5-6ad0d31fb0b0.mp4?pkey=AAWtWJs2zkZXzKwAdUt7rtD8scwCEPuOBq7Pn7VLpGstNISogXgsust_iUhM7RpuZG1rzqTPelWdmGfW_PpBk0lw&c=avc.0.0&f=mp4&pu=078babd7&bu=078babd7&expiration=1660046771&v=ks6
      在这里插入图片描述

    1.4 ZeRO通信分析

    • DDP:reduce-scatter + all-gather,分别需要Ψ的通信量,每gpu共计消耗2Ψ通信量。
    • Pos、Pos+g:reduce-scatter + all-gather,通信量每gpu共计消耗2Ψ。
    • Pos+g+p:需要对梯度进行一次reduce-scatter操作(因为每个gpu各自负责部分参数的更新,因此不需要对梯度进行all-gather操作),对参数需要进行正向和反向两次传递,所以需要消耗2Ψ通信量,共计每gpu消耗3Ψ通信量。

    问题二:请分析分析MP @方恺齐 的通信量?

    1.5 FSDP源码分析

    https://github.com/pytorch/pytorch/blob/b91ff5e361623685799b8ef725a91b756685a9ae/torch/distributed/fsdp/fully_sharded_data_parallel.py#L462
    代码问题:
    1、FSDP对模型参数是怎么进行操作和存储的,和普通的DDP有什么不同?
    将参数拉平,并存储了每个参数的size
    2、FSDP分别是在哪个函数实现模型参数分割和合并的?
    _sharded_parameters
    3、对于模型参数不能均分的情况,FSDP采用了什么策略?
    pad
    https://github.com/pytorch/pytorch/blob/e81664449559f95d0b8d0fe57d66544a0ab84fe8/torch/distributed/fsdp/fully_sharded_data_parallel.py#L3237

    1.6 实际应用

    1.6.1 实验代码

    import argparse
    import os
    
    import torch
    import torch.distributed as dist
    import torch.multiprocessing as mp
    import torch.nn as nn
    import torch.optim as optim
    from fairscale.optim.oss import OSS
    from torch.nn.parallel import DistributedDataParallel as DDP
    from fairscale.nn.data_parallel import ShardedDataParallel as SDP 
    from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
    from torch.cuda.amp import autocast
    from torch.cuda.amp import GradScaler
    
    
    scaler = GradScaler()
    
    
    class MyModel(nn.Module):
    
        def __init__(self, vocab_size, embed_dim, inner_dim, hidden_dim, num_choices, nlayers=2):
            super().__init__()
            self.nlayers = nlayers
            self.embed = nn.Embedding(vocab_size, embed_dim)
            self.linear = nn.Linear(embed_dim, hidden_dim)
            self.fn = nn.Sequential(nn.Linear(hidden_dim, inner_dim), 
                                        nn.ReLU(),
                                        nn.Linear(inner_dim, hidden_dim))
            self.drop = nn.Dropout(0.1)
            self.classifier = nn.Linear(hidden_dim, num_choices)
            
        def forward(self, input_ids):
            embed = self.embed(input_ids)
            v = self.linear(embed)
            v = self.fn(v)
            last_token_hidden = v[:, -1]
            last_token_hidden = self.drop(last_token_hidden)
            logits = self.classifier(last_token_hidden)
            return logits
    
    
    def initialize_distributed(args):
        """Initialize torch.distributed."""
    
    # Manually set the device ids.
    device = args.rank % torch.cuda.device_count()
    print(f'rank = {args.rank} || local_rank = {args.local_rank}')
    if args.local_rank is not None:
        device = args.local_rank
    torch.cuda.set_device(device)
    init_method = 'tcp://'
    master_ip = os.getenv('MASTER_ADDR', 'localhost')
    master_port = os.getenv('MASTER_PORT', '6000')
    init_method += master_ip + ':' + master_port
    print(f'init_method = {init_method}')
    dist.init_process_group(backend='nccl',
                            world_size=args.world_size,
                            rank=args.rank,
                            init_method=init_method)
    dist.all_reduce(torch.zeros(1).cuda())
    
    parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
    parser.add_argument('--autocast', action='store_true', help='Run in pytorch autocast mode.')
    parser.add_argument('--zero3', action='store_true', help='Run in pytorch autocast mode.')
    parser.add_argument('--zero2', action='store_true', help='Run in pytorch autocast mode.')
    parser.add_argument('--zero1', action='store_true', help='Run in pytorch autocast mode.')
    
    parser.add_argument('--local_rank',
                        type=int,
                        default=None,
                        help='local rank passed from distributed launcher')
    
    args = parser.parse_args()
    
    args.rank = int(os.getenv('RANK', '0'))
    args.world_size = int(os.getenv("WORLD_SIZE", '1'))
    
    initialize_distributed(args)
    
    total_steps = 10000000
    batch_size = 1
    vocab_size = 20000
    data_len = 512
    embed_dim = 10000
    inner_dim = 10000
    hidden_dim = 20000
    num_choices = 10000
    loss_fn = nn.CrossEntropyLoss()
    
    device = "cuda:{}".format(torch.cuda.current_device())
    
    model = MyModel(vocab_size, embed_dim, inner_dim, hidden_dim, num_choices)
    model.to(device)
    
    n_all_param = sum([p.nelement() for p in model.parameters()])
    
    if args.rank == 0:
        print(f'n_all_param: {n_all_param}')
        for k, v in model.named_parameters():
            print(f'rank: {args.rank} --- {k} shape: {v.shape}')
    
    if args.zero1:
        base_optimizer_arguments = {'lr':0.05}
        base_optimizer = torch.optim.Adam
        optimizer = OSS(
            params=model.parameters(),
            optim=base_optimizer,
            **base_optimizer_arguments)
        model = DDP(model)
    
    elif args.zero2:
        base_optimizer_arguments = {'lr':0.05}
        base_optimizer = torch.optim.Adam 
        optimizer = OSS(
            params=model.parameters(),
            optim=base_optimizer,
            **base_optimizer_arguments)
        model = SDP(model, optimizer)
        if args.autocast:
            from fairscale.optim.grad_scaler import ShardedGradScaler
            scaler = ShardedGradScaler()
    
    elif args.zero3:
        print(f'zero3 ----')
        optimizer = optim.Adam(model.parameters(), lr=0.05)
        model = FSDP(model, mixed_precision=True)
    
    else:
        optimizer = optim.Adam(model.parameters(), lr=0.05)
        model = DDP(model)
    
    if args.rank == 0:
        for k, v in model.named_parameters():
            print(f'rank: {args.rank} --- {k} shape: {v.shape}')
    
    for step in range(total_steps):
        model.zero_grad()
        data = torch.randint(vocab_size, (batch_size, data_len)).to(device) 
        labels = torch.randint(2, [batch_size]).to(device) 
        if args.autocast:
            with autocast():
                logits = model(data) 
            loss = loss_fn(logits, labels)
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            scaler.step(optimizer)
            scaler.update()
        else:
            logits = model(data) 
            loss = loss_fn(logits, labels)
            loss.backward()
            optimizer.step()
    
        if args.rank == 0:
            print(f'step: {step} loss: {loss.item()}')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156

    运行上述代码脚本: CUDA_VISIBLE_DEVICES=0,1 OMP_NUM_THREADS=3 python -W ignore -m torch.distributed.launch --nproc_per_node 2 --master_addr 127.0.0.1 --master_port 8927 fairscale_fsdp.py

    结果:
    在这里插入图片描述

    选答题:
    GPT-2(1.5B)在ZeRO-1,ZeRO-2,ZeRO-3模式,2张显卡的情况下,每张显卡内存分别是多少?

    1.6.2 基于Fairscale库的使用代码:

    ZeRO-1

    from fairscale.optim.oss import OSS
    from torch.nn.parallel import DistributedDataParallel as DDP
    
    
    base_optimizer_arguments = {'lr':0.05}
    base_optimizer = torch.optim.Adam
    
    optimizer = OSS(
            params=model.parameters(),
            optim=base_optimizer,
            **base_optimizer_arguments)
    model = DDP(model)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    ZeRO-2

    from fairscale.optim.oss import OSS
    from fairscale.nn.data_parallel import ShardedDataParallel as SDP 
    
    base_optimizer_arguments = {'lr':0.05}
    base_optimizer = torch.optim.Adam 
    optimizer = OSS(
        params=model.parameters(),
        optim=base_optimizer,
        **base_optimizer_arguments)
    model = SDP(model, optimizer)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    ZeRO-3

    from fairscale.optim.oss import OSS
    from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
    
    optimizer = optim.Adam(model.parameters(), lr=0.05)
    model = FSDP(model, mixed_precision=True)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    1.6.3 补充部分

    下面实验是基于上面的1.6.1实验代码进行的
    将整个模型包在一个FSDP

    model = FSDP(model)
    
    • 1

    在这里插入图片描述

    将每个参数分别包一个FSDP

    model.embed = FSDP(model.embed)
    model.linear = FSDP(model.linear)
    model.fn = FSDP(model.fn)
    model.classifier = FSDP(model.classifier)
    
    • 1
    • 2
    • 3
    • 4

    在这里插入图片描述

    选答题:怎么将模型参数恢复呢?

  • 相关阅读:
    【UE 材质】制作加载图案(2)
    AIGC时代:未来已来
    OpenHarmony实战开发-如何实现防盗链应用功能。
    mysql 备库重做
    vue 组件封装 综合案例2
    python后端相关知识点汇总(十二)
    【eCharts】第三部分 在同一个容器中展示多个图表
    开源模型应用落地-工具使用篇-向量数据库(三)
    消息发送超过时间限制如何撤回?
    每天一道算法题(四)——移动零(将数组中的零移到最后面)
  • 原文地址:https://blog.csdn.net/weixin_43922901/article/details/126246309