修改一行代码就能实现高效微调!上海交大&腾讯开源SaRA:兼顾原始生成和下游任务
文章链接:https://arxiv.org/pdf/2409.06633
项目链接:https://sjtuplayer.github.io/projects/SaRA/
1.引言
SaRA是一种针对预训练扩散模型的高效微调方法。通过微调预训练扩散模型中的无效参数,赋予模型对下游任务的处理能力。SaRA能够显著节省计算显存开销与代码复杂度,仅修改一行训练代码即可实现微调过程。该方法的核心创新在于:
- 参数重要性分析:SaRA首先对预训练模型中的参数重要性进行分析,发现预训练扩散模型中绝对值最小的10%至20%的参数在生成过程中的作用微乎其微。并且这些参数的无效性并非模型固有属性,而是由于训练过程中的不稳定性导致。
- 稀疏低秩训练:基于上述发现,SaRA提出利用这些暂时无效的参数,通过优化稀疏权重矩阵来学习特定任务的知识。为了避免过拟合,SaRA采用了基于核范数的低秩稀疏训练方案,有效约束了学习过程中的参数秩。
- 渐进式参数调整策略:SaRA设计了一种参数重调整策略,通过在微调过程中重定向可训练参数,确保几乎所有参数都能有效地贡献于新任务的学习。
- 非结构化反向传播策略:SaRA提出了一种新颖的反向传播策略,显著降低了微调过程中的内存成本。
SaRA在多个下游任务上进行了广泛的实验验证,包括基模型能力提升、下游数据微调、图像定制化、可控视频生成等。实验结果表明SaRA不仅能够提升基础模型在原始任务的生成能力,在下游任务中,能兼顾下游任务的学习以及预训练先验的维护,实现优越的模型微调效果。
2. 参数重要性分析
2.1 预训练模型中的无效参数
图 1:Stable Diffusion预训练模型参数分布与小参数对生成结果的影响
2.2 无效参数的潜在有效性
2.1中导致无效参数的原因可能有两个:一是由于模型结构设计的原因,这些参数天生就是冗余、无效的参数,因此无法在训练过程中起到作用,另外一个原因则可能是由于模型训练过程中的随机性,导致这些参数恰好在训练结束的时候停留在0附近。因此,作者进一步对参数无效的原因展开研究。选取了Stable Diffusion在FFHQ的预训练模型,标记了初始权重最小的1%参数,将该模型继续在FFHQ上训练,并在训练过程中实时跟踪这1%参数的变化,结果如图 2所示,可见,随着训练的进行,初始的小参数(蓝色线条)逐渐跳出了1%的阈值,而初始大于1%阈值的参数,大部分跌入了1%以内,并且小于该阈值的参数总量始终维持在1%左右,证明了在训练过程中,所有参数都以一定的概率跳出或者跌入1%阈值中,说明初始的小参数是由训练过程的随机性导致的,因此,可以在微调过程中利用这些暂时无效的参数,赋予模型针对下游任务的生成能力。
图 2:训练过程中权重绝对值小于初始1%阈值θ_t的参数分布变化
3. 方法介绍
3.1 稀疏矩阵训练
3.2 基于核范数的低秩约束
3.3 渐进式参数调整
在模型的微调过程中,由于训练的随机性,仍然会存在部分参数停留在阈值以下,尤其是微调过程的总轮次往往较少,导致最终存在一部分的参数仍然无效。如图 2 所示,初始的小参数在训练初期会快速跳出阈值,而后期的趋势逐渐放缓,当微调轮次较少时,可训练参数中可能存在15%的参数仍然无效。因此,SaRA提出渐进式的参数调整策略,在微调的前半阶段,首先对初始的无效参数进行训练,使得大部分的无效参数跳出阈值,而在后半阶段,再次对剩余的无效参数进行训练,使其快速跳出阈值。通过这种分阶段的渐进式训练策略,SaRA可以更有效地利用这些无效参数,提高模型在新任务上的性能。
3.4 非结构化反向传播策略
图3:非结构化梯度反向传播
4.实验效果
为了验证方法的有效性,SaRA在多个主流与下游的任务上进行了验证,包含基模型提升、下游数据集微调、图像定制化与可控视频生成。
4.1 基模型提升
SaRA主要致力于将预训练模型中的无效参数利用起来,赋予模型更强大的生成能力,这与一般微调方法仅针对下游任务设计的理念不尽相同。因此,SaRA可以用来提升预训练模型在原本任务上的生成能力。实验选取了在ImageNet、FFHQ、CelebA-HQ上预训练的Stable Diffusion,利用SaRA在相应数据集对模型进行进一步的微调,以完全利用模型中的无效参数,结果如图4所示,可以看出,SaRA能够稳定地提升基模型的生成能力(降低约5%的FID)。
图4:基模型在原始任务上的微调
4.2下游数据集微调
在下游数据集微调实验中,将SaRA应用于多个不同的数据集,并在不同StableDiffusion版本(1.5,2.0,3.0)与参数规模(50M,20M,5M)下进行了训练。数据集包括BarbieCore, Cyberpunk, Elementfire, Expedition, Hornify五个风格,结果如图5所示,可见SaRA取得了学习到了最丰富的数据特征,同时能够保持语义与文本的一致性。此外,实验还计算了生成数据的FID,与文本的CLIP Score,以及一个归一化指标VLHI同时衡量FID与CLIP Score,定量结果如表1所示,可见,SaRA在不同版本的Stable Diffusion以及不同的参数量下,均取得了最好的表现。
图5:不同微调方法在下游数据集微调的表现。
表1:不同微调方法在下游数据集微调的定量表现。
4.3 图像定制任务
图像定制化通过从少量几张图像中学习到共有的对象特征,然后将该对象生成到新的图片中。Dreambooth作为一种主流的图像定制化,需要微调扩散模型实现对目标特征的捕捉,因此,SaRA可以直接用于Dreambooth的微调过程。实验比较了不同微调方法在DreamBooth上的表现,定性结果如图6所示,可见,SaRA在成功捕捉目标对象特征的同时,还较好地维护了生成图像语义与文本的一致性。表2计算了不同方法在三个定制化数据集上的定量表现,可以看出,SaRA同时兼顾了特征捕捉与图文一致性,展现了在定制化任务重的优秀表现。
图6:不同微调方法在Dreambooth上的定性表现
表2:不同微调方法在Dreambooth上的定量表现
4.4 视频生成任务
SaRA不仅在图像生成任务中大展身手,在视频生成任务重也同样能取得较好的结果。实验将不同微调方法应用在视频生成模型AnimateDiff上,在不同运镜数据集下进行微调(镜头放大、缩小、右移)。结果如图7所示,其他的微调方法在视频生成任务中展现出一定的过拟合与内容崩溃的问题,相较之下,SaRA在微调过程中展现出丰富的运镜理解能力,同时较好地维护了模型的先验,保证了生成视频的质量以及与文本的一致性。
图7:不同微调方法在可控运镜的视频生成上的表现
4.5 计算资源比较
SaRA引入了非结构化梯度回传策略,有效解决了基于参数选择的微调方法中遇到的显存开销大的问题。图8的实验比较了LT-SFT(一种基于参数选择的方法)、LoRA以及SaRA在Stable Diffusion 2.0用不同Batch size训练过程中的显存开销与训练时间。可以看出,SaRA比LT-SFT减少了固定的9.2GB显存占用(对应所有参数的梯度占用空间),在Batch Size较小时(<=4)节省了45%的显存。而LoRA随着Batchsize的增大,显存占用急速上升。SaRA在Batch Size=16时比LoRA节省了52%的显存占用,并且节省了49%的训练时间。
图8:基于参数选择的方法、LoRA、SaRA在不同batch size下的显存开销与训练时间
4.6 训练参数分析
一个好的微调方法,在微调过程中,应该能够学习到更多的任务相关的信息,同时最大化保留预训练权重的先验知识。因此,作者实验分析了SaRA与LoRA在Expedition数据集上微调后的学习到的参数ΔP与预训练权重P之间的关系。表3通过F范数量化了ΔP前r维子空间与预训练权重P子空间的相关性,可见SaRA学习到的参数与P相关性更小,说明相较于LoRA学习到了更多的下游任务的知识。此外,还计算了放大因子Amplification Factor,量化了ΔP对P中未强调的特征方向的放大倍数,同样证明了SaRA对新知识更强的学习能力。
表3:SaRA与LoRA训练参数ΔP,与预训练参数P的关系
图9计算了SaRA与LoRA训练后的模型参数ΔP+P与预训练参数P的前r维子空间的相似性,可以看出,SaRA的相似性在95%以上,而LoRA维持在80%附近,证明了融合SaRA训练参数的模型,能够更好地维护预训练权重的先验知识。
图9:SaRA与LoRA训练后的模型参数ΔP+P与预训练参数P的关系
结论
本文提出了 SaRA,一种新颖的参数高效微调方法,该方法充分利用了预训练模型中绝对值最小的无效参数。作者提出了一种基于核范数的低秩损失,以约束学习到的稀疏矩阵的秩,从而避免模型过拟合。此外,设计了一种渐进式参数调整策略,进一步提高了微调参数的有效性。最后,提出了一种新型的非结构化反向传播方法,大大节省了参数微调过程中的内存开销,同时也能降低其他选择性 PEFT 方法的内存成本。大量实验证明了本文方法的有效性,它在保持预训练模型的先验信息的同时,实现了最佳的拟合能力。此外,作者高效地封装了本文的方法,使其只需修改一行代码即可实现,这大大增强了代码在其他模型和任务中的易用性和适应性。
本文转自 AI生成未来 ,作者:AI生成未来