• YOLOX源码之 wait_for_the_master


    主进程读取数据

    在函数 get_data_loader 中,下面这段代码的作用是在多节点分布式训练时,每个节点的主进程负责读取数据。

    1. if self.dataset is None:
    2. with wait_for_the_master():
    3. assert cache_img is None, \
    4. "cache_img must be None if you didn't create self.dataset before launch"
    5. self.dataset = self.get_dataset(cache=False, cache_type=cache_img)

    在 PyTorch 的分布式训练中,每个节点的主进程负责数据加载、模型初始化和一些其他的准备工作。这意味着在每个节点的主进程中,都会有一份数据加载的代码。

    这样做的好处是:

    1. 减轻主节点压力:每个节点的主进程可以独立地负责数据加载,减轻了主节点的负担,有助于更好地利用各个节点的资源。
    2. 数据分发效率高:在每个节点加载数据的情况下,数据可以直接在本地节点内分发给其他进程,避免了网络传输的开销,提高了数据加载的效率。 

    接下来我们看下函数 wait_for_the_master 的实现

    1. from contextlib import contextmanager
    2. @contextmanager
    3. def wait_for_the_master(local_rank: int = None):
    4. """
    5. Make all processes waiting for the master to do some task.
    6. Args:
    7. local_rank (int): the rank of the current process. Default to None.
    8. If None, it will use the rank of the current process.
    9. """
    10. if local_rank is None:
    11. local_rank = get_local_rank()
    12. if local_rank > 0:
    13. dist.barrier()
    14. yield
    15. if local_rank == 0:
    16. if not dist.is_available():
    17. return
    18. if not dist.is_initialized():
    19. return
    20. else:
    21. dist.barrier()

    @contextmanager

    @contextmanager是一个装饰器,用于定义上下文管理器(context manager)。上下文管理器可以用于创建一个资源的上下文,然后在进入和退出这个上下文时执行特定的操作,比如资源的获取和释放。 

    在python中要自定义一个上下文管理器,需要定义一个类,并实现其__enter__()和__exit()__方法。但使用装饰器@contextmanager可以更简洁的实现这点,具体来说,@contextmanager 装饰器可以将一个生成器函数转换成一个上下文管理器。生成器函数中的 yield 语句之前的代码会在进入上下文时执行,而 yield 语句之后的代码会在退出上下文时执行。

    dist.barrier()

    这里首先获取每个节点的local_rank,大于0说明不是主进程,dist.barrier() 是 PyTorch 中分布式通信库 torch.distributed 提供的一个同步操作,它的作用是在分布式环境中同步所有参与通信的进程,确保它们在某一点上同时到达了同步点。

    在分布式训练中,dist.barrier()的作用通常是用来同步各个进程的执行,以保证它们在某个重要的时刻处于同步状态。当一个进程调用了dist.barrier()后,它会被阻塞,直到所有参与通信的进程也都调用了dist.barrier(),此时所有进程才会解除阻塞,继续执行后续的代码。

    具体来说,dist.barrier() 的作用有以下几点:

    1. 同步数据加载:在数据加载完毕之后,可以使用 dist.barrier() 来确保所有进程都已经加载完数据,然后再开始训练。
    2. 同步模型初始化:在模型初始化完成之后,可以使用 dist.barrier() 来确保所有进程都已经初始化完成,然后再开始训练。
    3. 同步参数更新:在每个训练步骤中,可以使用 dist.barrier() 来确保所有进程都已经计算完梯度,并更新了参数,然后再进行下一步的计算。
    4. 同步模型评估:在模型评估阶段,可以使用 dist.barrier() 来确保所有进程都已经完成了评估任务,然后再进行汇总或其他后续操作。

    结合上面两段代码来看,在进入上下文后,每个节点的非主进程会阻塞在yield前的dist.barrier()处,而主进程则会执行self.get_dataset()读取数据,在每个节点的主进程都执行完self.get_dataset()后,会退出上下文,此时非主进程还是停留在yield前的dist.barrier()处,而主进程则会执行yield后的dist.barrier(),当所有进程都调用了dist.barrier()后,所有进程的阻塞被解除,继续执行后续的代码。

  • 相关阅读:
    2342.数位和相等数对的最大和
    RENIX_IPv6自动配置——网络测试仪实操
    2023第五届山东国际中医药产业展会,中医养生展,中医文化展
    【Shell脚本13】Shell 文件包含
    jmeter使用csv进行参数化及(运行后出现乱码错误解决)
    Unity Scene窗口获取鼠标位置
    JCE cannot authenticate the provider BC ,has unsign 异常排查解决
    Java8函数式编程-lambda表达式与stream流
    笔记本电脑没有声音?几招恢复声音流畅!
    【改造后序遍历】 98. 验证二叉搜索树
  • 原文地址:https://blog.csdn.net/ooooocj/article/details/139398994