量化到1 bit的LLM还能再突破?
这次,他们对激活值下手了!
近日,BitNet系列的原班人马推出了新一代架构:BitNet a4.8,为1 bit大模型启用了4位激活值:
图片
论文地址:https://arxiv.org/pdf/2411.04965
众所周知,激活值量化通常是比较难办的。
本次的BitNet a4.8采用混合量化和稀疏化策略,来减轻异常通道引入的量化误差。
简单来说就是,对注意力层和FFN层的输入采用4位量化,同时用8位整数稀疏化中间状态。
大量实验表明,BitNet a4.8在相同的训练成本下,实现了与前代BitNet b1.58相当的性能,同时因为可以吃到4位(INT4/FP4)内核的计算红利,实现了更快的推理速度。
BitNet a4.8仅激活55%的参数,并支持3 bit KV cache,进一步提升了大规模LLM部署和推理的效率。
BitNet a4.8
图片
模型架构
模型的整体架构如图1所示,BitNet a4.8采用了与BitNet b1.58相同的布局。
作者使用BitLinear替换注意力(MHA)和前馈网络(FFN)中的线性投影,以从头开始学习1.58 bit权重。对于激活值,采用混合量化和稀疏化策略来减轻异常值维度引入的误差。
图片
图2说明了模型大小为7B的BitNet b1.58中,每个模块输入的分布。
注意力层和FFN层的输入通常类似高斯分布,而在FFN下采样之前的激活值和注意力中的输出投影中,发现了很多异常值通道和大量接近零的条目(全精度LLM也有类似观察结果)。
图片
如图3所示,直接将低位量化应用于这些中间状态会引入很大的量化误差。
因此,作者使用Q-Sparse的稀疏化方法,将这些中间状态保持在8位(同时消除了计算瓶颈)。
对于自注意层的输出投影,使用sparsify-then-quantize函数:
两个Q分别表示权重W和激活X的量化函数,M是掩码,根据激活X的绝对值取topK,⊙是元素乘法。
具体来说,权重量化和激活值量化函数可以表述为:
对于FFN,这里采用squared ReLU和门控线性单元(GLU)来进一步提高激活的稀疏性:
根据初步实验的结果,使用squared ReLU时,下采样输入的稀疏性超过了80%,且对性能的影响最小。
此外,作者还观察到gate + squared ReLU的输出也表现出高激活稀疏性(7B模型为67.5%)。通过首先计算gate projection,然后仅在非零通道上执行up projection,可以进一步减少推理的计算量。
相比之下,attention和FFN的输入中包含的异常值特征要少得多,可以使用absmean函数将激活值量化为4位整数:
模型训练
初始化
BitNet a4.8使用BitNet b1.58的权重开始训练,分为W1.58A8与W1.58A4两阶段。
第一阶段使用8位激活和GLU + squared ReLU训练模型;第二阶段采用上面介绍过的混合量化和稀疏化。
图片
BitNet a4.8只需少量训练,即可快速适应4bit位宽和稀疏激活,同时性能损失可以忽略不计。
梯度近似
作者使用直通估计器(STE)对BitNet a4.8进行梯度逼近,使用混合精度训练来更新参数。
图片
这里直接绕过了不可微函数,包括反向传播过程中的量化函数和topK稀疏函数。对于混合精度训练,保持全精度latent weight来累积参数更新。
模型量化
浮点量化提供了比基于整数的量化更宽的动态范围,这对于处理激活值的长尾分布至关重要。
研究人员将FFN下采样层的输入保留为8位整数,其他激活值使用MinMax量化器量化为FP4:
公式中E和M分别表示指数和尾数部分的位宽。这里采用E2M1格式,因为它的动态范围更大。
实验
本文将BitNet a4.8、BitNet b1.58,以及各种参数量大小的FP16精度LLaMA进行了比较。
其中的1.58 bit模型,遵循BitNet b1.58的训练方案,采用了两阶段权重衰减和学习率调度。
图片
所有模型都使用RedPajama数据集中的100B token进行训练,以确保公平比较。
对于BitNet a4.8,作者首先使用95B token来训练8位激活值的模型。然后重用优化器状态,并使用5B token进行混合量化和稀疏化的训练。实验将topK设置为50%(attention的输出投影位置)。
作者使用lm-evaluation-harness工具包,评估模型在一系列语言任务上的zero-shot准确性,包括ARC-Easy(ARCe)、ARCChallenge(ARCc)、Hellaswag(HS)、Winogrande(WGe)和PIQA(PQ)。另外还测试了在C4数据集(测试集)上的困惑度。
主要结果
图片
表1总结了BitNet a4.8、BitNet b1.58和FP16 LLaMA的详细测试结果。
全精度(FP16)LLaMA和BitNet b1.58之间的性能差距,随着模型大小的增长而缩小。对于7B模型,BitNet b1.58在语言模型困惑度和任务的平均准确性方面与LLaMA相当。
此外,相比于BitNet b1.58,BitNet a4.8的平均精度几乎没有损失。
图片
表2展示了各种大小的BitNet a4.8、BitNet b1.58 和 FP16 LLaMA中每个模块的详细稀疏性(使用C4验证集上的非嵌入参数计算)。
值得注意的是,BitNet a4.8的稀疏性明显高于BitNet b1.58和LLaMA。
比如在7B模型中,BitNet a4.8的整体稀疏性达到了44.5%,只有3.4B的活跃参数。down projection层的输入显示出特别高的稀疏性,且中间状态分布以零为中心。
此外,gate projection的输出非常稀疏,导致了up projection的高稀疏性(因为只需要在从Gate中选择非零通道来执行投影)。
具体来说,对于7B BitNet a4.8,Gate和up projection的稀疏率分别为67.5%和12.0%。
图片
表3显示了BitNet a4.8在3B和7B模型大小下,low-bit attention的详细情况。模型使用4位KV或QKV头,精度损失可忽略不计,同时KV cache可以量化为3位整数。
low-bit attention对于高效的长序列建模至关重要,它减少了KV cache的内存占用和IO,并加速了注意力计算。
在本文的实验中,作者采用RoPE后量化。使用absmax函数将QKV头直接量化为无符号整数,无需任何校准数据集。
对于3 bit KV量化,研究人员将bos token的头保留为4 bit,因为它包含更多的异常值特征。
消融实验
图片
图4显示了700M BitNet a4.8的训练损耗曲线,比较了使用完整的INT4/FP4量化,以及本文的混合量化和稀疏化。
完整的INT4量化会导致发散,而混合架构在训练困惑度方面明显优于完整的FP4架构。
使用RedPajama数据集中25B token,来进行模型的第一阶段训练,采用absmean和MinMax量化器分别进行完整的INT4和FP4量化。
对于完整的INT4量化,由于其输入具有更大的异常值,这里设置β = 2*mean(|X|)。
图片
接下来为1.3B BitNet a4.8的down projection层输入,设置不同的量化或激活函数。
所有模型都使用RedPajama数据集中的50B token进行第一阶段训练。为了确保公平比较,其他激活值都保留在8位。
图5显示了这些模型的训练损失曲线。Squared ReLU的训练困惑度比Swish略好,同时实现了更高的稀疏性。
此外,对down projection的输入应用FP4量化会导致性能显著下降,而将INT4激活与STE一起使用会导致发散。
参考资料: