
字节 TileLink:编译生成高效的计算和通信 Overlap Kernel
一、背景
笔者之前的文章(万字综述 LLM 训练中的 Overlap 优化:字节 Flux 等 7 种方案)中详细介绍过各种计算与通信 Overlap 的方案,这里进一步介绍字节最近发表的 TileLink,其中提到的大部分工作已经包含在我们之前的综述中,建议优先阅读,比如 CoCoNet、Centauri、Flux 等。
对应的论文:[2503.20313] TileLink: Generating Efficient Compute-Communication Overlapping Kernels using Tile-Centric Primitives [1]
二、摘要
大规模深度学习模型通常需要分布式系统以实现高效的训练与推理,分布式模型执行的基础构建模块是层内并行算子。提升层内并行算子性能的最有效方法在于实现计算与通信的 Overlap。这种 Overlap 可通过算子分解(Operator Decomposition)或 Kernel 融合(Fusion)两种方式达成:
- Operator Decomposition 虽易于实现,但性能往往欠佳。
- 将通信 Kernel 与计算 Kernel 相融合则需深厚的专业知识且易出错。
本文中,作者提出 TileLink,旨在高效编译并生成计算-通信 Overlap 执行的 Kernel。TileLink 由前端(Frontend)和后端(Backend)构成:
- 在前端,系统通过以 Tile 为中心的原语将通信与计算的设计空间解耦并建立关联。
- 在后端,将这些原语转换为底层指令,整合通信与计算组件以实现 Overlap 执行。
实验表明,TileLink 相较于非 Overlap 基线实现了 1.17x 至 20.76x 的加速,并在 GPU 上达到了与当前最优 Overlap 执行库相当的性能水平。
三、引言
3.1 北大 Centauri
北大在 [ASPLOS 24.04] Centauri: Enabling Efficient Scheduling for Communication-Computation Overlap in Large Model Training via Communication Partitioning [2] 中介绍了 Centauri 框架,其构建了一个由三个固有抽象维度组成的切分空间:原语替换、拓扑感知组切分及工作负载切分。这些维度共同构成了一个全面的优化空间,用于高效 Overlap。为确定通信与计算的高效 Overlap,作者将混合并行训练中的调度任务分解为 OP、Layer 和模型三个层次。
如下图 Figure 3 所示,Centauri 的工作流程包含两个核心环节:通信切分与层次调度。以 DP 与 FSDP 混合并行训练为例:
- 通信切分:通过考量三个基本维度,生成潜在切分空间,并为每种集合通信选择高效策略。
- 层次调度:在上述全面但较大的切分空间下,优化整图的 Overlap 调度成为一项复杂的任务,为了简化复杂的调度任务,作者将复杂的混合并行集合通信分解为三个层次,每个集合通信被分配至特定调度层级。各层级选取开销较低的切分与调度方案,旨在实现整体优化 Overlap 方案。
有一系列类似 Centauri 的算子分解方法,其核心是:将通信和计算 Kernel 拆分为更小规模的同构 Kernel,随后将其分配到多个通信-计算 Kernel 对中。这些拆分后的小 Kernel 可被调度到不同的 Stream 上,使得通信 Kernel 和计算 Kernel 能同时对切分的数据分片进行操作。
然而,类似上述算子分解的方法有一些局限性:
- 分解后的 Kernel 间的同步机制需要 Host 端介入,会在运行中引入不可忽略的开销。
- L2 Cache 利用率降低、资源量化效率不足,导致分解后的 Kernel 性能可能出现恶化。
这里的资源量化效率不足(Resource Quantization Inefficient)是指计算资源切分不均衡等导致的浪费,如下图 Stream-K([2301.03598] Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU [3])中提到的问题:
3.2 字节 Flux
字节在 [2406.06858] FLUX: Fast Software-based Communication Overlap On GPUs Through Kernel Fusion [4] 中提出 Flux,旨在通过依赖计算隐藏 GPU 间的通信时延。Flux 将通信和计算操作分解为更细粒度的操作,并进一步融合成更大的 Kernel,从而在不损害 Kernel 效率的前提下有效隐藏通信。在融合 Kernel 的情况下,Flux 有望重叠高达 96% 的通信时间。
如下图 Figure 5 展示 Flux 中 ReduceScatter 里 Overlap 与其他方案的差异。现有 Overlap 方案 Tm 理论上可能比原始方法 Tc 执行得更快,但通常情况下,Tm 仍慢于原始 GEMM 操作时间 Tg。主要原因在于,将一个 GEMM Kernel 拆分为一系列较小的 GEMM Kernel 会降低 GPU GEMM 的执行效率。GEMM 通常需要合理大小的矩阵才能充分利用 GPU 的计算能力。这些具有数据依赖性的小型 GEMM 操作序列进一步阻碍了 GEMM Kernel 通过 GPU 多路复用技术并行运行,因此,Tensor 并行度越高,GPU 上的 GEMM 效率越低。
相比之下,作者提出的技术不存在上述限制。作者的 Overlap 方案 Tf 能够在极小开销下实现与原始 GEMM 操作 Tg 相当的性能。其细粒度分解策略完美契合现代 GPU 设计特性,即通过上下文切换的 Warp 和数百个在 SM 间并发活跃的 Warp 来隐藏延迟,如下图 Figure 5 底部所示。最终,作者的方法在不影响 GEMM 计算效率的前提下,仅在执行末尾引入少量通信开销。
然而,虽然这种方式实现的 Kernel 效率很高,但是开发成本同样很高,尤其是针对不同场景、模型可能需要开发特定的 Kernel。DeepSeek 可以做深度的 DeepEP、DualPipe 等优化的一个前提就是其模型、硬件相对恒定,可以一劳永逸。
四、方案
4.1 概览
本文工作主要聚焦于层内并行,为了说明 TileLink 的优势,作者以 MLP 的 Tensor Parallelism(TP) 为例,如下图 Figure 1 所示,其实现包含 AllGather + GEMM(AG+GEMM)与 GEMM+ ReduceScatter(GEMM + RS),其配置与 LLaMA-7B 一致:
如下图 Table 2 所示,采用不同技术方案的性能进行对比,其中 Non-Overlap 为直接使用 cuBLAS 和 NCCL 的无 Overlap 方案;Decomposition 则为采用算子分解技术。可以看出,Decomposition 是性能最差的,Fusion 方案在 AG + GEMM 中最优,TileLink 在 GEMM + RS 中最优,同时 AG + GEMM 与 FLUX 性能接近(约达 99%)。同时,FLUX 需要 2000 行 CUDA 代码,而 TileLink 仅需 200 行 Python 代码,编程效率提升 10x。
PS:之前的 CoCoNet([ASPLOS 22] [2105.05720] Breaking the Computation and Communication Abstraction Barrier in Distributed Machine Learning Workloads [5]) 和 Dist-Einsum(Overlap Communication with Dependent Computation via Decomposition in Large Deep Learning Models | OpenReview [6]) 也可以生成 Overlap Kernel,但是其只能生成特定 Overlap Pattern 的算子,不够灵活。
4.2 前端原语(Frontend Primitives)
4.2.1 解耦设计空间
设计计算+通信融合 Kernel 存在两种方式:一种是将两部分优化选择紧密耦合;另一种是解耦计算 Kernel 和 通信 Kernel 设计。本文的 TileLink 选择后者,因为其解构设计空间能为 Kernel 设计提供更多灵活性,从而可能获得更优性能。(PS:是否也有可能丧失联合设计的优势,比如更加均衡的资源分配?)
解耦设计空间分为 3 个子空间:
- 分块尺寸(Tile Size):如下图 Figure 2a 所示,通信组件每次传输 128x128 的 Tile;计算组件每次处理 128x256 的 Tile。Tile 的大小与使用的处理核心数相关,比如通信组件占用更多核心时使用较小的 Tile 可以更充分的利用全部核心资源;反之,核心数较少时更大的 Tile 更加高效。
- 分块顺序(Tile Order) :如下图 Figure 2b 所示,通信组件可以与计算组件采用不同的分块顺序。分块顺序的选择也存在权衡:若计算组件等待多个 Rank 的数据分块,则能在处理更大数据块时获得更好的 Cache 效率,但可能等待时间变长;反之,仅等待单个 Rank 的数据分块,可提早开始计算,但整体计算效率可能降低。图中例子为通信采用 Ring 顺序而每次迭代等待 2 个 Rank 数据。
- 资源映射(Resource Mapping):如下图 Figure 2c 所示,通信与计算组件可映射到不同单元或相同单元。比如,如果通信组件使用 Copy Engine(DMA),可以避免与计算组件的资源冲突,但需要承担 Host 带来的额外开销;但是如果采用计算核心执行数据拷贝,则可以消除 Host 开销,但可能引发资源冲突,这适用于计算组件无法充分利用所有处理核心的场景。
4.2.2 Tile 为中心的基础原语
解耦通信与计算的设计空间也会引入同步的挑战。由于这两个组件采用不同的分块尺寸、分块顺序和资源映射方案,实现二者同步需要进行复杂的底层编程并插入通信指令。以 GPU 为例,要求使用诸如 ld.global.acquire 和 red.release 等特殊指令。然而,这类指令的编程模型和代码生成编译器的工作机制存在根本性差异,现有编译器普遍缺乏对内存一致性模型的原生支持。
- ld.global.acquire:获取语义,确保之后的操作不会提前执行,确保读取的变量是最新的,防止 CPU 或其他 GPU 线程的旧值污染数据。
- red.release:释放语义,确保之前的写入对其他线程可见,确保数据在此操作之前全部写入,防止写入乱序执行。
- 这两个指令通常用于同步机制,特别是生产者-消费者、互斥锁、信号量等场景,以保证不同 GPU 线程间的正确通信。
为解决上述问题,TileLink 提供了一套以 Tile 为中心的基础原语。这些原语引入了内存一致性(Memory Consistency)语义,并遵循编译器采用的 Tile 级抽象,与现有框架提供的以算子为中心的原语形成显著区别。如下图 Figure 3 所示,TileLink 原语分为信号原语(Siganl Primitive)和数据原语(Data Primitive)两大类,每类均包含 Device-side 原语和 Host-Side 原语两个子类。
涉及的所有原语如下表 Table 3 所示:
4.2.3 信号原语
信号原语:旨在管理通信和计算之间的屏障,包括:
- producer(peer)_tile_notify:生产者或 Peer 通知
- consumer(peer)_tile_wait:消费者或 Peer 等待
- rank_notify(wait):Rank 通知和等待
在 Device-side:
- producer_tile_notify 和 consumer_tile_wait 适用于生产者-消费者关系,例如 AllGather 与 GEMM 运算中各 Tile 的交互;
- peer_tile_notify 和 peer_tile_wait 主要用于跨不同 Rank 的同一算子 Tile,使用户能够构建多样化的 Tile 执行顺序。
在 Host-side:
- rank_notify 和 rank_wait 用于管理 Copy Engine 和计算核心间的同步屏障。当通信任务映射至 Copy Engine 时,这些原语可有效协调通信与计算间的 Tile 执行顺序。如上图 Figure 3a 所示。
Notify 原语需通过 Mode Argument 或 Rank argument 明确待通知的远端 Rank 范围。TileLink 为 Mode Agrument 提供两种选项:p2p 和 broadcast。
- p2p 仅通知单个目标 Rank,其数值由给定 Tile 标识(tile_id)在全局张量视图中的偏移量计算得出;
- broadcast 则向所有 Rank 发送通知信号。
内存一致性:在并行执行过程中,不同进程/线程执行的内存操作可能以非一致顺序对其他进程/线程可见。内存一致性模型通过设定约束条件,确保各进程/线程观测到的操作顺序不存在歧义。信号原语提供了严格的内存一致性语义:
- 通知类原语具有释放语义(release semantics),保证所有在 producer(peer)_tile_notify 和 rank_notify 之前的内存访问操作不得被重排到这些通知原语之后;
- 等待类原语则具有获取语义(acquire semantics),确保所有在 consumer(peer)_tile_wait 和 rank_wait 之后的内存访问操作不得被重排到这些等待原语之前。
这种严格的内存一致性约束在后端编译阶段同样需要予以考虑。
4.2.4 数据原语
数据原语促进了数据传输过程,主要包括 tile_push(pull)_data 和 rank_copy_data 两类原语。这些原语精确控制着传输数据的资源映射与 Tile 大小。
- Device-side 的 tile_push(pull)_data 原语将通信映射至处理核心。
- Host-side 的 rank_copy_data 原语则将通信映射至 Copy Engine。
数据传输存在拉取(pull)与推送(push)两种模式,各自适配不同的同步机制:
- 在 pull 模式下,生产者从所有其他 Rank 读取数据,并通过本地屏障通知其消费者;
- 与之相反,push 模式允许生产者将本地数据写入所有其他 Rank,同时向远端消费者发送数据到达通知。
如上图 Figure 3b 清晰展示了两种模式的差异。模式选择可能影响性能表现,具体取决于数据形态、分块策略及可用硬件资源等要素。值得注意的是,rank_copy_data 原语通过 P2P 复制技术支持双模式运行,其数据传输方向由源指针与目标指针的排列顺序显式指定。
4.3 后端映射(Backend Mapping)
TileLink 后端负责将通信与计算组件共同编译为底层设备代码。为实现分布式系统的代码生成,TileLink 采用了一种以计算单元为核心的映射技术,该技术能够将通信模块与计算模块进行关联整合。
TileLink 采用以 Tile 为中心的映射方法,将前端原语编译为底层代码。以 Tile 为中心的映射包含三个组成部分:
- 形状映射(shape mapping):将每个 tile_id 与特定的 Tensor Shape Tile 相关联。
- Rank 映射(rank mapping):将每个 tile_id 与 Device Rank 相关联。
- 通道映射(channel mapping):为每个tile_id 分配通信屏障(communication barrier)。
作者分别用 fS、fR、fC 表示这三种映射。根据工作负载类型的不同,应采用不同的映射函数。作者将不同映射划分为两类:
静态映射(Static Mapping):指可在编译时确定的映射关系,通常用于数据分片策略固定的场景,例如 Tensor-Parallel MLP 和 Sequence-Parallel Self-Attention。作者采用仿射运算(Affine Operation)处理静态映射(此时 fS、fR、fC 均为仿射函数)。以包含 R 个设备(每 Rank 对应 C 个通道/屏障)的系统上执行 AllGather(pull 模式)+ GEMM(问题规模 M×N×K)为例:生产者 AllGather 操作的 Tile 尺寸为 Tmp × Tnp,输入 Tensor 沿 M 维分片。给定生产者 Tile 的 tile_idp,其形状范围、源 Rank 及通道可通过以下公式计算。类似地,可以计算出从消费者 tile_idc 到形状范围、 Rank 和通道的映射关系:
动态映射(Dynamic Mapping):是指在运行时计算的映射关系,这对于具有动态数据分片需求的工作负载至关重要。例如,在 MoE 数据分片策略中,动态路由决定了数据分布,每个 tile 可能需要来自其他任意 Rank 的 Token。在编译时无法确定需要从哪些 Rank 收集数据或在哪个通道等待屏障同步。因此,必须在运行时计算这些映射关系。为支持动态映射,TileLink 将这些映射转换为查找表,其值可在运行时填充,而对这些查找表的访问操作则在编译时确定。从形式化角度来看,动态映射如下所示(其中 fS_low,fS_high,fR 和 fC 是查找表,其值在运行中动态调整):
内存一致性编译:在后端编译过程中,前端具有内存一致性语义的原语被编译为相应的设备指令(如 ld.global.acquire 和 red.release)。然而,直接翻译这些原语并不足以确保内存一致性。对于大多数计算 Kernel,采用多级流水线技术来提升负载-计算平衡并优化整体性能。将原始程序编译为多级流水线版本需要进行算子重排,在此过程中某些内存访问操作可能会意外地被重排至 TileLink 原语之前或之后。为解决这一问题,TileLink 在其原语与后续 load/store 操作之间建立了严格的数据依赖关系,从而确保其原语能够通过流水线处理阶段被正确重排序和展开。
其他编译优化:除上述技术外,TileLink 还采用单设备优化策略以实现高性能,该策略在已有研究中得到充分论证。优化主要体现在内存优化与流水线优化两方面:
- 内存优化通过自动分配片上寄存器缓存和计算用共享存储缓冲区,对全局缓冲区的数据访问进行合并操作,并重构共享存储器访问模式以避免存储体冲突;
- 流水线优化则通过重组数据 load/store 操作与计算任务,构建多级流水线架构。其中,本地数据拷贝被映射至专用异步引擎(如 GPU 的 TMA),而计算任务则被分配至高性能运算单元(如 GPU 的 Tensor Core)。
4.4 Kernel 设计
为展示 TileLink 的灵活性与普适性,作者阐释了如何为 GEMM + Ring ReduceScatter、AllGather + MoE 以及 AllGather KV + Self Attention 机制设计 Overlap 计算 Kernel。这三个案例具有代表性:它们分别采用了不同的分片顺序(Ring 和 All2All)、不同的映射策略(静态与动态)以及不同的硬件资源(Device-side 和 Host-side)。
如下图 Figure 4 展示了 GEMM + Ring ReduceScatter Kernel 的伪代码实现,该案例采用静态映射策略,演示了生产者-消费者和 P2P 双向通信的编程范式。
- 其中计算与通信均采用 SM,分配了 20 个 SM 专用于通信(见第 1 行)。
- 生产者 GEMM 将部分计算结果存储于本地 Tensor,并通过 producer_tile_notify 通知消费者(第 9 行)。
- 消费者 ReduceScatter 通过 consumer_tile_wait(第 16 行)等待生产者就绪。
- 一旦数据可用即执行 local reduce 操作(第 20 行),并将部分结果通过 tile_push_data 传递给前序节点(第 24 行)。
- 节点间的信号控制通过 peer_tile_wait(第 19 行)和 peer_tile_notify(第 26 行)原语实现。
如下图 Figure 5 展示了 AllGather + MoE 的伪代码实现。
- 同样采用 20 个 SM 处理通信任务(第 1 行)。值得注意的是,MoE 需要基于动态路由(输入中的 topk_ids)为每个 token 选择专家,必须采用动态映射。因此使用 table 数据结构存储形状映射、Rank 映射及通道映射的查找表。所有相关原语均需以 table 为参数,以确保 TileLink 能基于动态映射生成正确代码。
- 此外,load 原语需要借助 table 中的形状映射来收集当前分片所需的正确 token(第 11 行)及其对应的 topk_ids(第 12 行)。
如下图 Figure 6 展示了 AllGather KV + Self Attention(序列并行)的伪代码。本案例中通信操作通过 Copy Engine 实现,采用 Host 原语来触发 Copy Engine。通信与计算分别在两个独立的流上执行:
- 通信部分通过 rank_copy_data 原语完成,其分块尺寸为 KV Cache 序列长度(S)除以总 Rank 数(WORLD_SIZE)。
- 计算部分则采用不同的分块尺寸。通过基于分块的 Kernel 映射机制,确保通信与计算环节间的屏障操作正确执行。
4.5 实现
TileLink 基于 Triton,使用 Python 语言实现。作者在 Python 层面实现了以计算块为中心的原语操作,从而扩展了 Triton 的语言特性,而面向计算块的映射机制则通过 Python 抽象语法树(AST)转换实现。其实现方案可轻松适配至 TVM、MLIR 等其他编译器框架。
如下图 Figure 7 所示,编译器输入为融合 TileLink 原语与 Triton 原生原语的纯 Python 程序。通过特殊参数 BlockChannel 为计算和通信提供以计算块为核心的映射上下文,BlockChannel 封装了分布式映射元数据,包括当前进程 Rank、总 Rank 数、同步屏障配置及生产者/消费者计算块关系等。
- Python 程序经解析生成 AST 后转换为 Triton 中间表示(IR),在此过程中 BlockChannel 参数被分解,利用其内嵌元数据构建面向计算块的映射关系,TileLink 原语则转换为 Triton 的 ElementwiseInlineAsmOp 操作。
- 随后 Triton IR 被进一步降级为 Triton GPU IR 和 TileLink 新增的 Distributed IR,后者用于将通过 ElementwiseInlineAsmOp 表达的特殊指令转换为 LLVM IR,最终编译为适用于 NVIDIA GPU 的 PTX 代码。
- 通过将 LLVM IR 转换为目标架构特定的底层汇编,可支持更多后端硬件。
- 运行时:
- 采用 NVSHMEM 初始化分布式执行环境并分配共享内存。
- 生成的代码在所有进程上启动以执行并发计算与通信。
- 运行结束后正确释放共享内存空间。
五、评估
如下图 Figure 8 所示,作者在 8xH00 集群上测试:
- 对 AG+GEMM 场景,Async-TP PyTorch 由于分解后的 GEMM 运算规模过小无法充分占用设备资源,未能实现加速效果。FLUX 凭借高度优化的实现取得了最高加速比(相较于 cuBLAS + NCCL 达1.34x)。TileLink 同样实现了优于 cuBLAS + NCCL 的加速效果(1.27x),达到 FLUX 性能的 94.5%。
- 对于 GEMM + ReduceScatter 场景,TileLink 展现出最佳性能:较 cuBLAS + NCCL 提升 1.25x,较 Async-TP PyTorch 提升 2.22x,较 FLUX 提升 1.28x。
如下图 Figure 9 所示,MoE 层相较于 MLP 层复杂度显著提升,在编译阶段需进行动态映射。该层可分解为两个核心部分:AG + Gather + Group GEMM 与 Group GEMM + Scatter + Topk Reduce + RS。这两类算子可融合为 Group GEMM Kernel,vLLM 已实现此类融合运算。
- 在第一部分:TileLink 凭借通信-计算 Overlap 优化,在 vLLM 基础上进一步实现 1.51x 平均加速。
- 在第二部分:TileLink 相较 vLLM 获得 1.31x 平均加速,较 CUTLASS + NCCL 组合提升 10.56x。
- 需特别指出,FLUX、Async-TP PyTorch 等现有库均不支持 MoE 层 Overlap 执行,而 TileLink 凭借灵活的原语体系与动态映射机制实现了该功能支持。
如下图 Figure 10 所示,作者针对 16K 到 128K 序列长度的 Self Attention 机制进行了评估。实验表明,在所有序列长度条件下,TileLink 方案相较 PyTorch 非 Overlap 实现(Torch)与RingAttention(RingAttn)均展现出稳定的加速优势。经量化分析,TileLink 平均可获得 5.04x 于Torch、1.97x 于 RingAttn 的性能提升。
作者将 TileLink 集成至 PyTorch 框架,并在 H800 集群上对 8 种不同 LLM 进行端到端性能评估。
- 首先在单节点(8×H800 GPU)环境下进行测试,结果如下图 Figure 11 左半部分所示。前五种为 Dense 模型,后三种为 MoE。其中 Qwen1.5 采用 MoE 共享专家机制,通过将 MLP 层与 MoE 层合并来实现共享专家支持。实验设置 Batch Size 为 4、序列长度 8192。结果表明, TileLink 相较 PyTorch 实现平均 1.32x 加速。Dense 模型平均加速比为 1.20x,与单层 MLP 加速效果一致——尽管 Self Attention 获得显著加速,但端到端性能仍由 MLP 层主导。MoE 模型平均加速比为 1.54x,低于单层 MoE 加速效果,因其 MLP 层与 MoE 层各占约 50% 执行时间,最终加速比介于二者之间。
- 在多节点部署评估中,鉴于节点间带宽限制,采用节点内 TP 与节点间 DP 的混合策略。双节点(各 8×H800 GPU)测试结果与单节点基本一致(Batch 规模倍增),整体加速比为 1.29x,因节点间通信开销略有下降。
六、参考链接
- https://arxiv.org/abs/2503.20313
- https://dl.acm.org/doi/10.1145/3620666.3651379
- https://arxiv.org/abs/2301.03598
- https://arxiv.org/abs/2406.06858
- https://arxiv.org/abs/2105.05720
- https://openreview.net/forum?id=MIJtDiMUX9
本文转载自AI闲谈,作者:AI闲谈
