• CogView中的RowParallelLinear


    入门小菜鸟,希望像做笔记记录自己学的东西,也希望能帮助到同样入门的人,更希望大佬们帮忙纠错啦~侵权立删。

    目录

    一、原理

    二、代码解析

     1、__init__

    (1)参数说明

    (2)初始化部分与ColumnParallelLinear类似(就是W是与上面的类似转置过来的)

    2、forward(也和ColumnParallelLinear差不多)


    ✨补:下文中有关ColumnParallelLinear的解说,可以看看往期博文

    CogView中的ColumnParallelLinear_tt丫的博客-CSDN博客

    一、原理

    简单来说就是基于模型分片地按行切分权重的线性变换

    权重:(p为分区数量,即GPU数量);

    偏置:B ;

    输入:[X_1, ..., X_p];

    输出:Y;

    表达式:(和的结果)


    二、代码解析

    (代码位置:model/mpu/layers)

     1、__init__

    (1)参数说明

    • input_size:矩阵W的第一维;
    • output_size:矩阵A的第二维度;
    • bias:是否添加偏置;
    • input_is_parallel:如果为真,我们假设输入已经在GPU上拆分,并且不再拆分。假则需要我们自行拆分;
    • init_method:初始化权重的方法;
    • stride:用于跨距线性层;
    • keep_master_weight_for_test:这是为测试而添加的,应设置为False。它返回用于初始化的主权重;
    1. class RowParallelLinear(torch.nn.Module):
    2. """Linear layer with row parallelism.
    3. The linear layer is defined as Y = XA + b. A is parallelized along
    4. its first dimension and X along its second dimension as:
    5. - -
    6. | A_1 |
    7. | . |
    8. A = | . | X = [X_1, ..., X_p]
    9. | . |
    10. | A_p |
    11. - -
    12. Arguments:
    13. input_size: first dimension of matrix A.
    14. output_size: second dimension of matrix A.
    15. bias: If true, add bias. Note that bias is not parallelized.
    16. input_is_parallel: If true, we assume that the input is already
    17. split across the GPUs and we do not split
    18. again.
    19. init_method: method to initialize weights. Note that bias is always set
    20. to zero.
    21. stride: For the strided linear layers.
    22. keep_master_weight_for_test: This was added for testing and should be
    23. set to False. It returns the master weights
    24. used for initialization.
    25. """
    26. def __init__(self, input_size, output_size, bias=True,
    27. input_is_parallel=False,
    28. init_method=init.xavier_normal_, stride=1,
    29. keep_master_weight_for_test=False):
    30. super(RowParallelLinear, self).__init__()
    31. # Keep input parameters
    32. self.input_size = input_size
    33. self.output_size = output_size
    34. self.input_is_parallel = input_is_parallel

    (2)初始化部分与ColumnParallelLinear类似(就是W是与上面的类似转置过来的)

    1. # Divide the weight matrix along the last dimension.
    2. world_size = get_model_parallel_world_size()#获取进程数(每个进程组里有多少个进程)——默认情况下,只有一个进程组
    3. self.input_size_per_partition = divide(input_size, world_size)#获取每个权重分区的大小
    4. # Parameters.
    5. # Note: torch.nn.functional.linear performs XA^T + b and as a result
    6. # we allocate the transpose.
    7. self.weight = Parameter(torch.Tensor(self.output_size,
    8. self.input_size_per_partition))
    9. self.weight.model_parallel = True
    10. #偏置
    11. if bias:
    12. self.bias = Parameter(torch.Tensor(self.output_size))
    13. # Always initialize bias to zero.
    14. with torch.no_grad():
    15. self.bias.zero_()
    16. else:
    17. self.register_parameter('bias', None)
    18. # Initialize weight.切分权重
    19. self.master_weight = _initialize_affine_weight(
    20. self.weight, self.output_size, self.input_size,
    21. self.input_size_per_partition, 1, init_method,
    22. stride=stride, return_master_weight=keep_master_weight_for_test)

    2、forward(也和ColumnParallelLinear差不多)

    1. def forward(self, input_):
    2. # Set up backprop all-reduce.
    3. if self.input_is_parallel:#输入已经在GPU上拆分(X1,……,Xp)
    4. input_parallel = input_
    5. else:#未划分则进行划分
    6. input_parallel = scatter_to_model_parallel_region(input_)
    7. # Matrix multiply.
    8. output_parallel = F.linear(input_parallel, self.weight)#XW
    9. # All-reduce across all the partitions.
    10. output_ = reduce_from_model_parallel_region(output_parallel)#对所有进程内的数据进行汇总,并且让所有进程都获取最终结果(就是比如说本来第一块GPU的数据是X1*W1,然后汇总后每块GPU上的数据都是XW)
    11. #偏置
    12. if self.bias is not None:
    13. output = output_ + self.bias#Y=XW+B
    14. else:
    15. output = output_
    16. return output

    欢迎大家在评论区批评指正,谢谢~

  • 相关阅读:
    Lampiao
    pop3 110端口渗透测试
    使用 Pycharm 调试远程代码
    C++ Primer Plus第五版笔记(p1-50)
    2 学习基础命令行命令
    spicy(一)基本定义
    Golang 中的字符串:常见错误和最佳实践
    VR数字化线上展馆降低企业投入成本和周期
    深入理解蓝牙BLE之“扩展广播”
    golang1.21新特性速览
  • 原文地址:https://blog.csdn.net/weixin_55073640/article/details/126467082