近年来在生成图像建模中,生成对抗网络(GAN)的应用越来越多。基于样式(style-based)的 GAN 可以生成不同层次的细节,大到头部形状、小到眼睛颜色,它在高保真图像合成方面实现了 SOTA,但其生成过程的计算复杂度却非常高,难以应用于智能手机等移动设备。
近日,一项专注于基于样式的生成模型的性能优化的研究引发了大家的关注。该研究分析了 StyleGAN2 中最困难的计算部分,并对生成器网络提出了更改,使得在边缘设备中部署基于样式的生成网络成为可能。该研究提出了一种名为 MobileStyleGAN 的新架构。相比于 StyleGAN2,该架构的参数量减少了约 71 %,计算复杂度降低约 90 %,并且生成质量几乎没有下降。
StyleGAN2(上)与 MobileStyleGAN(下)的生成效果对比。
论文作者已将 MobileStyleGAN 的 PyTorch 实现放到了 GitHub 上。
论文地址:
https://arxiv.org/pdf/2104.04767.pdf
项目地址:
https://github.com/bes-dev/MobileStyleGAN.pytorch
该实现所需的训练代码非常简单:
StyleGAN2(左)与 MobileStyleGAN(右)的生成效果展示。
下面我们来具体看一下 MobileStyleGAN 架构的方法细节。
MobileStyleGAN 架构
MobileStyleGAN 架构是在基于样式生成模型的基础上构建的,它包括映射网络和合成网络,前者采用的是 StyleGAN2 中的映射网络,该研究的重点是设计了一个计算高效的合成网络。
MobileStyleGAN 与 StyleGAN2 的区别
StyleGAN2 使用基于像素的图像表征,并旨在直接预测输出图像的像素值。而 MobileStyleGAN 使用基于频率的图像表征,旨在预测输出图像的离散小波变换 (DWT)。当应用到 2D 图像,DWT 将信道转换成四个大小相同的信道,这几个信道具有较低的空间分辨率和不同的频带。然后,逆向离散小波变换(IDWT) 从小波域重建基于像素的表征,如下图所示。
StyleGAN2 利用跳远生成器(skip-generator),通过对同一图像的多个分辨率的 RGB 值进行显式求和来形成输出图像。该研究发现,当在小波域中对图像进行预测时,基于跳远连接(skip connection)的预测头对生成图像的质量影响不大。因此,为了降低计算复杂度,该研究采用网络中最后一个块的单个预测头替换跳远生成器。但从中间块中预测目标图像对于稳定的图像合成具有重要意义。因此,该研究为每个中间块添加一个辅助预测头,根据目标图像的空间分辨率对其进行预测。
StyleGAN2 和 MobileStyleGAN 的预测头区别。
如下图所示,调制卷积包括调制、卷积和归一化(左)。深度可分离调制卷积也包括这些部分(中)。StyleGAN2 描述了用于权重的调制 / 解调,该研究分别将它们应用于输入 / 输出激活,这使得描述深度可分离调制卷积更加容易。
StyleGAN2 构造块使用 ConvTranspose(下图左)来 upscale 输入特征映射。而该研究在 MobileStyleGAN 构造块(下图右)中使用 IDWT 当作 upscale 函数。由于 IDWT 不包含可训练参数,该研究在 IDWT 层之后增加了额外的深度可分离调制卷积。
StyleGAN2 和 MobileStyleGAN 的完整构造块结构如下图所示:
基于蒸馏的训练过程
类似于此前的一些研究,该研究的训练框架也基于知识蒸馏技术。该研究将 StyleGAN2 作为教师网络,训练 MobileStyleGAN 来模仿 StyleGAN2 的功能,训练框架如下图所示。