随着Transformer模型的发展,近些年多模态模型获得了长足的发展,使得不同任务不同领域可以实现特征的打通,变换出很多新奇好玩的场景。其中非常热门的就是让AI学会看文作图,即文字生成图像,如OpenAI的CLIP模型,其基于带文字的图像数据集上,训练出很惊艳的效果。
然而,收集这些带文字的图像数据集成本非常高。最近字节在Arxiv上发表了一项文本生成图像 (text2img) 的工作,其利用对抗网络GAN改造CLIP模型,使得 CLIP-GEN 可以不依赖带文字描述的图片数据集,直接使用无文本图像数据集进行训练,通过预训练好的 CLIP 模型建立起文本和图像的映射关系。在很多实验数据中表明,它的效果比VQGAN-CLIP要真实,尤其是泛化能力还比不少用大量“文本-图像”数据对训练出来的模型要好很多。
幻方AI最近复现了该项工作,并通过幻方自研的 3FS、hfreduce、算子,对模型训练和推导进行优化。我们在hfai数据仓库中开源了训练数据,模型代码,旨在帮助研究者和开发者们降低研究门槛。
论文标题:CLIP-GEN: Language-Free Training of a Text-to-Image Generator with CLIP
论文地址:https://arxiv.org/abs/2203.00386
模型仓库:https://github.com/HFAiLab/clip-gen
模型介绍
下图是 CLIP-GEN 模型的整体结构:
CLIP-GEN 模型主要由两部分组成,第一部分是一个 VQGAN 模型,用来学习如何把图像编码成一系列的图像标记(image tokens),通过这些图像标记解码还原成一张图片;第二部分是一个 condition transformer 模型,用来学习如何把文字的 CLIP embedding 映射到图像标记(image tokens)中。
训练、推理的过程分为三步:
- 预训练 VQGAN:输入图像,把图像编码到码本空间,然后再从码本空间解码为图像。经过预训练之后我们就能够把图像表示成一个图像标记(image tokens) 下的离散序列。
- 训练 condition transformer:通过第一步的预训练,我们已经能够通过图像标记(image tokens) 来生成图片了,在此基础上,我们训练一个 condition transformer,其旨在学习如何把图像的 CLIP embedding 映射到图像标记。训练的过程中 VQGAN 的参数保持不变。
- 文字生成图片:由于在 CLIP 中,文字和图像共享同一个嵌入空间,我们可以直接把文字的 CLIP embedding 作为 condition transformer 的输入映射到图像标记上,然后通过 VQGAN 的 decoder 来生成图片
数据集
对于 CLIP-GEN 的训练,我们采用了 COCO Caption 数据集,包含 20 万张图文对(训练的过程中没有使用文本)。我们把 COCO Caption 数据集转换成 ffrecord 格式,整合到了 hfai 数据仓库中,可以直接通过以下方式直接使用:
import hfai
dataset = hfai.datasets.CocoCaption(split='train', transform=transform)
有关更多内容,可以访问 hfai 官方文档:https://doc.hfai.high-flyer.cn/index.html
模型训练与优化加速
幻方AI复现了GLIP-GEN模型,并验证其效果。通过幻方自研的 3FS、hfreduce、算子等优化工具,对模型训练和推导进行优化和加速,具体的包括:
- hfai ddp:采用 hfreduce 优化多机多卡通信
- hfai nn: 重构深度学习算子,提升性能
- hfai datasets: 采用高效数据样本格式 ffrecord,充分发挥 3FS 存储带宽性能
下面进行详细描述。
hfai DDP 通信加速
hfai DDP 内部采用了幻方自研的 hfreduce 高性能通讯框架,能有效提升模型的训练速度,使用方法只需要修改一行代码:
# from torch.nn.parallel import DistributedDataParallel
from hfai.nn.parallel import DistributedDataParallel
# ...... initialize model
model = DistributedDataParallel(model, device_ids=[local_rank])
hfai nn 算子加速
为了进一步提升模型训练速度,我们可以使用 hfai.nn 里的高性能算子,相比于 PyTorch 能带来明显的提升,使用方法只需要增加一行代码:
import hfai
model = hfai.nn.to_hfai(model) # 自动替换为 hfai 高性能算子
使用说明
幻方AI将所复现的模型和优化的方法都进行了开源,统一归集到 hfai 模型仓库 (https://github.com/HFAiLab/hfai-models) 中,欢迎大家来 star。
-
下载 CLIP 预训练模型:下载 CLIP 后放至
pretrained/clip_vit_b32.pt
,该预训练模型来自 OpenAI. -
在 COCO 上训练 VQGAN:通过
hfai python
提交任务至萤火集群hfai python train_vqgan.py --ds coco -- -n 1 -p 30
-
在 COCO 上训练 Conditional GPT:通过
hfai python
提交任务至萤火集群hfai python train_gpt.py --ds coco --vqgan_ckpt /path/to/vqgan/ckpt -- -n 4 -p 30
训练结果
我们来看看训练完成后,一些文本生成图像的效果。
可以看到,不利用带文本的图像数据集,CLIP-GEN 所生成的效果还是非常逼真的。
体验总结
CLIP-GEN 将对抗网络 GAN 用于改造 CLIP 模型,使得 CLIP-GEN 可以不依赖带文字描述的图片数据集,直接使用无文本图像数据集进行训练,这极大降低了数据收集的成本,推动了该领域研究的发展。通过预训练好的 CLIP 模型建立起文本和图像的映射关系,在很多实验数据中表明,CLIP-GEN 的效果比 VQGAN-CLIP 要真实,尤其是泛化能力还比不少用大量“文本-图像”数据对训练出来的模型要好很多。
综合体验打分如下:
-
研究指数:★★★★
该模型是多模态领域的最新研究成果,降低了数据收集的成本,推动了该领域的发展。
-
开源指数:★★★
代码没有开源,但所依赖的方法有其他开源版本,容易复现。
-
门槛指数:★★★
数据规模大,模型适中,适合多级多卡数据并行训练。一般单卡训练难度比较大。
-
通用指数:★★★★
该方法适用于多模态研究场景,能在很多类似场景下应用。
-
适配指数:★★★★★
依赖简单,很容易与幻方AI的训练优化工具结合,提效明显。
幻方 AI 紧跟 AI 研究的前沿浪潮,致力于用领先算力助力AI落地与价值创造,欢迎各方数据研究者与开发者们一同共建。