随着多个行业朝着大规模的 3D 虚拟世界发展,能够生成大量的、高质量的、多样的 3D 内容的工具是非常被需要的。英伟达的最新工作 GET3D 希望训练更好的 3D 生成模型,来生成下游任务可以直接使用的、保真纹理和复杂几何细节的 3D 模型。
幻方 AI 最近对这项工作进行了整理和优化,在幻方萤火二号上复现了实验。通过幻方自研的 3FS、hfreduce、算子,对模型训练进行提速,从单机多卡的训练支持优化到多级多卡的训练支持,帮助研究者和开发者们降低使用门槛。本期文章将为大家详细描述。
论文标题:GET3D: A Generative Model of High Quality 3D Textured Shapes Learned from Images
原文地址:https://arxiv.org/pdf/2209.11163.pdf
项目主页:https://nv-tlabs.github.io/GET3D
模型仓库:https://github.com/HFAiLab/GET3D
模型介绍
3D 生成模型
最近,英伟达发布了最新的 GET3D 模型,可以通过 2D 图像训练,生成具有高保真纹理和复杂几何细节的 3D 形状。在此之前,学术界也有大量的 3D 生成模型,他们所生成的 3D 模型存在如下问题:
- 缺乏几何细节
- 缺乏纹理
- 在合成过程中只能使用神经渲染器,3D 软件中不方便使用
过去工作的方法以及生成的 3D 模型效果如下表所示,其中 Ours 代表 GET3D 模型的生成几何结果,Ours-tex 代表 GET3D 模型生成的带纹理结果:
GET3D
GET3D 希望克服上述模型存在的问题,生成具有丰富几何细节、带纹理的、下游 3D 软件中可以直接使用的 3D 模型。并且训练过程只需要 2D 的监督图像训练。
GET3D 包括生成器部分和辨别器部分,生成网路包括如下两个分支:
- 几何生成分支:可微的输出任意拓扑的表面几何结果
- 纹理生成分支:根据查询的表面点来产生纹理场(texture field),还可以扩展到表面的其他属性,比如材质
生成器生成 3D 模型及纹理后,通过一个有效的可微栅格器,将生成的带纹理 3D 模型投影到 2D 的高分辨率图片,并使用 2D 辨别器来区分生成器所生成模型的投影图像和真实图像、生成器所生成模型的投影轮廓和真实轮廓。整个过程都是可微分的,使得整个对抗训练可以从辨别器传递到两个分支。模型整体结构如下图所示。
GET3D 输入是两个采样的高斯分布(z1 和 z2),通过非线性映射网络得到两个隐含表示(w1 和 w2),非线性网络是 8 层的 MLP 网络,每层是 512 维和 leaky-ReLU 的激活函数。两个隐含表示的具体含义如下:
- 纹理隐含向量(w1):用来控制 3D 模型的形状
- 几何隐含向量(w2):用来控制 3D 模型的纹理
3D 几何生成器:GET3D 的几何生成器包含最近提出的可微分表面表征 DMTet,DMTet 将 3D 模型表示成一个可变形三角面片的四面体(tetrahedron)及每个顶点的符号距离场(signed distance field,SDF),通过移动顶点来使得几何形状变形,达到从四面体可微分的恢复 3D 模型表面的目的。通过使用 DMTet,生成器可以生成任意拓扑结构和类别的 3D 模型。
纹理生成器:GET3D 将纹理参数化为纹理场(texture field),纹理生成器以纹理隐含向量(w2)和几何隐含向量(w1)为条件,将具体的表面点,映射到 RGB 表示的颜色空间中。
可微分渲染和训练:GET3D 将生成的 3D 几何模型和纹理场,使用可微分渲染器渲染到 2D 的图片,再使用 2D 辨别器来监督辨别器的学习。辨别器的学习目标是更准确的分辨真实物体的图片和生成物体的渲染图片。
模型实践
幻方 AI 基于英伟达开源的 GET3D 代码,进行多级多卡的训练支持,开源至 HFAI GET3D ,并采用幻方一系列优化工具进行提速升级,包括 hfreduce 并行训练、hfai 集群训练断点挂起等功能。
数据集
GET3D 使用 ShapeNET 数据集进行试验,ShapeNet数据集是一个有丰富标注的、大规模的3D图像数据集。ShapeNetCore 是 ShapeNet 的一个子集,其中包括近 51,300 个独特的 3D 模型。它提供了 55 个常见的对象类别和注释。GET3D 使用 ShapeNetCoreV1 数据集作为训练及测试数据。
我们对 ShapeNET 数据进行了渲染,将 3D 模型全部投影为 2D 图片以备训练。
模型训练
萤火提供了统一的训练管理平台,可以将大量的训练任务依据优先级负载均衡到不同的GPU上执行,充分利用起计算算力。我们只需加入如下几行代码,便可以把训练任务提交给萤火集群运行。
-
登录幻方萤火二号,引入 hfai
import hfai
-
初始化 hfreduce 分布式参数
from hfai.nn.parallel import DistributedDataParallel model = DistributedDataParallel(model, device_ids=[local_rank]) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # training ... for epoch in range(epochs): for step, (x, y) in enumerate(dataloader): # training optimizer.zero_grad() output = model(x) loss_fn(y, output).backward() optimizer.step()
-
对于每一轮训练,加入接收集群调度的逻辑代码,并做好模型 checkpoint 的保存
start_epoch, start_step, others = hfai.checkpoint.init(model, optimizer, ckpt_path='latest.pt') for epoch in range(start_epoch, epochs): for step, (x, y) in enumerate(dataloader): if step < start_step: continue output = model(x) loss_fn(y, output).backward() model.try_save(epoch, step, others=None)
训练结果
我们申请 4 个节点,32 张 A100 显卡进行并行训练。通过hfai命令行工具提交任务,在萤火二号上,以车类别作为数据来训练 GET3D,每个 64 个 kimg 仅耗时 10 分钟左右,相比单卡性能提升了 32 倍以上。GPU 利用率在 95% 左右,充分利用其了显卡的计算性能。模型在训练到第 5000 个 kimg 基本达到了收敛状态。
训练效果可视化如下:
体验总结
GET3D 模型通过 2D 图像训练,生成具有高保真纹理和复杂几何细节的 3D 形状,解决了以往生成模型几何细节缺失、纹理缺失或者下游引擎无法使用等问题。能够生成大量的、高质量的、多样的 3D 内容,帮助多个行业更快速的朝着大规模的 3D 虚拟世界发展。幻方 AI 对 GET3D 代码适配到了多级多卡的训练,借助萤火平台对 GET3D 的训练进行了优化,最终达到了超过单卡 32 倍的训练速度。
综合体验打分如下:
-
研究指数:★★★★
该模型以低成本、2D 监督、端到端的方法实时生成 3D 模型,生成模型几何细节丰富、纹理自然。
-
开源指数:★★★★
数据处理和代码都已经开源,代码逻辑清晰、可读性高,HFAI 对其增加了多级多卡的训练支持。
-
门槛指数:★★
数据量、模型大小中等,普通高性能单卡即可运行,但是正常训练需要花费一周左右。
-
通用指数:★★★★
该方法使用 2D 领域生成网络的思想进行 3D 模型的生成,未来具有很大的拓展空间。
-
适配指数:★★★★★
该项目依赖简单,很容易与幻方AI的训练优化工具结合,提效明显。
幻方 AI 紧跟 AI 研究的前沿浪潮,致力于用领先算力助力AI落地与价值创造,欢迎各方数据研究者与开发者们一同共建。