作者:王磊
更多精彩分享,欢迎访问和关注:https://www.zhihu.com/people/wldandan
在前面一系列AI工程文章中,我们先后介绍了如何启动一个机器学习项目、如何处理数据、以及如何设计开发网络。接下来,我将继续介绍如何有效进行模型训练。
模型训练对算力的要求较高,在当前的AI领域,训练的数据规模和模型参数正呈现越来越大的趋势,单个设备的算力已经无法支撑模型的高效训练。因此,需要引入分布式并行来解决单个设备算力不足的问题。
本篇文章将探讨为什么需要分布式并行、分布式并行的策略、以及MindSpore分布式并行的实现机制和优势。
随着 OpenAI GPT-3 发布开始,各大厂商先后推出自己的大模型,人工智能产业开始了新一轮的激烈角逐,而且有愈演愈烈之势。而打造大模型并非易事,需要消耗庞大的数据、算力、算法等各种软硬件资源。以GPT系列为例:
引自: 2022,大模型还能走多远
对于如此大规模的模型及训练数据,使用单卡的方式完全无法完成训练。以GPT-3模型训练为例 ,使用 8 张 V100 显卡,训练时长预计需要36 年, 512 张 V100显卡 ,训练时间接近 7 个月,而1024 张A100的训练时长可以减少到 1 个月。时间越长,意味着成本越高,大模型的训练可能是普通人难以负担的。因此,需要分布式并行的方式来增强算力、加速数据处理和模型训练。
业界主流的分布式并行类型包括数据并行、模型并行和混合并行三种,围绕切分的内容(数据切分、模型切分)来划分。
从MindSpore框架层面,它支持4种并行模式,包括数据并行模式、自动并行模式(融合数据并行和算子级模型并行)、半自动并行(手动配置算子切分策略)以及混合并行(手动切分模型)。
在实际的使用中,开发者可以通过context.set_auto_parallel_context()
来设置分布式训练的模式。
- # 数据并行模式
- context.set_auto_parallel_context(parallel_mode=context.ParallelMode.DATA_PARALLEL)
- # 半自动并行模式
- context.set_auto_parallel_context(parallel_mode=context.ParallelMode.SEMI_AUTO_PARALLEL)
- # 自动并行模式
- context.set_auto_parallel_context(parallel_mode=context.ParallelMode.AUTO_PARALLEL)
- # 混合并行模式
- context.set_auto_parallel_context(parallel_mode=context.ParallelMode.HYBRID_PARALLEL)
接下来介绍下MindSpore分布式并行实现的基本原理。
假设我们有8张GPU卡或者昇腾的NPU卡来训练图片分类的模型,训练的批量为160,那么每张卡上面分到的批量数据(min-batch)为20,每张卡基于样本数据完成训练。因为各张卡上处理的数据样本不同,所以获得的梯度会有些差别。因此,需要对梯度进行聚合(求和、均值)等计算来保持和单卡训练相同的结果,最后再更新参数。 具体的过程如下图:
mindspore.communication.init
接口可以完成通信的初始化工作。数据并行依赖于集合通信的操作,上面的例子中使用到了Broadcast以及AllReduce通信原语,其中Broadcast将数据分发到不同的卡,而AllReduce操作则完成不同卡上的梯度聚合操作,如下图:
模型并行从形式可以有自动、半自动方式,实际的并行过程中,需要通过层间并行(模型以层为单位切分到多个硬件设备)或者层内并行(每层的模型参数切分到多个硬件设备)的模型并行方式来解决。两者的差异如下图,层间并行每卡执行的网络模型存在差异,而层内并行不会改变网络结构,而是将每层的模型参数切分到不同的设备上实现并行。
由上图不难看出,模型并行的难度相比数据并行的难度更高,数据并行只拆分数据批量,单卡的网络结构并不改变,而模型并行需要以层或者每层的模型参数拆分到多卡运行。对于开发者来说,如果手动进行模型切分,需要解决如下几个难题:
MindSpore从框架层面提供了自动、半自动和混合并行训练方式,可以帮助开发者使用单机的脚本实现并行的算法逻辑,降低了分布式训练的难度,提升了训练性能。下图揭示了MindSpore中并行机制的原理,整个分布式并行的流程如下:
1. 并行策略配置:通过前端API接口设置自动并行策略,不同的并行策略(自动、半自动、混合)决定切分策略、并行模型配置,通过context.set_auto_parallel_context选择针对模型的并行策略。
2. 分布式算子和张量排布:自动并行的流程会对输入的ANF计算图进行遍历,以分布式算子为单位对张量进行切分建模,表示一个算子的输入输出张量如何分布到集群各个卡上(Tensor分布)。这种模型充分地表达了张量和设备间的映射关系,用户无需感知模型各切片放到哪个设备上运行,框架会自动调度分配。
(1) 张量排布:为了得到张量的排布模型,每个算子都具有切分策略,它表示算子的各个输入在相应维度的切分情况。通常情况下只要满足以2为基、均匀分配的原则,张量的任意维度均可切分。以下图为例,这是一个三维矩阵乘(BatchMatMul)操作,它的切分策略由两个元组构成,分别表示input和weight的切分形式。其中元组中的元素与张量维度一一对应,2^N为切分份数,1表示不切。当用户想表示一个数据并行切分策略时,即input的batch维度切分,其他维度不切,可以表达为strategy=((2^N, 1, 1),(1, 1, 1));当表示一个模型并行切分策略时,即weight的非batch维度切分,这里以channel维度切分为例,其他维度不切,可以表达为strategy=((1, 1, 1),(1, 1, 2^N));当表示一个混合并行切分策略时,其中一种切分策略为strategy=((2^N, 1, 1),(1, 1, 2^N))。依据切分策略,分布式算子中定义了推导算子输入张量和输出张量的排布模型的方法。这个排布模型由device_matrix,tensor_shape和tensor map组成,分别表示设备矩阵形状、张量形状、设备和张量维度间的映射关系。分布式算子会进一步根据张量排布模型判断是否要在图中插入额外的计算、通信操作,以保证算子运算逻辑正确。半自动并行模式中,需要用户对算子手动配置切分策略实现并行,这也是自动和半自动最大的差异点。
(2) 张量排布变换:当前一个算子的输出张量模型和后一个算子的输入张量模型不一致时,就需要引入计算、通信操作的方式实现张量排布间的变化。自动并行流程引入了张量重排布算法(Tensor Redistribution),可以推导得到任意排布的张量间通信转换方式。下面三个样例表示公式Z=(X×W)×V的并行计算过程, 即两个二维矩阵乘操作,体现了不同并行方式间如何转换。
在Figure 1中,第一个数据并行矩阵乘的输出在行方向上存在切分,而第二个模型并行矩阵乘的输入需要全量张量,框架将会自动插入AllGather算子实现排布变换。
在Figure 2中,第一个模型并行矩阵乘的输出在列方向上存在切分,而第二个数据并行矩阵乘的输入在行方向上存在切分,框架将会自动插入等价于集合通信中AlltoAll操作的通信算子实现排布变换。
在Figure 3中,第一个混合并行矩阵乘的输出切分方式和第二个混合并行矩阵乘的输入切分方式一致,所以不需要引入重排布变换。但由于第二个矩阵乘操作中,两个输入的相关维度存在切分,所以需要插入AllReduce算子保证运算正确性。
分布式算子、张量排布/变换是自动并行实现的基础,总体来说这种分布式表达打破了数据并行和模型并行的边界。从脚本层面上,用户仅需构造单机网络,即可表达并行算法逻辑,框架将自动实现对整图切分。
(3) 切分策略搜索算法:自动并行模式支持并行策略传播(Sharding Propagation),能够有效地降低用户手配算子切分策略的工作量,算法将切分策略由用户配置的算子向未配置的算子传播。为进一步降低用户手配算子切分策略的工作量,支持切分策略完全自动搜索。为此,围绕硬件平台构建相应的代价函数模型,计算出一定数据量、一定算子在不同切分策略下的计算开销,内存开销及通信开销。然后通过动态规划算法或者递归规划算法,以单卡的内存上限为约束条件,高效地搜索出性能较优的切分策略。策略搜索这一步骤代替了用户手动指定模型切分,在短时间内可以得到较高性能的切分方案,极大降低了并行训练的使用门槛。半自动并行,
3. 分布式自动微分:图分片的过程包含正向网络以及反向网络切分,传统的手动模型切分除了需要关注正向网络通信还需要考虑网络反向的并行运算,MindSpore通过将通信操作包装为算子,并利用框架原有的自动微分操作自动生成通信算子反向,所以即便在进行分布式训练时,用户同样只需关注网络的前向传播,真正实现训练的全自动并行。
自动并行从开发态来说开发者最友好,降低了分布式并行的门槛,我们将介绍基于MindSpore,使用自动并行的方式完成Resnet-50模型的分布式训练。
说明:严禁转载本文内容,否则视为侵权。