大家好,我是来自 NVIDIA GPU 计算专家团队的陶砺,很高兴今天有机会在这里跟大家分享一下我和我的同事陈庾,在 Swin Transformer 这个视觉大模的型训练和推理优化上的一些工作。其中一些的方法与策略,在其他的模型训练、推理的优化上都可以使用,来提高模型的吞吐、提升 GPU 的使用效率、加快模型的迭代。
我会介绍 Swin Transformer 模型的训练部分的优化,在推理优化部分的工作,将由我的同事来做详细的介绍
这里是我们今天分享的目录,主要分为四个部分,既然是针对特定模型进行的优化,那么我们首先会简单介绍一下 Swin Transformer 模型。然后,我会结合 profiling 的工具,也就是 nsight system 对训练的流程进行分析和优化。在推理部分,我的同事会给出推理优化的策略和方法,包含较为细节的 cuda 层面的优化。最后,是今天优化内容的一个总结。
首先是第一部分,也就是 Swin Transformer 的介绍。
1. Swin Transformer 简介
从模型的名称我们可以看出,这是一个基于 transformer 的模型,我们先对 transformer 进行一下简单的回顾。
Transformer 模型从 attention is all you need 这篇文章中被提出后,在自然语言处理领域的很多任务上大放异彩。
Transformer 模型的核心就是所谓的注意力机制,也就是 attention mechanism。对于注意力模块,通常的输入是 query,key 和 value 三个张量。通过 query 和 key 的作用,加上 softmax 的计算,可以得到通常被称为 attention map 的注意力结果,根据 attention map 中的数值的高低,模型就可以学习到需要更加注意 value 中的哪些区域,或者说模型可以学习到,value 中的哪些数值对我们的任务有很大的帮助。这就是最基础的单头注意力模型。
我们通过增加这样单头注意力的模块的数量,也就可以构成常见的多头注意力模块。常见的 encoder、decoder 都是基于这样的多头注意力模块搭建的。
很多模型通常包含了 self-attention,cross-attention 这两种注意力模块,或者是一个或多个模块的堆叠。如著名的 BERT 就是由多个 encoder 模块组成,现在大热的 diffusion 模型通常同时包含了 self-attention 和 cross-attention。
在 Swin Transformer 之前, Vision Transformer (ViT) 首先将 transformer 应用到了计算机视觉领域。ViT 的模型结构,如下图左侧所示,ViT 会将一个图像分割成一系列的 patch,每一个 patch 类比于自然语言处理中的 token,然后通过一个 Transformer-based 的 encoder 对这一系列 patch 进行 encode,最后得到可用于分类等任务的 feature。
而来到 Swin Transformer,它引入了 window attention 的概念,不同于 ViT 对整个图像进行 attention,Swin Transformer 会先将图像划分成若干个 window,然后仅对 window 内部的 patch 进行 attention,从而减少计算量。
为了弥补 window 带来的边界问题,Swin Transformer 进一步引入 window shift 的操作。同时为了使得模型有更丰富的位置信息,还在 attention 时引入了 relative position bias。其实这里的 window attention 和 window shift,就是 Swin Transformer 中的 Swin 名称的由来。
这里给出的是 Swin Transformer 的网络结构,大致的一个网络结构和传统的 CNN 如 ResNet 十分相近。
可以看到整个网络结构被划分为多个 stage,在不同 stage 中间,会有对应的降采样的过程。每个 stage 的分辨率是不一样的,从而形成了一个分辨率金字塔,这样也使得每个 stage 的计算复杂程度也逐渐降低。
然后每个 stage 中会有若干个 transformer block。每一个 transformer block 中,就会用到上面提到的 window attention 模块。
接下来,我们从具体操作的角度来对 Swin Transformer 进行解构。
可以看到,一个 transformer block 中涉及到三大部分,第一部分是 window shift/partition/reverse 的 window 相关的操作,第二部分是 attention 计算,第三部分是 FFN 计算;而 attention 和 FFN 部分又可以进一步细分为若个 op,最终我们可以将整个模型细分为几十个 op 的组合。
这样的算子划分对于我们进行性能分析,定位性能瓶颈以及开展加速优化而言,都是非常重要的。
以上就是第一部分的介绍。接下来,我们来介绍一下在训练上我们进行的一些优化工作,特别的,我们结合 profiling 工具,也就是 nsight system,对整体的训练流程做一个分析和优化。
2. Swin Transformer 训练优化
对于大模型的训练而言,通常会用到多卡、多节点的计算资源。针对 Swin Transformer,我们发现卡间通讯的开销占比会相对较少,随着卡数的增长,整体速度的提升几乎呈现线性的增长,所以在这里,我们优先对单 GPU 上的计算瓶颈进行分析和优化。
nsight system 是一个系统层面的性能分析工具,通过这个工具,我们可以很方便的看到模型的各个模块的 GPU 的使用情况,是否存在数据等待等可能存在的性能瓶颈和优化空间,可以便于我们合理的规划 CPU、GPU 之间的负载。
nsight system 可以捕捉到 CUDA,以及一些 gpu 计算库如 cublas,cudnn,tensorRT 等调用的核(kernel)函数的调用和运行情况,以及可以方便用户添加一些标记,来统计标记范围内对应 gpu 的运行情况。
一个标准的模型优化流程如下图所示,我们对模型进行 profiling,拿到性能分析报告,发现性能优化点,然后有针对性的去做性能调优。
这里是一个 nsight system 的界面,我们可以很清晰地看到核函数的发射,也就是 kernel launch;核函数的运行,也就是这里的 runtime 部分。对于具体的核函数,我们可以看到在整个流程里的时间占比,以及 gpu 是否存在空闲等信息。在添加完 nvtx 标记之后,我们可以看到模型前向,反向所需要的时间。
在前向部分,如果放大,我们也可以清晰地看到具体每个 SwinTransformer Block 的计算需要的时间。
我们首先通过 nsight system 性能分析工具来看一下整个 baseline 的性能表现,下图中展示的就是 FP32 的 baseline,可以看到它的 GPU 利用率是很高的,而其中占比最高的是矩阵乘的 kernel。
那么对于矩阵乘法而言,我们的一个优化手段,就是充分利用 tensor core 进行加速。
我们知道 NVIDIA 的 GPU 内有 cuda core 和 tensor core 这样的硬件资源,tensor core 是专门为了矩阵乘法的加速的模块。我们可以考虑直接采用 tf32 tensor core 或者混合精度下,采用 fp16 tensor core。要知道,使用 fp16 的 tensor core 在矩阵乘法上的吞吐,会比 tf32 要高,对比纯 fp32 的矩阵乘也会有很高的加速效果。
在此,我们采用了混合精度的方案。通过采用 torch.cuda.amp 的混合精度的模式,我们可以取得了 1. 63 倍的吞吐提升。
在 profiling 的结果里也能够很清晰地看到,原本占最高的矩阵乘,经过优化后,在整个 timeline 中的占比降到了 11.9%。至此,占比较高的 kernel 都是 elementwise kernel。
对于 elementwise kernel,我们首先要了解哪里会用到 elementwise 的 kernel。
Elementwise kernel 里,比较常见的 unrolled elementwise kernel 和 vectorized elementwise kernel。其中 unrolled elementwise kernel 广泛存在于一些有偏置的卷积,或者线性层中,以及一些保证数据在内存连续性的 op 中。
vectorized elementwise kernel 则经常出现在一些激活函数,如 ReLU 的计算中。如果想要减少这里大量的 elementwise kernel,一个常见的做法是做算子融合,比如矩阵乘法中,我们可以通过将 elementwise的操作与矩阵乘法的算子融合在一起,来降低这部分的时间开销。
对于算子融合,一般而言可以为我们带来两个好处:
一个是减少 kernel launch 的开销,如下图所示,两个 cuda kernel 的执行需要两次 launch,那样可能会导致 kernel 之间存在 gap,使得 GPU 空闲,那么如果我们将两个 cuda kernel 融合成一个 cuda kernel,一方面节省了一次 launch,同时也可以避免 gap 的产生。
另外一个好处是减少了 global memory 的访问,因为 global memory 的访问是非常耗时的,而两个独立的 cuda kernel 之间要进行结果传递,都需要通过 global memory,将两个 cuda kernel 融合成一个 kernel,我们可以在寄存器或者 share memory 上进行结果传递,从而避免了一次 global memory 写和读,提升性能。
对于算子融合,我们第一步是采用现成的 apex 库来进行 Layernorm 和 Adam 中操作的融合,可以看通过简单的指令替换,我们可以使能 apex 的 fused layernorm 和 fused Adam,从而使得加速从 1.63 倍提升至 2.11 倍。
从 profling 的日志我们也可以看到,经过算子融合之后,elementwise kernel 在这个 timeline 的占比大幅降低,矩阵乘法重新成为时间占比最大的 kernel。
除了利用现有的 apex 库,我们也进行了手工的融合算子开发。
通过观察 timeline,以及对模型的理解,我们发现 Swin Transformer 中有特有的 window 相关操作,如 window partition/shift/merge 等,这里的一次 window shift,需要调用两个 kernel,并在 shift 完成之后调用 elementwise 的 kernel。并且,attention 模块前如果需要做一次这样的操作,那么之后会有对应的 reverse 操作。这里单单 window shift 调用的 roll_cuda_kernel 就在整个 timeline 中占比 4.6%。
刚才提到的这些操作,其实只是对数据进行了划分,即对应的数据会被划分到一个 window 中去,对应的原始代码如下图所示。
我们发现,这部分的操作其实本质上只是 index mapping,因此,我们对这一部分进行的融合算子开发。开发的过程,我们需要掌握 CUDA 编程的相关知识,并且编写算子的前向计算和反向计算的相关代码。
如何向 pytorch 中引入自定义算子,官方给出了教程,我们可以按照教程编写 CUDA 代码,编译好后就可以作为一个模块引入原始的模型。可以看到,通过引入我们的定制化融合算子,我们可以将加速比进一步提升至 2.19 倍。
接下来展示的是,我们对 mha 部分的融合工作。
Mha 部分是 transformer 模型中一个占比很大的模块,因此对它的优化往往可以带来较大的加速效果。从图中可以看到,在没有进行算子融合之前,mha 部分的操作占比为 37.69%,其中包括了不少 elementwise 的 kernel。如果我们能够将相关操作融合成一个独立的 kernel,并具有更快的速度,加速比可以得到进一步提升。
对于 Swin Transformer,这部分的模块除了 query,key 和 value 外,mask 和 bias 都是以 tensor 的形式传入的,我们开发了 fMHA 这样的一个模块,可以将原本的若干 kernel 融合起来。从 fMHA 这个模块涉及到的计算来看,针对 Swin Transformer 中遇到的一些 shape,该模块都有比较显著的提升。
模型用上 fMHA 模块后,我们可以将加速比进一步提升 2. 85 倍。上述是我们在单卡上取得的训练加速效果,那么我们来看一下单机 8 卡的训练情况,可以看到,通过上述优化,我们可以将训练吞吐从 1612 提升至 3733,取得 2.32 倍的加速。
对于训练优化而言,加速比我们希望越高越好,对应的,我们也希望加速后的性能能够与加速前保持一致。
叠加上上述若干加速方案后,可以看到,模型的收敛性与原始的 baseline 保持一致,优化前后的模型的收敛、精度的一致性,在 Swin-Tiny,Swin-Base 以及 Swin-Large 上都得到了验证。
关于训练部分,一些其他的加速策略包括 CUDA graph、multi-stream 等,都能对 Swin Transformer 的性能有进一步提升;其他方面,目前我们介绍的是使用混合精度的方案,也就是 Swin Transformer 官方 repo 采用的策略;使用纯 fp16 的方案(即 apex O2 模式)可以达到更快的加速效果。
虽然 Swin 对通信的要求不高,但是对于多节点大模型的训练,相比于原始的分布式训练,使用合理的策略去隐藏通信的开销,能够在多卡训练上获得进一步的收益。
接下来,有请我的同事来介绍一下我们在推理上的加速方案和效果。
3. Swin Transformer 推理优化
大家好,我是来自英伟达 GPU 计算专家团队的陈庾,非常感谢陶砺在训练加速上的介绍,接下来由我来介绍一下推理上的加速。
跟训练一样,推理的加速离不开算子融合这一方案。不过相对于训练而言,在推理上进行算子融合有更好的灵活性,主要体现有两点:
- 推理上的算子融合不需要考虑反向,所以 kernel 开发过程中不需要考虑保存计算梯度所需要的中间结果;
- 推理过程允许预处理,我们可以对一些只需要一次计算便可重复使用的操作,提前算好,保留结果,每次推理时直接调用从而避免重复计算。
在推理侧,我们可以进行不少的算子融合,这里给出的是我们在 Transformer 模型中常见的一些算子融合的 pattern 以及实现相关 pattern 所需要用到的工具。
首先,我们单独列出矩阵乘法和卷积,是因为有一大类算子融合是围绕他们进行的,对于矩阵乘法相关的融合,我们可以考虑采用 cublas,cutlass,cudnn 这三个库;对于卷积,我们可以采用 cudnn 或者 cutlass。那么对于矩阵乘法的算子融合而言,在 Transformer 模型中,我们归纳为 gemm + elementwise 的操作,比如 gemm + bias, gemm + bias + 激活函数等,这一类的算子融合,我们可以考虑直接调用 cublas 或 cutlass 来实现。
此外,如果我们 gemm 之后的 op 操作比较复杂,比如 layernorm,transpose 等,我们可以考虑将 gemm 和 bias 分开,然后把 bias 融合到下一个 op 中,这样可以更为容易地调用 cublas 来实现简单的矩阵乘法,当然这种 bias 和下一个 op 进行融合的 pattern 一般是需要我们手写 cuda kernel 来实现。
最后,有一些特定 op,同样需要我们以手写 cuda kernel 的方式进行融合,比如 layernorm + shift + window partition。
由于算子融合需要我们比较巧妙地设计 cuda kernel,所以我们一般建议先通过 nsight system 性能分析工具对整体 pipeline 进行分析,优先针对热点模块进行算子融合优化,以达到性能和工作量的平衡。
那么在众多的算子融合优化中,我们挑选了两个加速效果比较明显的算子进行介绍。
首先是 mha 部分的算子融合,我们将 position bias lookup 这一操作提前到预处理部分,从而避免每次推理时都进行 lookup。
然后将 batch gemm,softmax,batch gemm 融合成一个独立的 fMHA kernel,同时我们把 transpose 相关的操作融合到了 fMHA kernel I/O 操作中,通过一定的数据读写的 pattern 来避免显式的 transpose 操作。
可以看到,融合后该部分取得了 10 倍的加速,而端到端也取得了 1.58 倍的加速。
另一个我想介绍一下的算子融合是 QKV gemm + bias 的融合。
gemm 和 bias 的融合是一个十分常见的融合手段,在这里为了配合我们前面提到的 fMHA kernel,我们需要对 weight 和 bias 提前进行格式上的变换。
我之所以在这里选择介绍这个算子融合,也正是因为这种提前变换体现了我们前面提到的,推理上进行算子融合的灵活性,我们可以对模型的推理流程做一些不影响其精度的变化,从而实现更好算子融合 pattern,取得更好的加速效果。
最后,通过 QKV gemm+bias 的融合,我们可以进一步取得 1.1 倍的端到端加速。
下一个优化手段是矩阵乘法 padding。
在 Swin Transformer 的计算中,有时候我们会遇到主维为奇数的矩阵乘法,这时候并不利于我们的矩阵乘法 kernel 进行向量化读写,从而使得 kernel 的运行效率变低,此时我们可以考虑对参与运算的矩阵主维进行 padding 操作,使其变为 8 的倍数,这样一来,矩阵乘 kernel 就可以以 alignment=8,一次读写 8 个元素的方式来进行向量化读写,提升性能。
如下表所示,我们将 n 从 49 padding 到 56 后,矩阵乘法的 latency 从 60.54us 下降为 40.38us,取得了 1.5 倍的加速比。
下一个优化手段是巧用 half2 或者 char4 这样的数据类型。
以下的代码是一个 half2 优化的示例,它实现的是一个简单的加 bias 再加残差这样的算子融合操作,可以看到通过使用 half2 数据类型,相对于 half 数据类,我们可以将 latency 从 20.96us 下降到 10.78us,加速 1.94 倍。
那么采用 half2 数据类型一般有什么好处呢?主要有三点:
第一个好处是向量化读写可以提升 memory 的带宽利用效率并降低访存指令数;如下图右侧所示,通过 half2 的使用,访存指令减少了一半,同时 memory 的 SOL 也有显著提升;
第二个好处是结合 half2 专有的高吞吐的数学指令,可以减低 kernel 的 latency。这两点都已经体现在了这个示例程序中;
第三个好处是在进行 reduction 相关 kernel 开发时,采用 half2 数据类型意味着一个 cuda 线程同时处理两个元素,可以有效减少空闲的线程数,也可以减少线程同步的 latency。
下一个优化手段是巧用寄存器数组。
在我们进行 layernorm 或者 softmax 等 Transformer 模型常见的算子优化时,我们经常需要在一个 kernel 中多次使用同一个输入数据,那么相对于每次都从 global memory 读取,我们可以采用寄存器数组来缓存数据,从而避免重复读取 global memory。
由于寄存器是每个 cuda 线程独占的,所以在进行 kernel 设计时,我们需要提前设定好每个 cuda 线程所需要缓存的元素个数,从而开辟对应大小的寄存器数组,并且在分配每个 cuda 线程所负责元素时,需要确保我们可以做到合并访问,如下图右上侧所示,当我们有 8 个线程时,0 号线程可以处理 0 号元素,当我们有 4 个线程是,0 号线程则处理 0 号和 4 号元素,如此类推。
我们一般建议可以采用模板函数的方式,通过模板参数来控制每 个cuda 线程的寄存器数组大小。
此外,在使用寄存器数组时,需要保证我们的下标是常量,如果是循环变量作为下标,我们应该尽量保证可以进行循环展开,这样可以避免编译器将数据放到了 latency 很高的 local memory 中,如下图所示,我们在循环条件中添加限制,通过 ncu report 可以看到,避免了 local memory 的使用。
最后一个我想介绍优化手段是 INT8 量化。
INT8 量化是推理加速非常重要的加速手段,对于 Transformer based 的模型而言,INT8 量化可以在减少显存消耗的同时带来更好的性能。
而对于 Swin 来说,通过结合合适的 PTQ 或 QAT 量化方案,可以在取得良好加速的同时,保证量化精度。一般我们进行 int8 量化,主要是对矩阵乘法或者卷积进行量化,比如 int8 矩阵乘法中,我们会先将原始的 FP32 或 FP16 的 input 和 weight 量化为 INT8 然后再进行 INT8 矩阵乘法,累加到 INT32 数据类型上,这是我们会进行反量化操作,得到 FP32 或 FP16 的结果。
比较常见调用 INT8 矩阵乘法的工具是 cublasLt,为了可以取得更好的性能,我们有必要深入地了解一下 cublasLt api 的一些特性。
cublasLt 对于 int8 矩阵乘法,提供了两种输出类型,分别是下图左侧所示,以 INT32 输出,或者下图右侧所示,以 INT8 输出,图中蓝框所示的 cublasLt 的计算操作。
可以看到相对于 INT32 输出而言, INT8 输出会多了一对反量化和量化操作,这样一来一般会带来更多的精度损失,但是由于 INT8 输出,在写出到 global memory 时相对 INT32 输出少了 3/4 的数据量,性能会更好,所以这里面存在着精度和性能 tradeoff。
那么对于 Swin Transformer 而言,我们发现配合 QAT,以 INT8 输出会在取好的加速比的前提下,保证精度,因为我们采用了 INT8 输出的方案。
另外,关于 cublasLt 中 INT8 矩阵乘法,还需要考虑数据的布局问题,cublasLt 支持两种布局,一种 IMMA-specific 的布局,其中涉及到一些比较复杂的格式,而且在这种布局只支持 NT-gemm,另外一种是常规的列优先的布局,在此布局下支持 TN-gemm。
一般来说,采用列优先的布局,会更有利于整个 pipeline 代码的开发,因为如果我们用 IMMA-specific 布局的话,我们为了兼容这种布局可能需要很多额外的操作,以及上下游 kernel 也需要为这种特殊布局做兼容。但是在一些尺寸的矩阵乘法上,IMMA-specific 布局可能会有更好的性能,所以如果我们要尝试搭建 int8 推理的话,建议咱们可以先做一些 benchmark,以便更好地从性能和开发难易程度做取舍。
在 FasterTransformer 中我们采用了 IMMA-specific 布局。所以接下来,我们以 IMMA-specific 布局为例,简单介绍了一下 cublasLt int8 矩阵乘法的基本搭建流程,以及一些开发技巧。
cublasLt int8 矩阵乘法的基本搭建流程,一共可以分为 5 步:
- 首先我们需要创建句柄和乘法描述符;
- 接下来我们为每个矩阵创建一个矩阵描述符;
- 因为一般我们的输入都是常规 layout 的,所以我们需要对常规布局的矩阵进行布局转换,使其变为 IMMA-specific 的布局;
- 然后再进行 int8 矩阵乘法,得到结果之后,我们可以考虑继续用这个结果进行下游的乘法计算,这样可以避免转变会常规布局的开销;
- 只有最后一个矩阵乘法的结果,我们需要转换常规布局以便输出。
上述介绍了 IMMA-specific 布局下的搭建流程,可以看到里面会有不少限制。为了避免这些限制对性能的影响,我们在 Faster Transformer 中采用了以下技巧:
- 首先 IMMA-specific 布局对矩阵是有特定的尺寸要求,为了避免推理过程中需要额外分配空间的操作,我们会提前分配好符合 IMMA-specific 布局尺寸的 buffer;
- 然后,由于 weight 可以一次处理重复使用,所以我们会提前对 weight(相当于乘法中的 B 矩阵)进行布局变换,避免在推理过程中反复变换 weight;
- 第三个技巧是,对于不得不进行特殊布局变换的 A 和 C,我们会把变换和上游或下游 op 进行算子融合,以便隐藏这部分的开销;
- 最后一点,是与布局无关,而是 int8 矩阵乘法必有的量化和反量化的操作,我们同样会采用算子融合的方式,把它的 latency 隐藏起来。
以下是我们在 Faster Transformer 中采用的的 INT8 流程的示意图,可以看到,所有矩阵乘都变为了 int8 数据类型,每个 int8 矩阵乘法前后都会插入对应的量化和反量化节点,然后对于加 bias,加残差或 layernorm 等操作,我们还是保留原始的 FP32 或 FP16 数据类型,当然它的 I/O 可能是 int8 的,从而会比 FP16 或 FP32 I/O 性能要好。
这里展示的是 Swin Transformer int8 量化的精度情况,通过 QAT 我们可以保证精度损失在千分之 5 以内。
而在 PTQ 那一列,我们可以看到 Swin-Large 的掉点比较严重,一般对应掉点严重的问题,我们都可以考虑采用减少一些量化节点的方式来提升量化精度,当然这样可能会带来加速效果的减弱。
在 FT 中,我们可以通过禁用 FC2 和 PatchMerge 中 int 8 矩阵乘法的 int8 输出前的反量化和量化结点(即采用 int32 输出),来进一步提升量化精度,可以看到在此优化操作下,swin-large 的 PTQ 精度也明显提升了。
接下来是我们推理侧取得的加速效果,我们分别在不同型号的 GPU T4、A10、A100 上进行了跟 pytorch FP16 实现的性能对比。
其中下图左侧是优化后跟 pytorch 的 latency 对比,右图为优化后 FP16 下跟 pytorch 以及 INT8 优化跟 FP16 优化的加速比。可以看到,通过优化,在 FP16 精度上,我们可以取得,相对于 pytorch 2.82x ~ 7.34x 的加速,结合 INT8 量化,我们可以在此基础上进一步取得 1.2x ~ 1.5x 的加速。
4. Swin Transformer 优化总结
最后,我们总结一下,本次分享中我们介绍了如何通过 nsight system 性能分析工具发现性能瓶颈,然后针对性能瓶颈,介绍了一系列训练推理加速技巧,其中包括 1. 混合精度训练 / 低精度推理,2. 算子融合,3. cuda kernel 优化技巧 :如矩阵补零,向量化读写,巧用寄存器数组等,4. 推理优化上采用一些预处理,来完善我们的计算流程;我们也介绍了 multi-stream,cuda graph 的一些应用 。
结合上述优化,我们在训练上,以 Swin-Large 模型为例取得了单卡 2.85x 的加速比,8 卡 2.32x 的加速比;在推理上,以 Swin-tiny 模型为例,在 FP16 精度下取得了 2.82x ~ 7.34x 的加速比,结合 INT8 量化,进一步取得 1.2x ~ 1.5x 的加速比。
上述视觉大模型训练与推理的加速方法都已经在百度百舸 AI 异构计算平台的 AIAK 加速功能中实现,欢迎大家使用。