hfai.pl | 兼具萤火集群的优化特性的 PyTorch Lightning

Freja    November 17, 2022

Pytorch Lightning(简称 pl) 是在 PyTorch 基础上进行封装的库,它能帮助开发者脱离 PyTorch 一些繁琐的细节,专注于核心代码的构建,在 PyTorch 社区中备受欢迎。本文为大家介绍如何在 PyTorch Lightning 上轻松的应用幻方萤火集群的系列优化特性。

集群环境适配

若集群每个计算节点有 x 张 GPU。用户提交任务时需选定节点数量 N,则该任务可获得 N*x 个 GPU。每个进程中全局环境变量含义如下:

  • world_size (hfai): 节点数量,用 N 表示
  • rank (hfai): 节点 id,用 n 表示,n 属于 0 ~ N-1
  • local_rank (hfai): GPU id, 用 k 表示,k 属于 0 ~ x-1

与 PyTorch init_process_group 的变量有如下对应:

  • world_size (PyTorch): 进程数量,每个 GPU 对应一个进程,因此总进程数目为总 GPU 数目,计算方法为 N*x
  • rank (PyTorch): 进程 id,利用节点 id 和 GPU id 可以计算出进程的 id,计算方法为 n*x+k

在一般的 PyTorch 代码中,我们需要进行如下的分布式初始化:

ip = os.environ.get("MASTER_ADDR", "127.0.0.1")
port = os.environ.get("MASTER_PORT", "2223")
hosts = int(os.environ.get("WORLD_SIZE", 1))  # number of nodes
rank = int(os.environ.get("RANK", 0))  # node id
gpus = torch.cuda.device_count()  # gpus per node
dist.init_process_group(
    backend="nccl", 
    init_method=f"tcp://{ip}:{port}", 
    world_size=hosts*gpus, 
    rank=rank*gpus+local_rank
)
torch.cuda.set_device(local_rank)

pl 对分布式初始化进行了封装,用户不需要自己初始化,只需指定节点数量和 GPU 数量即可。为了适配到幻方萤火集群,我们提供了 HFAIEnvironment 环境插件,只需要几行代码即可将 pl 的分布式参数无缝适配到萤火集群当中。

from hfai.pl import HFAIEnvironment
trainer = pytorch_lightning.Trainer(
    gpus=x, num_nodes=N,
    plugins=[HFAIEnvironment()] # 定义 HFai 环境并作为插件输入
)

绑定 numa

为了规避跨 numa 传输带来的带宽损耗,建议对训练代码进行 numa 绑定,在一般的 PyTorch 代码中,我们通过使用 hfai.multiprocessing 启动多进程,并指定 bind_numa=True 来进行 numa 的绑定,代码如下:

import torch
import hfai

def main(gpu_id):
    torch.cuda.set_device(gpu_id)
    # ......

if __name__ == "__main__":
    ngpus = hfai.utils.num_gpus()  # 调用 cuda 函数会导致子进程产生错误
    hfai.multiprocessing.fork(main, args=(), nprocs=ngpus, bind_numa=True)

pl 中对多进程的启动进行了封装,用户只需要在 strategy 中指定使用 ddp 或者 ddp_spawn 就可以指定多进程的启动方式。因此我们提供了两种新的 strategyddp_bind_numaddp_spawn_bind_numa,用户选择这两种 strategy,就可以在指定多进程启动方式的同时绑定 numa,使用方式如下所示:

# 使用 ddp_bind_numa 或者 ddp_spawn_bind_numa
trainer = pytorch_lightning.Trainer(strategy="ddp_spawn_bind_numa")

使用 hfreduce

hfreduce 是幻方 AI 自研的高性能多卡并行通信工具,其能够更高效的在多显卡之间交换梯度信息,加速模型训练。

在一般的 PyTorch 代码中,我们通过使用 hfai.nn.parallel.DistributedDataParallel 替换 PyTorch 自带的 torch.nn.parallel.DistributedDataParallel 即可使用 hfreduce。

由于 pl 中对 DDP 的初始化进行了封装,因此我们提供了 hfreduce_bind_numahfreduce_spawn_bind_numa 两种 strategy。用户在初始化 Trainer 的时候进行如下指定即可。

# 使用 hfreduce_bind_numa 或者 hfreduce_spawn_bind_numa
trainer = pl.Trainer(strategy="hfreduce_bind_numa")

自动打断重启

幻方萤火平台采用任务级分时调度的底层设计,给每个任务分配集群运行时间片。按照分时调度的方案规则,训练任务会被中断、并自动调起。因此在该平台上运行的代码需要在训练循环中进行打断信号的捕获和参数保存,以便之后恢复训练。断点保存的示例代码如下:

for epoch in range(epochs):
    for step in range(len(data_batch)):
        if hfai.distributed.get_rank() == 0 and gpu_id == 0: # 在0号节点的0号进程上接收集群调度信息
            if hfai.client.receive_suspend_command(): 
                model.save() # 保存模型、迭代器等参数到文件
                time.sleep(5) # 最多预留5秒完成断点保存,之后会被强制打断
                hfai.client.go_suspend() # 发送准备好被打断的信号

pl 对训练循环进行了封装,因此我们提供了 ModelCheckpointHF 作为回调,在收到打断信号的 step 进行模型的保存,在每次任务启动时,检测是否需要从上一个断点恢复训练,使用方法如下所示:

from hfai.pl import ModelCheckpointHF
output_dir = 'hfai_out'
cb = ModelCheckpointHF(dirpath=output_dir)
trainer = pytorch_lightning.Trainer(callbacks=[cb]) # 自动处理集群打断信号

ckpt_path = f'{output_dir}/{cb.CHECKPOINT_NAME_SUSPEND}.ckpt' # 检查是否有断点模型被保存
ckpt_path = ckpt_path if os.path.exists(ckpt_path) else None
trainer.fit(
    model_module,
    ckpt_path=hfai_suspend_ckpt_path # 自动恢复训练
)

使用 hfai 优化算子

幻方 AI 依托萤火二号集群,对 PyTorch 框架进行了深度优化,结合萤火集群的特点,对一些常用的 AI 算子重新研发,提升效率,进一步提升了模型整体的训练效率。

在 pl 框架中,我们只需要增加如下一行代码,就可以将 PyTorch 中的算子转换成 hfai 优化后的算子:

import pytorch_lightning as pl
class ToyNetModule(pl.LightningModule):
    # ...
pl_module = ToyNetModule()
model_module = nn_to_hfai(pl_module) # 将算子转换为 hfai 算子

使用 FFRecord

为了在 PyTorch Lightning 中使用 ffrecord 的 Dataloader,我们需要在 Dataloader 设置 skippable=False:

from ffrecord.torch import Dataset, DataLoader

class MyDataset(Dataset)
    # ...

dataset = MyDataset(...)
dataloader = DataLoader(dataset, batch_size, num_workers=num_workers, skippable=False)

使用示例

上面介绍了如何将 pl 框架的代码融入萤火集群的各种优化特性当中,下面提供一个完整的示例帮助大家理解:

from hfai.pl import HFAIEnvironment
from hfai.pl import ModelCheckpointHF
import pytorch_lightning as pl

class ToyNetModule(pl.LightningModule):
    ...

output_dir = 'hfai_out' # 模型保存路径
cb = ModelCheckpointHF(dirpath=output_dir) # 可以接收集群打断信号的回调类
trainer = pl.Trainer(
    gpus=x, # 每个节点 x 个 GPU
    num_nodes=N, # 节点数量
    strategy="ddp_bind_numa",  # 支持 ddp_bind_numa, ddp_spawn_bind_numa, hfreduce_bind_numa, hfreduce_spawn_bind_numa
    plugins=[HFAIEnvironment()],  # 自动适配 HFAI 分布式环境
    callbacks=[cb] # 自动处理集群打断信号
)
model_module = nn_to_hfai(ToyNetModule()) # 将算子转换为 hfai 算子

ckpt_path = f'{output_dir}/{cb.CHECKPOINT_NAME_SUSPEND}.ckpt' # 判断之前是否保存了断点模型
ckpt_path = ckpt_path if os.path.exists(ckpt_path) else None
trainer.fit(
    model_module,
    ckpt_path=ckpt_path # 自动恢复训练
)

通过上面的简单适配,PyTorch Lightning 就能够应用幻方萤火集群的系列优化特性啦。


本文作者: Freja


您可以转载、不违背作品原意地摘录及引用本技术博客的内容,但必须遵守以下条款: 署名 — 您应当署名原作者,但不得以任何方式暗示幻方为您背书,亦不会对幻方的权利造成任何负面影响。 非商业性使用 — 您不得将本技术博客内容用于商业目的。 禁止演绎 — 如果基于该内容改编、转换、或者再创作,您不得公开或分发被修改内容,该内容仅可供个人使用。