全面超越ViT,美团、浙大等提出视觉任务统一架构VisionLLAMA

人工智能 新闻
沿袭 ViT 的研究思路,我们能否借助创新性的 LLaMA 架构,真正实现语言和图像的架构统一?

半年多来,Meta 开源的 LLaMA 架构在 LLM 中经受了考验并大获成功(训练稳定、容易做 scaling)。

沿袭 ViT 的研究思路,我们能否借助创新性的 LLaMA 架构,真正实现语言和图像的架构统一?

在这一命题上,最近的一项研究 VisionLLaMA 取得了进展。VisionLLaMA 在图像生成(包含 Sora 依赖的底层的 DIT)和理解(分类、分割、检测、自监督)等多个主流任务上相较于原 ViT 类方法提升显著。

图片

  • 论文标题:VisionLLaMA: A Unified LLaMA Interface for Vision Tasks
  • 论文地址:https://arxiv.org/abs/2403.00522
  • 代码地址:https://github.com/Meituan-AutoML/VisionLLaMA

该研究在统一图像和语言架构方面的尝试,可以复用 LLM 社区在 LLaMA 上的训练(稳定且有效的 scaling)、部署等一系列成果。

研究背景

大语言模型是当前学术界研究的热点,其中,LLaMA 是最具影响力和代表性的工作之一,许多最新的研究工作都基于该架构开展,各种应用的解决方案大都建立在该系列的开源模型之上。在多模态模型的进展中,其中许多方法都依赖 LLaMA 进行文本处理、并依赖类似 CLIP 的视觉 transformer 进行视觉感知。同时,许多工作致力于加快 LLaMA 的推理速度、降低 LLaMA 的存储成本。总而言之,LLaMA 现在是事实上最通用、最重要的大语言模型架构。

图片

LLaMA 架构的成功使得本文作者提出了一个简单而有趣的设想:该架构是否可以在视觉模态上同样成功?如果答案是肯定的,那么视觉模型和语言模型都可以使用相同的统一架构,并从为 LLaMA 设计的各种动态部署技术中受益。然而,这是一个复杂的问题,因为这两种模态之间存在一些明显的差异。

首先存在维度差异:文本序列是一维的,而视觉任务需要处理两个或更多维度的数据;其次存在结构差异:许多视觉任务依赖于金字塔结构的骨干网络以获得更好的性能,而 LLaMA 是一个结构上朴素的编码器;第三,需要有效处理不同分辨率的图像和视频输入。

本文旨在解决这些挑战,并弥合不同模态之间的架构差距,具体为提出适应视觉任务的 LLaMA 架构,解决与模态差异相关的难题,并实现通过一种统一的方法对视觉和语言数据进行处理。

本文主要贡献如下:

1. 本文提出 VisionLLaMA,一种类似于 LLaMA 的视觉 transformer 架构,以减少语言和视觉之间的架构差异。

2. 本文研究了使 VisionLLaMA 适应常见视觉任务的方法,包括图像理解和创建(图 1)。本文研究了两种广为人知的视觉架构方案(常规结构和金字塔结构),并评估它们在监督和自监督学习场景下的性能。此外,本文还提出了 AS2DRoPE(即自动缩放 2D RoPE),它将旋转位置编码从 1D 扩展到 2D,并利用插值缩放来适应任意分辨率。

3. 在精确的评估下,VisionLLaMA 在图像生成、分类、语义分割和目标检测等许多代表性任务中明显优于当前主流且被精确微调过的视觉 transformer。大量实验表明,VisionLLaMA 与现有视觉 transformer 相比具有更快的收敛速度和更好的性能。

VisionLLaMA 总体架构设计

图片

常规 Transformer

本文提出的常规 VisionLLaMA 遵循 ViT 的流程,并且尽可能保留 LLaMA 的架构设计。对于一张图像,首先将其变换并展平为一个序列,然后在序列的开头添加一个类别 token,整个序列通过 L 个 VisionLLaMA block 进行处理。与 ViT 不同,VisionLLaMA 不向输入序列添加位置编码,因为 VisionLLaMA 的 block 包含位置编码。具体来说,该 block 与标准的 ViT block 有两点不同:具有位置编码(RoPE)的自注意力和 SwiGLU 激活。本文仍然使用 LayerNorm 而不是 RMSNorm,因为本文通过实验发现前者表现更好(参见表 11g)。block 的结构如图 2 (a) 所示。本文发现在视觉任务中直接应用 1D RoPE 不能很好地推广到不同的分辨率上,因此将其扩展到二维形式:

图片

金字塔结构 Transformer

将 VisionLLaMA 应用于类似 Swin 的基于窗口的 transformer 非常简单,因此本文选择在更强的基线 Twins 上探索如何构建强大的金字塔结构 transformer。Twins 的原始架构利用了条件位置编码、以局部 - 全局注意力的形式进行交错的局部 - 全局信息交换。这些组件在各种 transformer 中十分常见,这意味着在各类 transformer 变体中应用 VisionLLaMA 并不困难。

本文的目标不是发明一种全新金字塔结构的视觉 transformer ,而是如何在现有设计的基础上调整 VisionLLaMA 的基本设计,因此本文遵循对架构和超参数进行最少修改的原则。遵循 ViT 的命名方式,两个连续的 block 可以写为:

图片

其中 LSA 是组内的局部自注意力操作,GSA 是通过与每个子窗口中的代表性键值交互而进行的全局子采样的注意力。本文移除了金字塔结构 VisionLLaMA 中的条件位置编码,因为 AS2DRoPE 中已经包含了位置信息。此外,还移除了类别 token,并在分类头之前使用 GAP(全局平均池化),该设置下的 block 结构如图 2 (b) 所示。

超越序列长度限制的训练或推理

将一维 RoPE 拓展到二维:对不同的输入分辨率进行处理是视觉任务中的常见需求。卷积神经网络使用滑动窗口机制来处理可变长度。与之相比,大多数视觉 transformer 应用局部窗口操作或插值,例如 DeiT 在不同分辨率上训练时采用双三次插值;CPVT 使用基于卷积的位置编码。本文中评估了 1D RoPE 的性能,发现其在 224×224 分辨率上拥有最高的精度,然而当分辨率上升到 448×448 时,精度急剧下降甚至为 0。因此,本文将一维 RoPE 扩展到二维。对于多头自注意力机制,二维 RoPE 在不同头之间共享。

位置插值有助于二维 RoPE 更好地泛化:受一些使用插值来扩展 LLaMA 的上下文窗口的工作启发,在更高分辨率的参与下,VisionLLaMA 采用类似方式扩展二维上下文窗口。与具有扩大的固定上下文长度的语言任务不同,目标检测等视觉任务通常在不同的迭代中处理不同的采样分辨率。本文使用 224×224 的输入分辨率对小模型进行训练,并在不重新训练的情况下评估更大分辨率的性能,指引本文能够更好的应用内插值或外差值策略。经过实验,本文选择应用基于 “锚点分辨率” 的自动缩放插值(AS2DRoPE)。对 H × H 的方形图像和 B × B 的锚点分辨率进行处理的计算方式如下:

图片

这种计算方式效率高并且不会引入额外的成本。如果训练分辨率保持不变,AS2DRoPE 会退化为 2 维 RoPE。

图片

由于需要将位置信息添加到汇总的键值中,本文对于金字塔结构设置下的 GSA 进行了特殊处理。这些子采样的键值是通过特征图上的抽象生成的。本文使用内核大小为 k×k 且步长为 k 的卷积。如图 3 所示,生成的键值的坐标可以表示为采样特征的平均值。

实验结果

本文全面评估了 VisionLLaMA 在图像生成、分类、分割和检测等任务上的有效性。默认情况下,本文所有模型均在 8 个 NVIDIA Tesla A100 GPU 上进行训练。

图像生成

基于 DiT 框架的图像生成:本文选择在 DiT 框架下应用 VisionLLaMA,因为 DiT 是使用视觉 Transformer 和 DDPM 进行图像生成的代表性工作。本文用 VisionLLaMA 替换了 DiT 原来的视觉 transformer,同时保持其他组件与超参数不变。该实验证明了 VisionLLaMA 在图像生成任务上的通用性。与 DiT 相同,本文设置 DDPM 的 sample steps 为 250,实验结果如表 1 所示。与大多数方法保持一致,FID 被视为主要指标,并在其他次要指标上例如 sFID、Precision/Recall、Inception Score 进行评估。结果表明,VisionLLaMA 在各种模型尺寸上都显着优于 DiT。本文还将 XL 模型的训练步数扩展到 2352k,以评估本文的模型是否具有更快的收敛优势,或者在更长的训练周期设置下仍然表现更好。DiT-LLaMA-XL/2 的 FID 比 DiT-XL/2 低 0.83,表明 VisionLLaMA 不仅具有更好的计算效率,而且比 DiT 具有更高的性能。图 1 中展示了使用 XL 模型生成的一些示例。

图片

基于 SiT 框架的图像生成:SiT 框架显著提高了使用视觉 transformer 生成图像的性能。本文用 VisionLLaMA 替换 SiT 中的视觉 transformer,以评估更好的模型架构带来的收益,本文将其称为 SiT-LLaMA。实验保留了 SiT 中其余所有设置与超参数,所有模型都使用相同数量的步骤进行训练,在所有实验中都使用线性插值(linear interpolant)和快速模型(velocity model)。为了进行公平比较,本文还重新运行已发布的代码,并使用 250 steps 的 SDE 采样器(Euler)对 50k 256×256 图像进行采样,结果如表 2 中所示。SiT-LLaMA 在各种容量级别的模型中均优于 SiT。与 SiT-L/2 相比,SiT-LLaMA-L/2 降低了 5.0 FID,其幅度大于新框架带来的提升(4.0 FID)。本文还在表 13 中展示了更高效的 ODE 采样器 (dopri5),与本文方法的性能差距仍然存在。可以得出与与 SiT 论文中的类似的结论:SDE 比其对应的 ODE 具有更好的性能。

图片

ImageNet 上的图像分类

  • 全监督训练

本节重点关注模型在 ImageNet-1K 数据集上的全监督训练,排除其他数据集或蒸馏技巧的影响,所有模型均使用 ImageNet-1K 训练集进行训练,并在表 3 中展示了在验证集上的准确性结果。

图片

常规视觉 Transformer 的比较:DeiT3 是当前最先进的常规视觉 transformer,它提出了一种特殊的数据增强并执行广泛的超参数搜索以提高性能。DeiT3 对超参数敏感并且容易出现过拟合,用 GAP(全局平均池化)替换类别 token 会导致 DeiT3-Large 模型在经过 800 个 epoch 训练后准确率下降 0.7%。因此,本文在常规 transformer 中使用类别 token 而不是 GAP。结果如表 3 中所示,其中 VisionLLaMA 取得了与 DeiT3 相当的 top-1 精度。单一分辨率上的准确性并不能提供全面的比较,本文还评估了不同图像分辨率的性能,结果如表 4 所示。对于 DeiT3,本文使用双三次插值来进行可学习的位置编码。尽管这两个模型在 224×224 分辨率下具有相当的性能,但当分辨率增加时,差距会扩大,这意味着本文的方法在不同分辨率下具有更好的泛化能力,这对于目标检测等许多下游任务来说至关重要。

图片

金字塔结构的视觉 transformer 比较:本文使用与 Twins-SVT 相同的架构,详细配置列于表 17。本文移除了条件位置编码,因为 VisionLLaMA 已经包含一种旋转位置编码。因此,VisionLLaMA 是一种无卷积架构。本文沿用 Twins-SVT 中的包含超参数在内的所有设置,与 Twins-SVT 保持一致,本文不使用类别 token,而是应用 GAP。结果如表 3 所示,本文的方法在各个模型级别上都实现了与 Twins 相当的性能,并且始终优于 Swin。

  • 自监督训练

本文使用 ImageNet 数据集评估自监督视觉 transformer 的两种常见方法,同时将训练数据限制为 ImageNet-1K,移除了任何使用 CLIP、DALLE 或蒸馏等可以提高性能的组件,本文的实现基于 MMPretrain 框架,利用 MAE 框架并使用 VisionLLaMA 替换编码器,同时保持其他组件不变。该对照实验能够评估本文方法的有效性。此外,本文使用与所比较方法相同的超参数设置,在这种设置下,与强大的基线相比依然实现了显着的性能提升。

Full fine-tuning 设置:在当前设置下,模型首先使用预训练的权重进行初始化,然后使用完全可训练的参数进行额外的训练。VisionLLaMA-Base 在 ImageNet 上经过 800 个 epoch 的训练,达到了 84.0% 的 top-1 准确率,比 ViT-Base 提高了 0.8%。本文的方法训练速度比 SimMIM 快约 3 倍。本文还将训练周期增加到 1600,以验证 VisionLLaMA 能否在足够的训练资源下保持优势。VisionLLaMA-Base 在 MAE 变体中取得了新的 SOTA 结果,top-1 准确率达到 84.3%,比 ViT-Base 提高了 0.9%。考虑到 full fine-tuning 具有性能饱和风险,本文方法的提升十分显着。

Linear probing:最近的一项工作认为线性探测度量(linear probing metric)是对表示性学习更加可靠的评估。在当前设置下,模型由 SSL 阶段的预训练权重初始化。然后,在训练过程中,除了分类器头之外,整个骨干网络都被冻结。结果如表 5 所示:在训练成本为 800 个 epoch 的情况下,VisionLLaMA-Base 的性能优于 ViTBase-MAE 4.6%。它还超过了训练了 1600 个 epoch 的 ViT-Base-MAE。当 VisionLLaMA 训练 1600 个 epoch 时,VisionLLaMA-Base 达到了 71.7% 的 top1 准确率。本文方法还扩展到 VisionLLaMA-Large,相比 ViT-Large 提高了 3.6%。

图片

ADE20K 数据集上的语义分割

  • 全监督训练

按照 Swin 的设置,本文在 ADE20K 数据集上使用语义分割来评估本文方法的有效性。为了进行公平比较,本文限制基线模型仅使用 ImageNet-1K 进行预训练。本文使用 UperNet 框架,并用金字塔结构 VisionLLaMA 替换主干网络。本文的实现基于 MMSegmentation 框架。模型训练步数设置为 160k,全局 batch size 为 16。结果如表 6 中所示,在相近的 FLOP 下,本文的方法比 Swin 和 Twins 的性能高出 1.2% mIoU 以上。

图片


  • 自监督训练

本文使用 UperNet 框架在 ADE20K 数据集上进行语义分割,用 VisionLLaMA 替换 ViT 主干,同时保持其他组件和超参数不变。本文的实现基于 MMSegmentation,结果如表 7 所示。对于 800 个 epoch 的预训练组,VisionLLaMA-B 将 ViT-Base 显着提升了 2.8% mIoU。本文方法还明显优于其他一些改进,例如引入额外的训练目标或特征,这些方法会给训练过程带来额外的开销并降低训练速度。相比之下,VisionLLaMA 仅涉及基础模型的替换,并且具有快速的训练速度。本文进一步评估了 1600 个较长预训练 epoch 的性能,VisionLLaMA-B 在 ADE20K 验证集上实现了 50.2% mIoU,这使得 ViT-B 的性能提高了 2.1% mIoU。

图片

COCO 数据集上的目标检测

  • 全监督训练

本文评估了金字塔结构 VisionLLaMA 在 COCO 数据集上的目标检测任务的性能。本文使用 Mask RCNN 框架并用金字塔结构 VisionLLaMA 替换主干网络,类似于 Swin 的设置,该金字塔结构 VisionLLaMA 在 ImageNet-1K 数据集上预训练了 300 个 epoch。因此,本文的模型具有与 Twins 相同数量的参数和 FLOP。该实验能够用于验证本文方法在目标检测任务上的有效性。本文的实现基于 MMDetection 框架,表 8 中展示了标准的 36 个 epoch 训练周期 (3×) 的结果,本文的模型优于 Swin 和 Twins。具体来说,VisionLLaMA-B 比 Swin-S 高出 1.5% 的 box mAP 和 1.0% mask mAP。与更强的基线 Twins-B 相比,本文的方法具有在 box mAP 上高出 1.1% ,在 mask mAP 上高出 0.8% 的优势。

图片

  • 自监督训练

本文应用基于 ViTDet 框架的 VisionLLaMA,该框架利用常规视觉 transformer 来实现与对应金字塔结构视觉 transformer 相当的性能。本文使用 Mask RCNN 检测器,并用 VisionLLaMA-Base 模型替换 vit-Base 主干网络,该模型使用 MAE 预训练 800 轮。原始的 ViTDet 收敛缓慢,需要专门的训练策略,例如更长的训练周期才能实现最佳性能。在训练过程中,本文发现 VisionLLaMA 在 30 个 epoch 后达到了相似的性能,因此,本文直接应用标准的 3x 训练策略。本文方法的训练成本仅为基线的 36%。与所比较方法不同,本文方法不进行最佳超参数搜索。结果如表 9 所示,VisionLLaMA 在 Box mAP 上优于 ViT-B 0.6%,在 mask mAP 上优于 ViT-B 0.8%。

图片

消融实验与讨论

消融实验

本文默认选择在 ViT-Large 模型上进行消融实验,因为本文观察到该模型在多次运行中产生的方差较小。

图片

FFN 和 SwiGLU 的消融:本文用 SwiGLU 替换 FFN ,结果如表 11a 中所示。由于明显性能差距,本文选择使用 SwiGLU 以避免对 LLaMA 架构引入额外的修改。

归一化策略的消融:本文对 transformer 中两种广泛使用的归一化方法 RMSNorm 和 LayerNorm 进行了比较,结果如表 11g 中所示。后者具有更好的最终性能,这表明重新居中不变性(re-centering invariance)在视觉任务中也很重要。本文还计算了每次迭代花费的平均时间用来衡量训练速度,其中 LayerNorm 仅比 RMSNorm 慢 2%。因此,本文选择 LayerNorm 而不是 RMSNorm 以获得更均衡的性能。

部分位置编码:本文使用 RoPE 调整全部 channel 的比率,结果如表 11b 中所示,结果表明将比率设置在小阈值上即可获得良好的性能,不同的设置之间没有观察到存在显着的性能差异。因此,本文保留 LLaMA 中的默认设置。 

基础频率:本文对基础频率进行更改与比较,结果如表 11c 中所示,结果表明,性能对于大范围的频率来说是稳健的。因此,本文保留 LLaMA 中的默认值以避免部署时的额外特殊处理。 

每个注意力头之间共享位置编码:本文发现,在不同头之间共享相同的 PE(每个头中的频率从 1 到 10000 变化)比独立的 PE(所有通道中的频率从 1 到 10000 变化)要好,结果如表 11d 所示。 

特征抽象策略:本文在大参数规模的模型(-L)上比较了两种常见的特征提取策略:类别 token 和 GAP ,结果如表 11e 中所示,使用类别 token 比 GAP 更好,这与 PEG [13] 中所得到的结论不同。然而,两种方法的训练设置截然不同。本文还使用 DeiT3-L 进行了额外的实验,得到了类似的结论。本文进一步评估 “小型”(-S)和 “基础”(-B)模型的性能。有趣的是,在小模型中观察到了相反的结论,有理由怀疑 DeiT3 中使用的较高丢弃路径率(drop-path rate)使得诸如 GAP 之类的无参数抽象方法(parameter-free abstraction)难以达到应有的效果。 

位置编码策略:本文还在金字塔结构 VisionLLaMA-S 上评估了其他绝对位置编码策略,例如可学习位置编码 和 PEG。由于存在强大的基线,本文使用 “小” 模型,结果显示在表 11f 中:可学习的 PE 不会提高性能,PEG 将基线从 81.6% 略微提高到 81.8%。出于三个原因,本文并没有将 PEG 作为基本组成部分。首先,本文尝试对 LLaMA 进行最小程度的修改。其次,本文的目标是为 ViT 等各种任务提出一种通用方法。对于像 MAE 这样的屏蔽图像框架(masked image frameworks),PEG 增加训练成本,并可能损害下游任务上的性能。原则上,可以在 MAE 框架下应用稀疏 PEG,但会引入部署不友好的算子。稀疏卷积是否与其密集版本一样包含足够的位置信息仍然是一个未解决的问题。第三,无模态束缚的设计为进一步研究涵盖文本和视觉之外的其他模态铺平了道路。 

对输入尺寸的敏感性:在未训练的前提下,本文进一步比较了增大分辨率和常用分辨率的性能,结果如表 12 中所示。这里使用了金字塔结构 transformer,因为其在下游任务中比对应的非层次结构版本更受欢迎。1D-RoPE 的性能因分辨率变化而受到严重影响并不奇怪。α = 2 的 NTK-Aware 插值实现了与 2D-RoPE 类似的性能,2D-RoPE 实际上是 NTKAware (α = 1)。AS2DRoPE 展示出了在较大分辨率上的最佳性能。

图片

讨论

收敛速度:对于图像生成,本文研究了不同训练步数下的表现,分别在 100k、200k、300k 和 400k 次迭代时存储权重来计算保真度指标。由于 SDE 明显慢于 ODE,因此本文选择使用 ODE 采样器。表 10 中的结果表明 VisionLLaMA 在所有模型上的收敛速度都比 ViT 快得多。具有 30 万次训练迭代的 SiT-LLaMA 性能甚至优于具有 40 万次训练次数的的基线模型。

图片

本文还与图 4 中 ImageNet 上使用 DeiT3-Large 全监督训练 800 个 epoch 的 top-1 精度进行了比较,表明 VisionLLaMA 比 DeiT3-L 收敛得更快。本文进一步比较了 MAE 框架下 ViT-Base 模型的 800 个 epoch 的训练损失,并在图 5 中进行了说明。VisionLLaMA 在开始时具有较低的训练损失,并将该趋势保持到最后。

图片

责任编辑:张燕妮 来源: 机器之心
相关推荐

2012-07-02 10:45:38

国产CPU龙芯MIPS

2024-09-10 14:00:00

英伟达架构AI

2021-03-01 10:01:22

开发技能编码

2022-07-20 22:53:44

CCNNSOTACNN 架构

2010-07-29 23:05:57

思科城市云

2024-07-09 13:06:52

2012-11-08 15:20:29

AMDARM数据中心

2015-04-27 13:54:10

2020-06-18 10:46:12

IBM侯淼

2024-10-29 14:10:00

AI模型

2022-04-18 15:56:49

AI模型系统

2024-05-16 17:58:30

线程任务线程通讯线程池

2024-04-17 13:20:29

2009-04-16 12:31:42

交换数据中心H3C

2012-12-07 11:32:33

Exchange 20

2021-12-14 15:59:38

数据模型推理

2010-10-08 13:53:02

Silverlight

2021-11-11 10:37:23

Memblaze

2022-11-08 15:05:49

模型参数

2023-02-22 09:53:55

架构芯片
点赞
收藏

51CTO技术栈公众号