CLIP(Contrastive Language-Image Pre-Training,以下简称 CLIP)模型是 OpenAI 在 2021 年初发布的用于匹配图像和文本的预训练神经网络模型,可以说是近年来在多模态研究领域的经典之作。该模型直接使用大量的互联网数据进行预训练,在很多任务表现上达到了目前最佳表现(SOTA)。
本次我们尝试使用 Google 开源的 Conceptual Captions 数据集来训练 CLIP 模型,并对其效果进行一定的验证。
本次体验接入的系统是幻方AI团队研发的幻方萤火二号:自研智能分时调度、高效存储与网络通信,可以实现任务级别的弹性算力支持;配合API或者 JupyterLab 交互就可以轻松接入,利用起澎湃算力进行大规模深度学习训练。
源码地址:https://github.com/openai/CLIP | https://github.com/mlfoundations/open_clip
数据集:https://ai.google.com/research/ConceptualCaptions/download
模型仓库:https://github.com/HFAiLab/clip-gen
CLIP 模型简介
在目前图像分类领域中,我们训练的模型通常会遇到以下问题:
- 模型需要用到大量的格式化标注数据,这些标注数据获取通常成本高昂。
- 模型在当前数据集的效果比较好,但是可能模型的泛化能力较差,同时迁移到新的训练任务也比较困难
与此同时,互联网上面已经存在了大量的图像文本对(在网页中,开发者一般都会为图片添加一段文字备注),实际上这些素材可以作为已经标注好的数据集,利用这些数据集进行训练,既能解决获取标注数据成本高昂的问题,同时也因为互联网上的数据量比较大和数据本身差异较大,更容易让我们获得泛化能力较强的模型。
CLIP 模型就是基于上述概念,使用 OpenAI 收集到的 4 亿对图像文本对,分别将文本和图像进行编码,之后使用 metric learning 进行训练,通过计算 cosine similarities 计算图像和文本的匹配,通过最大化匹配的图像文本对的 cosine similarity 和最小化不匹配的文本对的 cosine similarity 来优化目标函数,它的核心流程比较容易理解,可以直接参考下述伪代码:
# image_encoder - ResNet or Vision Transformer
# text_encoder - CBOW or Text Transformer
# I[n, h, w, c] - minibatch of aligned images
# T[n, l] - minibatch of aligned texts
# W_i[d_i, d_e] - learned proj of image to embed
# W_t[d_t, d_e] - learned proj of text to embed
# t - learned temperature parameter
# extract feature representations of each modality
I_f = image_encoder(I) #[n, d_i]
T_f = text_encoder(T) #[n, d_t]
# joint multimodal embedding [n, d_e]
I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
T_e = l2_normalize(np.dot(T_f, W_t), axis=1)
# scaled pairwise cosine similarities [n, n]
logits = np.dot(I_e, T_e.T) * np.exp(t)
# symmetric loss function
labels = np.arange(n)
loss_i = cross_entropy_loss(logits, labels, axis=0)
loss_t = cross_entropy_loss(logits, labels, axis=1)
loss = (loss_i + loss_t)/2
在预测阶段,也是通过一系列生成的文本对和目标图像,计算 cosine similarities 从而获取预测值:
CLIP 通过以上过程,取得的效果还是比较惊艳的。
OpenAI 官方开源了 CLIP 模型部分的代码,不过如果想完整的的复现训练,还需要写不少训练相关的代码,这里我们基于 github 上的另外一个开源实现,通过一定的修改,来完成在萤火平台上训练的整个过程。
数据预处理
下载和清洗
Conceptual Captions 数据集是 Google 从数十亿互联网网页提取的图像文本信息数据,并进行了若干种类型的过滤,使得数据集具有较高的质量和准确度,总共约 300 万个测试数据和 8000 多个验证数据。
它提供的是描述和图片的下载地址对,其数据格式如下:
title filepath
a man on a scooter , taking part in the rally http://l7.alamy.com/zooms/5e84d0b73a544fe8a8e5ac0d20052685/a-man-on-a-lambretta-scooter-taking-part-in-the-daily-express-rally-ddmmff.jpg
the dentist drill the tooth with a turbine http://l7.alamy.com/zooms/da764b930a6e477da792ddb8b83b1ecb/the-dentist-drill-the-tooth-with-a-turbine-hn3aph.jpg
...
我们可以将该数据集通过 python 脚本下载和统一存放,下载完之后,我们需要将该数据集进行清洗,将下载的空图片或者不完整的图片数据过滤,之后才可以使用。
这部分下载和清洗,我们简单编写 python 脚本即可完成,我们完全下载清洗后得到的 csv 文件格式为:
title filepath
a man on a scooter , taking part in the rally /demo/path/clip/cc_data_1/val/29/29.jpg
the dentist drill the tooth with a turbine /demo/path/clip/cc_data_1/val/19/19.jpg
...
实际上,这个数据集由于比较大,而且都是分散在各个 URL 而不是一个完整的压缩包, 我们下载会花费不少时间,累计接近 100G 的解压大小也会对我们的数据存放造成一定负担,不过幻方AI团队目前正在基于自研的高性能平行文件系统进行数据集仓库的建设,也就是说,对于一些常用的数据集和已经在集群上训练过的公开数据集,包括以上 Conceptual Captions 数据集,我们会经过处理后放在可访问的公开文件目录,后续在萤火使用相同的数据集进行训练或模型开发,就无需重复下载和预处理数据集了,直接使用即可,从而能极大减轻我们的数据负担。
在分布式训练开始之前,虽然我们已经可以使用清洗好的图像数据进行训练了,但如果我们频繁打开小文件,对训练自身的性能和整个集群的影响都会比较大,我们可以使用幻方AI自研的 FFRecord 来将多个小文件进行合并处理,从而减少训练时打开大量小文件的开销,同时对存储后端更加友好,我们会在下文对其进行介绍。
FFRecord 数据格式转换
FFRecord(FireFlyer Record)是幻方AI自研的简单高效的存储二进制记录的文件格式,它的一些特点包括:
- 合并多个文件,减少了训练时打开大量小文件的开销,对存储后端更加友好。
- 支持随机读取,可以适应不同的样本读取模式。
- 包含数据校验,保证读取的数据完整可靠。
同时,为了方便训练,我们还提供了专为 PyTorch 和 FFDataset(FireFlyer Dataset) 和 FFDataLoader(FireFlyer DataLoader),优势包括:
- 高效读取,使用 Linux Asynchronous I/O 的接口,充分利用了自研存储系统的随机读取性能优势。
- 简单易用,只需要对使用原有 PyTorch DataLoader 的代码进行简单的修改即可切换为 FFDataLoader。
- 使用灵活,完全兼容 PyTorch Dataset 和 DataLoader 的相关接口,如 Sampler 等。
数据预处理和 FFRecord 转换部分的代码,均可在萤火二号访问。
在转换为 FFRecord 之后,我们就可以开始调整开源代码,进行训练了。
开源代码调整
这里我们针对开源代码进行调整,这里的主要调整的内容包括:
- 参数调整,包括超参数、文件路径等
- 入口文件微调,引入 haienv、hfai
- 使用 FFDataset 代替原有的 Dataset
- 适配集群打断逻辑
step1: 参数调整部分比较简单,我们可以根据自己的需要调整数据集和验证集的文件路径,以及 batch-size 和 epoches 等参数。
这里我们增加了两个参数,分别用于表示训练集和验证集的 FFRecord 文件地址:
'--train-ddr-perfix', '/3fs-jd/prod/private/nxt/clip/cc_data_1_proc/train',
'--val-ddr-perfix', '/3fs-jd/prod/private/nxt/clip/cc_data_1_proc/val',
step2: 引入 haienv、hfai 的代码和我们的上一篇 informer 模型的调整比较类似,这里我们直接给出代码
import haienv
haienv.set_env("202111")
import os
import sys
sys.path.append(f'{os.getcwd()}/src')
step3: 为了提高训练性能,我们使用上文提到的 FFDataset 代替 torch 的 Dataset,这里我们给出一个实现代码:
class FFDataset(fftorch.Dataset):
def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t", ddr_perfix=None):
logging.info(f'[FFDataset] Loading csv data from {input_filename}. use ffrecord: {ddr_perfix}')
df = pd.read_csv(input_filename, sep=sep)
self.images = df[img_key].tolist()
self.captions = df[caption_key].tolist()
self.transforms = transforms
self.ddr_perfix = ddr_perfix
super().__init__(self.ddr_perfix, check_data=True)
logging.info('[FFDataset] Done loading data.')
def __len__(self):
return len(self.captions)
def process(self, indexes, data):
'''
在每次读取样本数据的时候,
FFDataset 先从 FFRecord 文件中读取样本的二进制数据,
然后将原始二进制数据传递给 process 函数。
用户需要继承 ffrecord.torch.Dataset 并自定义 process 函数以对原始二进制数据进行处理。
'''
samples = []
try:
for bytes_ in data:
sample = deserialize(bytes_)
image_from_bytes = sample['image_bytes']
images = self.transforms(image_from_bytes)
texts = tokenize([str(sample['caption'])])[0]
samples.append((images, texts))
except:
raise
return samples
step4: 由于萤火的任务调度规则为分时分优先级调度,也就是说任务提交以后,有可能会被集群打断暂停,任务打断实际上会强行结束当前进程,后续再启动的时候就又会再次执行初始化逻辑,为了能够在打断后恢复继续训练,我们需要及时保存我们的 checkpoint,并且在任务恢复的情况下加载之前的历史记录继续训练,同时,我们应该在整个训练结束之后对模型进行保存。
这里笔者提供一个比较简单的方式,我们利用 whole_life_state 来进行标记,第一次 whole_life_state 是 0,如果是打断恢复的场景,这个时候 whole_life_state 就是 1,因此我们加载最新的 checkpoint 进行继续训练,这里给出关键的代码:
# main 函数里面处理 whole_life_state:
whole_life_state = hfai.get_whole_life_state()
print('whole_life_state:', whole_life_state)
if whole_life_state == 1:
# 只要不是第一次运行,我们就尝试加载最近的 checkpoint:
args.resume = os.path.join(args.checkpoint_path, f"epoch_latest.pt")
hfai.set_whole_life_state(1)
# 加载 checkpoint:
if args.resume is not None:
if os.path.isfile(args.resume):
checkpoint = torch.load(args.resume, map_location=loc)
start_epoch = checkpoint["epoch"]
sd = checkpoint["state_dict"]
model.load_state_dict(sd)
if optimizer is not None:
optimizer.load_state_dict(checkpoint["optimizer"])
logging.info(
f"=> loaded checkpoint resume:'{args.resume}' (epoch {checkpoint['epoch']})"
)
# ...
经过以上处理,我们就可以验证和运行我们的任务了。
测试运行并提交训练
幻方萤火提供的 hfai 命令行工具提供了测试运行的接口,我们可以通过 hfai python src/training/main.py
来进行测试运行(一般来说,我们建议测试运行没有明显问题之后再提交训练)
测试运行会在当前终端直接打印运行信息:
2021-12-22,22:34:14 | INFO | Rank 0 | Use GPU: 0 for training
2021-12-22,22:34:14 | INFO | Rank 4 | Use GPU: 4 for training
2021-12-22,22:34:14 | INFO | Rank 6 | Use GPU: 6 for training
2021-12-22,22:34:14 | INFO | Rank 3 | Use GPU: 3 for training
2021-12-22,22:34:14 | INFO | Rank 7 | Use GPU: 7 for training
2021-12-22,22:34:14 | INFO | Rank 1 | Use GPU: 1 for training
2021-12-22,22:34:14 | INFO | Rank 2 | Use GPU: 2 for training
2021-12-22,22:34:14 | INFO | Rank 5 | Use GPU: 5 for training
... ...
Train Epoch: 0 [0/2850879 (0%)] Loss: 8.067907 Data (t) 42.464 Batch (t) 72.875 LR: 0.000000 logit_scale 2.659
Train Epoch: 0 [3840/2850879 (1%)] Loss: 8.055143 Data (t) 0.008 Batch (t) 0.356 LR: 0.000001 logit_scale 2.659
我们可以看到,我们的代码已经可以在单机8卡的场景下测试运行了,不过,我们会发现由于数据集比较多,这个训练还是比较慢的,此时我们可以利用萤火的多机多卡进行并行训练。
我们在幻方萤火的管理页面提交任务运行,选择 6 个节点:
任务提交后,我们在幻方萤火的管理页面,可以比较方便地看到任务的日志和状态:
我们可以发现,在 6 节点(48 张显卡)并行训练的过程中,单个 epoch 训练大约 6 分钟,我们如果训练 30 个 epoch,加上初始化时间,大约 3 个多小时:
验证
针对 CLIP 模型的验证测试,openAI 提供了多种方式,我们可以比较方便地在萤火上对此进行复现,这里为了便于我们理解模型,我们展示一种比较可视化的方式对它进行简单的验证:
我们使用 skimage 的一些图片和文本信息,使用我们训练好的模型计算相似度,然后通过二维矩阵的方式打印出来:
我们把相似度最高的,标记成了黄色,我们可以从上图看出,我们训练的模型在该情况下表现尚可。
总结
CLIP 模型思路新颖同时比较简单,借助萤火集群和自研文件系统,我们可以比较高效地训练该模型,在训练的过程中,GPU 的平均使用率达到 80%+, GPU Memory 使用率持续在 10GB+。
基于目前的数据和模型,我们后续还可以进行进一步的探索。
综合体验打分如下:
-
研究新颖度: ★★
该模型对图像文本分类领域提出了一个较为新颖的方法,有一定的技术推进性。
-
开源指数: ★★★★
官方开源了模型代码,并且 github 上有一些更加通用的开源代码参考,代码逻辑清晰可读性高。
-
算力需求: ★★★
解压后训练数据大小为 84 G,近 300 万条训练数据,6 节点 48 张 A100 大约需要 3 小时
-
通用指数: ★★★
该模型虽然比较容易理解,但是在多模态领域还是比较有启发的,有可能会在一定程度上成为后续深入研究的 baseline
-
模型适配度: ★★★★
额外依赖不多,在使用下载好的数据集的情况下我们只需要简单改一下 DataSet 等代码即可运行
幻方 AI 紧跟 AI 研究的前沿浪潮,致力于用领先算力助力AI落地与价值创造,欢迎各方数据研究者与开发者们一同共建。