CC's

Back

本文分为三部分:

  • 第一部分为官网 Distributed Data Parallel 设计思路翻译;
  • 第二部分为官网教程;
  • 第三部分为实际使用时的一些笔记。

Distributed Data Parallel 设计笔记#

torch.nn.parallel.DistributedDataParallel(DDP)透明1地执行分布式数据并行训练。该部分内容解释了 DDP 的运行原理及设计细节。

例子#

我们先从一个简单的 torch.nn.parallel.DistributedDataParallel 例子开始。这个例子使用了线性层作为本地模型,使用 DDP 包装后,分别运行了一次前向传播、反向传播以及一步优化。之后,本地模型的参数将被更新,其他线程的所有模型均应保持一致。

内部设计#

这一部分阐述了 torch.nn.parallel.DistributedDataParallel 的运行原理,深入了一次迭代中所有步骤的细节。

  • 预先准备: DDP 依赖于 c10d ProcessGroup 作为通信手段。所以在使用 DDP 前应用必须先创建 ProcessGroup 实例。
  • 构建 DDP: DDP 构建需要本地模型的引用,而后会把在 rank 0 进程上的本地模型的 state_dict() 广播到同组的其他进程,从而保证所有模型副本都从一致的状态开始训练。然后,每个 DDP 进程都会创建本地的 Reducer,用于在反向传播阶段同步梯度信息。为了优化通信效率,Reducer 将参数分成多个桶,每次仅处理一个桶。桶大小可以通过 DDP 的 bucket_cap_mb 参数调整。参数梯度到桶的映射是 DDP 在构造时根据桶大小和参数量确定的。模型参数大致上是按照 Model.parameters() 的逆向顺序保存到桶中的。这么做的原因是 DDP 希望梯度按照反向传播的顺序依次就绪。下图给出了一个简单的例子。注意 grad0grad1 都在 bucket1 中,其他两个梯度则在 bucket0 中。当然,这个假设不会总符合实际情况。由于 Reducer 不能在尽可能早的时间启动数据通信,这种情况将会拖慢 DDP 的反向传播速度。除了桶,Reducer 同样在构造时给每个参数注册了梯度钩子。这些狗子将会在反向传播过程中对应梯度就绪时触发。
  • 前向传播: DDP 获取输入并将其传递给本地模型,如果 find_unused_parameters 被设置为真,那么其还会分析本地模型的输出。这种方法允许我们仅在模型的一部分上反向传播,通过遍历模型输出的梯度图,将所有未使用参数标记为「就绪」状态,DDP 能够找出参与反向传播过程的参数。在反向传播过程,Reducer 将只等待未就绪状态的参数,但仍将 reduce 所有的桶。将参数梯度标记为就绪现在并不能让 DDP 自动跳过对应桶,但能够防止 DDP 在反向传播过程中无限等待不存在的梯度数据。注意,遍历梯度图将带来额外的代价,所以没有必要时应用不应该设置 find_unused_parameters
  • 反向传播: backward() 操作是在损失 Tensor 上执行的,这部分不在 DDP 的控制之下。DDP 使用梯度钩子来触发梯度同步。当一个梯度就绪后,它对应的 DDP 钩子也会被触发,DDP 就可以把参数梯度标记为就绪了。当一个桶中的所有梯度都就绪后,Reducer 就会开始在桶上执行异步 allreduce 过程,来计算所有进程中该部分梯度的均值。当所有的桶都完成后,Reducer 就会阻塞进程,等待所有 allreduce 完成操作。当所有步骤结束后,梯度均值将会被写回所有参数的 params.grad。到目前,所有 DDP 进程中对应参数的梯度都应该是一样的。
  • 优化: 从优化器的角度来看,它只会优化一个本地模型。不同 DDP 进程的模型副本都能保持同步,因为它们都从一个状态开始,并且每次优化迭代中它们的梯度都是一致的。


Distributed Data Parallel 入门#

DistributedDataParallel(DDP)实现了模块级别的数据并行,能够在多台机器上同步运行。使用 DDP 编写的应用会创建多个进程,并在每个进程创建 DDP 单例。DDP 使用 torch.distributed 包里的通信方法来同步梯度和缓存。更具体地讲,DDP 对 model.parameters() 的每个参数创建了梯度钩子,在反向传播过程中对应梯度完成计算时这些钩子将被触发。然后 DDP 使用这些信号来触发跨进程的梯度同步。你可以参考 DDP 设计笔记来了解更多的细节。

DDP 的推荐用法是对每个模型副本都创建一个进程,每个模型副本可以使用多个 device。DDP 进程能够在单个机器或者多台机器下部署,但 GPU 设备不能在多个进程间共享。这个教程从一个基本的 DDP 使用案例开始,逐步引入更多的高级案例,如 checkpoint 和融合模型并行到 DDP 数据并行。

DataParallelDistributedDataParallel 的区别#

在正式进入教程前,我们先解释为什么要使用更复杂一点的 DistributedDataParallel 而不是 DataParallel

  • 第一,DataParallel 是单进程多线程的,只能在一个机器上工作。DistributedDataParallel 则是多进程的,且能够在单机或多机上部署训练。即使是在单机上,DataParallel 通常也会慢于 DistributedDataParallel,因为跨进程的 GIL、模型分发以及额外的输入分发、输出汇总。
  • 回忆先前的教程,如果你的模型太大,那你可能要使用模型并行来在多个 GPU 上训练。DistributedDataParallel 能够适配模型并行,但 DataParallel 不行。当在 DDP 中使用模型并行时,每个 DDP 进程将使用模型并行,每个进程还将使用数据并行。
  • 如果你的模型需要跨机器训练,或者你的使用情景不能归纳到数据并行范式,请查看 RPC API 来了解更加通用的分布式训练支持。

基本使用样例#

要使用 DDP,你要先正确设置一个进程组(process group)。更多的细节请参阅结合 Pytorch 编写分布式应用

现在,我们可以创建一个简单的深度学习模块了,然后使用 DDP 包装这个模块,在随便给它一些输入。注意,DDP 在构造时会把 rank 0 进程的模型状态广播到其他进程,所以你不需要考虑不同 DDP 进程的模型的初始值是否一致。

可以看到,DDP 封装了底层的分布式通信细节,为我们提供了简洁的 API。梯度同步的通信在反向传播时进行,并覆盖原有的反向传播计算过程。当 backward() 返回时,param.grad 已经是同步的梯度张量了。在这个简单的例子中,DDP 只需要了了几行就能设置进程组。当在更复杂的情况下使用 DDP 时,需要注意一些警告。

处理进程异步问题#

在 DDP 中,构造函数、前向传播、反向传播都是分布式下的同步点。不同进程应该启动相同数量的同步点,并且以相同的顺序到达这些同步点,还应该在基本相同的时间到达这些同步点。否则,更快的进程可能更快到达同步点并超时等待更慢的进程。所以,用户需要负责平衡不同进程间的工作任务。某些时候,进程速度倾斜可能无法避免,比如网络出现延迟、资源竞争、抑或是一些不好预料的负载峰值。为了避免超时,请确保创建进程组时设置的 timeout 参数合理。

保存和读取 checkpoint#

一个常见的操作是使用 torch.savetorch.load 来保存、读取网络等模块的中间状态。在 DDP 下,一个优化是只让一个进程保存模型,然后在剩下的进程读取。如果使用这个优化,我们需要确保没有进程在保存前读取状态。另外,在读取模块状态时,你还需要提供合适的 map_location 参数,来防止进程误用其他进程的设备。如果 map_location 没有设置,torch.load 会先把数据读入 CPU,然后复制到它原来的设备里,这可能导致所有进程使用同一块设备。对于更复杂的错误恢复,请阅读 TorchElastic

同时使用 DDP 和模型并行#

DDP 也能和多 GPU 模型同时使用。DDP 封装的多 GPU 模型在大数据、大模型的情况下很有用。

class ToyMpModel(nn.Module):
    def __init__(self, dev0, dev1):
        super(ToyMpModel, self).__init__()
        self.dev0 = dev0
        self.dev1 = dev1
        self.net1 = torch.nn.Linear(10, 10).to(dev0)
        self.relu = torch.nn.ReLU()
        self.net2 = torch.nn.Linear(10, 5).to(dev1)

    def forward(self, x):
        x = x.to(self.dev0)
        x = self.relu(self.net1(x))
        x = x.to(self.dev1)
        return self.net2(x)
python

当把一个多 GPU 模型传递给 DDP 时,你不应设置 device_idsoutput_device。输入和输出数据应存放在合适的设备上,这一步应由应用或者模型的 forward 方法完成。

使用 torch.distributed.run/torchrun 初始化 DDP#

我们可以使用 Pytorch Elastic 来初始化 DDP 代码,这能让我们的工作更加简单。我们还是使用上文的简单模型,然后创建一个名为 elastic_ddp.py 的文件。

然后,我们在所有节点上运行 torch elastic/torchrun 命令来初始化 DDP 任务:

torchrun --nnodes=2 --nproc_per_node=8 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR:29400 elastic_ddp.py
bash

我们在两个机器上运行 DDP 脚本,每个机器有 8 个进程,这意味着我们同时在 16 个 GPU 上运行模型。注意 $MASTER_ADDR 在所有节点上必须相同。

torchrun 将启动 8 个进程,并启动 elastic_ddp.py 在每个节点上,不过用户同时也需要使用集群管理工具(比如 slurm)来在两个节点上运行命令。

例如,在一个 SLURM 集群上,我们可以使用这样的命令来设置 MASTER_ADDR

export MASTER_ADDR=$(scontrol show hostname ${SLURM_NODELIST} | head -n 1)
bash

然后我们可以直接把这些命令保存成脚本,然后使用 SLURM 命令来同时在集群上运行:srun --nodes=2 ./torchrun_script.sh。当然,这只是一个简单的例子,你可以使用自己的集群管理工具来启动任务。


使用经验#

各位读者需要意识到,在 DDP 里,每个进程都有着自己的模型、自己的输入、自己的输出、自己的 loss,DDP 只不过是在各个进程间维护了一致的模型初始状态和梯度,这些进程除此之外几乎是相互独立的。所以如果你想在 DDP 后做事,不但需要考虑是否要像 checkpoint 一样指定一个唯一的 rank 来做这件事,还需要自行同步所有进程上的数据。这涉及到分布式编程的常见问题,以下基本上都是在处理这些情况。好在 Pytorch 提供了丰富的分布式编程 API,我们得以简化很多操作。

后端的选择#

在初始化进程组的时候需要选择后端,这个后端的参数一般可以选择 nccl,如果要使用 CPU 训练的话,可以选择 gloo

数据获取及 batch size#

第一个需要注意的是,DDP 下的 batch size 是对每一个进程设置的,而不是总数。所以如果需要设置 batch_size 为 NN,你可能需要将其除以进程数,以保证梯度同步后是正确的。

第二个是在 DDP 中创建 DataLoader 的正确方法。DDP 中不存在输入的广播步骤,这意味着你必须从一开始就在不同的进程给模型使用不同的数据。为了达到这一点,可以使用 Pytorch 提供的 DistributedSampler

# train_dataset
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, pin_memory=True)
python

在每个 epoch 中,在进入 dataloader 的迭代器前需要执行 train_sampler.set_epoch(epoch) 来让 shuffle 正常工作。

validation 阶段不需要这项操作。虽然会导致所有进程跑一遍所有数据,但使用该 sampler 可能会导致数据被截断(先要平分到进程,而后平分到 batch)。

指标统计#

不但 checkpoint 需要考虑不同进程间的同步问题,指标统计显然也是必要的。

对于 loss,可以这样做:

loss = loss.clone().detach()
loss_mean = dist.reduce(loss, rank=0) / dist.get_world_size()
if dist.get_rank() == 0:
	# collect results into rank0
	print(f"epoch: {epoch}, loss: {loss_mean} ")
python

BatchNorm 层#

显然,由于 batch size 被按照进程大小切分,直接在每个进程里各玩各的 BatchNorm 是不合适的。为了在多进程间同步,你需要在 DDP 之前:

model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
python

随机数种子#

你需要考虑自己是否需要多个进程的随机数种子一致。必要的时候,使用 rank 来为不同进程设置不同的随机数种子。

random.seed(seed + dist.get_rank())
np.random.seed(seed + dist.get_rank())
torch.manual_seed(seed + dist.get_rank())
# 下面两行能维持最大的可复现,但会拖慢速度
# cudnn.deterministic = True
# cudnn.benchmark = False
python

另外不仅 DDP 的分布式训练会导致随机数种子的问题, DataLoader 中也有同类问题. 虽然 Pytorch 正确处理了自己的 seed, 但是不代表你也是, 例如如果你在数据读取时使用了 numpy 的随机数, 那就需要自行给不同进程设置不同的 numpy seed.


Footnotes#

  1. 对于用户(人、代码)「透明」,意味着该组件的介入不会对它们的现状产生任何影响,它们不用也不会察觉到该组件的存在。就算阿卡林坐在你前桌也不会挡住你看黑板,大概是这意思。

Pytorch 分布式训练技术
https://astro-pure.js.org/blog/pytorch-distributed-data-parallel
Author Cheng Chen
Published at 2022年6月24日
Comment seems to stuck. Try to refresh?✨