LLM | 利用分布匹配蒸馏技术快速合成图像
一、结论写在前面
最近的一些方法已经显示出将昂贵的扩散模型蒸馏到高效的单步生成器中的前景。其中,分布匹配蒸馏(DMD)能够生成与教师模型在分布上匹配的单步生成器,即蒸馏过程不强制与教师模型的采样轨迹一一对应。然而,为了在实践中确保稳定训练,DMD需要使用教师模型通过多步确定性采样器生成的大量噪声-图像对计算一个额外的回归损失。这不仅在大规模文本到图像合成中计算代价高昂,而且还限制了学生模型的质量,使其过于紧密地绑定到教师模型的原始采样路径。
论文提出了DMD2,一组技术来解决这一限制并改进DMD训练。首先,论文消除了回归损失和昂贵数据集构建的需求。论文表明,由此导致的不稳定性是由于"假"评价器无法充分准确地估计生成样本的分布。因此,论文提出了一种两时间尺度更新规则作为补救措施。其次,论文将GAN损失整合到蒸馏过程中,区分生成样本和真实图像。这使论文能够在真实数据上训练学生模型,从而缓解教师模型中不完美的"真实"分数估计,并因此提高质量。第三,论文介绍了一种新的训练程序,允许学生模型进行多步采样,并通过在训练时模拟推理时的生成器样本来解决之前工作中训练-推理输入不匹配的问题。
总的来说,论文的改进在单步图像生成上设置了新的基准,ImageNet-64×64上的FID分数为1.28,零样本COCO 2014上为8.35,尽管推理成本降低了500倍,但仍优于原始教师模型。此外,论文展示了论文的方法可以生成百万像素级的图像,通过对SDXL进行蒸馏,在少步方法中展现出卓越的视觉质量,并超过了教师模型。论文发布了论文的代码和预训练模型。
二、论文的简单介绍
2.1 论文的背景
扩散模型在视觉生成任务中达到了前所未有的质量水平。但其采样过程通常需要数十次迭代去噪步骤,每次迭代都需通过神经网络进行前向传播。这使得高分辨率文本到图像的合成既缓慢又昂贵。为解决这一问题,研究者们开发了多种蒸馏方法,旨在将一个教师扩散模型转化为一个高效、仅需少数步骤的学生生成器。然而,这些方法往往导致生成质量下降,因为学生模型通常是通过损失函数来学习教师模型的成对噪声到图像映射,但在完美模仿教师行为方面存在困难。
图1:由论文的4步生成器从SDXL蒸馏得到的1024x1024样本。请放大查看细节
尽管如此,应该注意的是,旨在匹配分布的损失函数,如GAN [21] 或 DMD [22] 损失,并不负担精确学习从噪声到图像的具体路径的复杂性,因为它们的目标是在分布上与教师模型对齐——通过最小化学生和教师输出分布之间的Jensen-Shannon(JS)或近似的Kullback-Leibler(KL)散度。
特别是DMD [22]在蒸馏Stable Diffusion 1.5时展现出了最先进的结果,但相比基于GAN的方法[23-29]来说,它受到的研究还不够深入。一个可能的原因是DMD仍然需要一个额外的回归损失来确保稳定训练。反过来,这就需要通过运行教师模型的全部采样步骤来创建数百万个噪声-图像对,对于文本到图像合成来说代价是特别高昂的。回归损失也抵消了DMD的无配对分布匹配目标的关键优势,因为它导致学生的质量被教师所限制。
图2:由论文从SDXL蒸馏出的4步生成器产生的1024x1024样本。请放大查看细节
2.2 改进的分布匹配蒸馏
论文重新审视了DMD算法中的多个设计选择,并确定了显著的改进。
图3:论文的方法将一个成本高昂的扩散模型(灰色,右侧)提炼成一个一步或多步生成器(红色,左侧)。论文的训练交替进行两个步骤:1. 使用隐式分布匹配目标的梯度(红色箭头)和GAN损失(绿色)优化生成器;2. 训练一个得分函数(蓝色)来模拟生成器产生的“假”样本的分布,以及一个GAN判别器(绿色)来区分假样本和真实图像。学生生成器可以是一步或多步模型,如图所示,具有中间步骤输入
2.2.1 移除回归损失:真正的分布匹配与更易于大规模训练
DMD[22]中使用的回归损失[16]确保了模式覆盖和训练稳定性,但正如论文在第3节中讨论的,它使得大规模蒸馏变得繁琐,并且与分布匹配的理念相冲突,因此本质上限制了蒸馏生成器的性能,使其无法超越教师模型。论文的第一个改进是移除这个损失。
2.2.2 通过双时间尺度更新规则稳定纯分布匹配
从DMD中简单地省略回归目标,如式(3)所示,会导致训练不稳定并显著降低质量(见表3)。例如,论文观察到生成的样本的平均亮度以及其他统计量波动很大,没有收敛到一个稳定点(见附录C)。论文将这种不稳定性归因于假扩散模型wfake中的近似误差,它没有准确跟踪假得分,因为它是在生成器的非平稳输出分布上动态优化的。这导致了近似误差和偏置生成器梯度(如[30]中也讨论的)。
论文采用Heusel等人[59]启发下的双时间尺度更新规则来解决这一问题。具体而言,论文以不同频率训练ufake和生成器G,确保ufake精确跟踪生成器的输出分布。论文发现,每进行一次生成器更新,进行5次虚假评分更新,不使用回归损失,能够提供良好的稳定性,并且在ImageNet上与原始DMD的质量相匹配(见表3),同时实现更快的收敛。进一步的分析包含在附录C中。
2.2.3 利用GAN损失和真实数据超越教师模型
到目前为止,论文的模型在不需昂贵的数据集构建的情况下,实现了与DMD相当的训练稳定性和性能(见表3)。然而,蒸馏生成器与教师扩散模型之间仍存在性能差距。论文推测这一差距可能归因于DMD中使用的真实评分函数的近似误差,这会传递到生成器并导致次优结果。由于DMD的蒸馏模型从未接受过真实数据训练,因此无法从这些误差中恢复。
论文通过在论文的流程中加入额外的GAN目标来解决这一问题,其中判别器被训练以区分真实图像和论文生成器生成的图像。通过使用真实数据训练,GAN分类器不受教师模型的限制,可能使论文的学生生成器在样本质量上超越它。论文将GAN分类器整合到DMD中遵循极简主义设计:论文在虚假扩散去噪器的瓶颈层之上添加了一个分类分支(见图3)。分类分支和UNet中的上游编码器特征通过最大化标准非饱和GAN目标进行训练:
其中D是判别器,是前向扩散过程(即噪声注入),其噪声水平对应于时间步。
2.2.4 多步生成器
通过提出的改进措施,论文能够在ImageNet和COCO上匹配教师扩散模型的性能(见表1和表5)。然而,论文发现,对于像SDXL这样的大规模模型,由于模型容量有限以及学习从噪声到高度多样化和详细图像的直接映射的复杂优化景观,将其提炼成一步生成器仍然具有挑战性。这促使论文将DMD扩展以支持多步采样。
论文固定一个预定的调度,使用个时间步,在训练和推理期间保持不变。在推理过程中,每个步骤交替进行去噪和噪声注入步骤,遵循一致性模型[9],以提高样本质量。
2.2.5 避免训练/推理不匹配的多步生成器模拟
论文通过在训练期间用当前学生生成器运行几个步骤产生的噪声合成图像替换噪声真实图像来解决这个问题,类似于论文的推理流程。这是可行的,因为与教师扩散模型不同,论文的生成器只运行几步。然后,论文的生成器对这些模拟图像进行去噪,输出由提出的损失函数进行监督。使用噪声合成图像避免了不匹配并提高了整体性能(见第5.3节)。
2.2.6汇总一切
总之,论文的蒸馏方法解决了DMD 对预计算噪声-图像对的严格要求。它进一步整合了GAN的优势,并支持多步生成器。如图3所示,从预训练的扩散模型开始,论文交替优化生成器Gθ以最小化原始分布匹配目标以及GAN目标,并优化使用假数据的去噪分数匹配目标和GAN分类损失来优化假分数估计器μfake。为确保假分数估计准确且稳定,尽管是在线优化,论文以比生成器更高的频率(5步比1步)更新它。
2.3论文的效果
论文使用几个基准评估论文的方法DMD2,包括在ImageNet-64x64 上的条件类图像生成,以及在COCO 2014 上使用各种教师模型进行文本到图像合成。
2.3.1 类别条件图像生成
表1比较了论文的模型与最近在ImageNet-64x64上的基准模型。通过单次前向传播,论文的方法显著超越了现有的蒸馏技术,甚至超越了使用ODE采样器的教师模型[52]。论文将这一显著性能归功于移除了DMD的回归损失,这消除了由ODE采样器施加的性能上限,以及论文增加的GAN项,这减轻了教师扩散模型分数近似误差的不利影响。
2.3.2文本到图像合成
论文在零样本COCO 2014上评估DMD2的文本到图像生成性能。论文的生成器通过蒸馏SDXL和SD v1.5进行训练,分别使用来自L.AION-Aesthetics[58]的300万提示的子集。此外,论文还从LAION-Aesthetic收集了50万张图像作为GAN判别器的训练数据。表2总结了SDXL模型的蒸馏结果。论文的4步生成器产生了高质量且多样的样本,实现了19.32的FID分数和0.332的CLIP分数,与教师扩散模型在图像质量和提示一致性上相媲美。
为了进一步验证论文方法的有效性,论文进行了一项广泛的用户研究,比较论文的模型输出与教师模型和现有蒸馏方法的输出。论文使用了LADD之后的PartiPrompts的一个子集128个提示。对于每次比较,论文要求一组随机的五名评估者选择视觉上更吸引人的图像,以及更好地代表文本提示的图像。关于人类评估的详细信息包含在附录H中。如图5所示,论文的模型在用户偏好上远超基线方法。值得注意的是,论文的模型在图像质量上超越其教师模型的样本占比达到249%,并且在提示对齐上达到可比性,同时需要的前向传播次数减少了25倍(4 vs 100)。
图5:用户研究比较了论文蒸馏的模型与其教师模型以及竞争性的蒸馏基线[23, 27, 31]。所有蒸馏模型使用4个采样步骤,教师模型使用50个。论文的模型在图像质量和提示对齐方面均达到了最佳性能
2.3.3 消融研究
表3在ImageNet上消融了论文提出的方法的不同组件。简单地从原始DMD中移除ODE回归损失导致由于训练不稳定而降级的FID为3.48(见附录C中的进一步分析)。然而,结合论文的双时间尺度更新规则(TTUR)缓解了这种性能下降,无需额外的数据集构建即可匹配DMD基线性能。添加论文的GAN损失进一步将FID提高了1.1个点。论文集成的方案超越了仅使用GAN(没有分布匹配目标)的性能,并且将双时间尺度更新规则添加到单独的GAN中并没有改善它,突出了在统一框架中结合分布匹配与GAN的有效性。
在表4中,论文分析了GAN项、分布匹配目标和反向模拟(对将SDXL模型蒸馏为四步生成器的影响。定性结果如图7所示。在没有GAN损失的情况下,论文的基线模型生成的图像过度饱和且过度平滑(图7第三列)。同样,移除分布匹配目标将论文的方法简化为纯粹的基于GAN的方法,这面临着训练稳定性的挑战。
此外,纯粹的基于GAN的方法也缺乏一种自然的方式来整合无分类器指导,这对于高质量的文本到图像合成至关重要。因此,在背景中,戴着太阳镜的骆驼坐在太空船甲板上的照片,尽管基于GAN的方法通过紧密匹配真实分布实现了最低的FID,但它们在文本对齐和美学质量方面显著表现不佳(图7第二列)。同样,省略反向模拟会导致图像质量下降,这一点通过退化的块FID分数得以体现。。
图6:论文的模型、SDXL教师模型以及选定竞争方法[23, 27, 31]之间的视觉对比。所有蒸馏模型使用4个采样步骤,而教师模型使用50个采样步骤并结合无分类器指导。所有图像均使用相同的噪声和文本提示生成。论文的模型生成的图像在真实感和文本对齐方面表现更优。(放大查看细节。)更多对比见附录图10
论文标题:Improved Distribution Matching Distillation for Fast Image Synthesis
论文链接:https://arxiv.org/pdf/2405.14867
本文转载自 AI帝国,作者: 无影寺