码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • torch.nn.DataParallel类


    参考   torch.nn.DataParallel类 - 云+社区 - 腾讯云

    class torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)[source]

    Implements data parallelism at the module level.

    This container parallelizes the application of the given module by splitting the input across the specified devices by chunking in the batch dimension (other objects will be copied once per device). In the forward pass, the module is replicated on each device, and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module.

    The batch size should be larger than the number of GPUs used.

    See also: Use nn.DataParallel instead of multiprocessing

    Arbitrary positional and keyword inputs are allowed to be passed into DataParallel but some types are specially handled. tensors will be scattered on dim specified (default 0). tuple, list and dict types will be shallow copied. The other types will be shared among different threads and can be corrupted if written to in the model’s forward pass.

    The parallelized module must have its parameters and buffers on device_ids[0] before running this DataParallel module.

    Warning

    In each forward, module is replicated on each device, so any updates to the running module in forward will be lost. For example, if module has a counter attribute that is incremented in each forward, it will always stay at the initial value because the update is done on the replicas which are destroyed after forward. However, DataParallel guarantees that the replica on device[0] will have its parameters and buffers sharing storage with the base parallelized module. So in-place updates to the parameters or buffers on device[0] will be recorded. E.g., BatchNorm2d and spectral_norm() rely on this behavior to update the buffers.

    Warning

    Forward and backward hooks defined on module and its submodules will be invoked len(device_ids) times, each with inputs located on a particular device. Particularly, the hooks are only guaranteed to be executed in correct order with respect to operations on corresponding devices. For example, it is not guaranteed that hooks set via register_forward_pre_hook() be executed before all len(device_ids) forward() calls, but that each such hook be executed before the corresponding forward() call of that device.

    Warning

    When module returns a scalar (i.e., 0-dimensional tensor) in forward(), this wrapper will return a vector of length equal to number of devices used in data parallelism, containing the result from each device.

    Note

    There is a subtlety in using the pack sequence -> recurrent network -> unpack sequence pattern in a Module wrapped in DataParallel. See My recurrent network doesn’t work with data parallelism section in FAQ for details.

    Parameters

    • module (Module) – module to be parallelized

    • device_ids (list of python:int or torch.device) – CUDA devices (default: all devices)

    • output_device (int or torch.device) – device location of output (default: device_ids[0])

    Variables

    ~DataParallel.module (Module) – the module to be parallelized

    Example:

    1. >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
    2. >>> output = net(input_var) # input_var can be on any device, including CPU

  • 相关阅读:
    合宙昆仑镜LCD驱动测试
    Jackson ObjectNode JsonNode --> FastJson JSONObject
    C++11之线程库
    基于A*、RBFS 和爬山算法求解 TSP问题(Matlab代码实现)
    Flink的API分层、架构与组件原理、并行度、任务执行计划
    数组传参及 &数组
    【已解决】UE5 plugin ‘xxx‘ failed to load because module ‘xxx‘ could not be found.
    设置Oracle环境变量
    Git操作指南:子模块、用户名修改和Subtree
    virtualbx_vagrant
  • 原文地址:https://blog.csdn.net/weixin_36670529/article/details/101776756
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号