• torch.expand()函数用法


    `torch.expand()` 是 PyTorch 中用于扩展张量(tensor)维度的函数。它用于在指定维度上复制数据,以匹配目标形状。这是一种常见的操作,通常用于广播(broadcasting)操作,以便在进行元素级操作时能够处理不同形状的张量。
    @[TOC](这里写目录标题)

    `torch.expand()` 的使用方法如下:

    ```python
    expanded_tensor = original_tensor.expand(target_shape)
    ```

    - `original_tensor`: 要扩展的原始张量。
    - `target_shape`: 目标形状,应该是一个元组或列表,描述了您要扩展到的最终形状。

    举个例子,如果有一个形状为 (3, 1) 的张量,我们可以使用 `torch.expand()` 来扩展它,使其变为 (3, 5) 的形状,如下所示:

    ```python
    import torch

    # 创建原始张量
    original_tensor = torch.Tensor([[1], [2], [3]])

    # 使用expand扩展形状
    expanded_tensor = original_tensor.expand(3, 5)

    print(original_tensor)
    print(expanded_tensor)

    #输出结果:
    #tensor([[1.],
    #        [2.],
    #        [3.]])
    #tensor([[1., 1., 1., 1., 1.],
    #        [2., 2., 2., 2., 2.],
    #        [3., 3., 3., 3., 3.]])

    ```

    这将生成一个形状为 (3, 5) 的张量,其中原始张量的值在第一个维度上被复制,以匹配目标形状。

    请注意,`torch.expand()` 并不会真正复制数据,它只是提供一个视图,因此不会增加内存使用。如果需要复制数据以创建新张量,可以使用 `torch.clone()` 或 `torch.copy()`。此外,要进行广播操作,通常可以直接使用运算符(例如 `+`,`*`),PyTorch 会自动执行广播,无需显式使用 `torch.expand()`。


    `torch.expand()` 主要用于扩展张量的维度以匹配目标形状,但还可以使用不同的参数来改变其行为,以满足其他需求。

    `torch.expand()` 可以接受额外的参数,以更精细地控制张量的扩展行为。其中一个重要的参数是 `stride`,可以用来指定在扩展过程中在某一维度上的步幅(stride)。以下是一些用法示例:

    1. **使用步幅扩展维度**:

       您可以使用 `stride` 参数来控制如何扩展维度。这是一种常见的情况,其中您可能希望以某种特定的间隔复制数据。

       ```python
       import torch

       original_tensor = torch.Tensor([[1, 2, 3]])
       expanded_tensor = original_tensor.expand(3, -1, -1)  # 扩展3倍,每隔1个元素复制一次
       ```

    2. **使用负数扩展维度**:

       当传递负数值给 `torch.expand()` 时,它将自动计算维度大小,以便适应目标形状。

       ```python
       import torch

       original_tensor = torch.Tensor([[1, 2, 3]])
       expanded_tensor = original_tensor.expand(-1, 4, -1)  # 扩展为3倍宽度,4倍高度
       ```

    这些附加参数可以让您更灵活地控制扩展的行为,以满足不同的需求。但请注意,`torch.expand()` 仍然不会复制数据,它只是提供了一个视图,以便在广播操作时使用。如果需要真正的数据复制,您可以使用 `torch.clone()` 或 `torch.copy()` 等操作。

  • 相关阅读:
    互联网黑话
    北斗导航 | 基于GPS/BDS多星座加权因子图优化的行人智能手机导航
    域名系统DNS
    数字化工厂系统有什么现实优势
    如何使用libavfilter库给pcm音频采样数据添加音频滤镜?
    2023 年全国大学生数学建模A题目-定日镜场的优化设计
    Vue中混入(mixin)的使用
    从零学习开发一个RISC-V操作系统(三)丨嵌入式操作系统开发的常用概念和工具
    论文阅读—— CEASC(cvpr2023)
    ​力扣解法汇总946-验证栈序列
  • 原文地址:https://blog.csdn.net/thy0000/article/details/134016260