分布式训练中模型的保存,特别是大模型,常常需要耗费很多的时间,降低了整体的 GPU 利用率。针对这类问题,幻方 AI 进行了攻关,优化过往深度学习模型单机训练保存的方法,研发出分布式 checkpoint 方案,大幅度降低模型保存与加载上的开销。
分布式 checkpoint
当我们进行分布式模型训练,特别是在训练大模型时,保存 checkpoint 需要较长的时间,这不仅浪费集群的计算资源,并且给集群整体的任务调度带来管理成本。为此幻方 AI 提供了一个分布式保存 checkpoint 的功能。
该功能的基本原理是:假设有 N 块 GPU,我们把模型参数和优化器参数切分成 N 个部分,然后每个分布 rank 把对应的部分写入文件系统;在读取的时候我们从文件系统中读出所有 checkpoint 拼接出完整的模型参数和优化器参数。除了模型参数和优化器参数,其他的信息会由 rank 0 进行保存。
使用方法:
from hfai.checkpoint import save, load
model = MyModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
others = {'epoch': epoch, 'step': step+1}
save('latest.pt', model, optimizer, others=others)
state = load('latest.pt', map_location='cpu')
epoch, step = state['epoch'], state['step']
model.load_state_dict(ckpt['model'])
optimizer.load_state_dict(ckpt['optimizer'])
集群上实际测试性能如下:
可以看到,随着并行程度的增加,使用hfai.checkpoint
进行模型保存加载的耗时越来越少。
自动断点训练
对于幻方萤火集群上的训练,需要接收集群统一的调度信号进行训练任务的管理。这里幻方 AI 提供了一个 hfai.checkpoint.init
函数帮助用户进行断点训练,该函数会自动加载上次保存的模型、优化器等状态,返回上次训练的 epoch 和 step,我们可以通过 epoch 和 step 进行优雅断点训练。同时,我们会向 model 注册一个成员函数 try_save
,通过这个函数可以在打断训练之前保存训练的状态。
使用方法:
model = MyModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
start_epoch, start_step, others = hfai.checkpoint.init(model, optimizer, ckpt_path='latest.pt')
for epoch in range(start_epoch, epochs):
for step, (x, y) in enumerate(dataloader):
if step < start_step:
continue
output = model(x)
loss_fn(y, output).backward()
model.try_save(epoch, step, others=None)
通过上述封装,您可以在代码中省去很多断点操作,简单方便地将代码适配幻方萤火系统,进一步降低门槛。