本文经自动驾驶之心公众号授权转载,转载请联系出处。
端侧LLM毫无疑问会成为各手机厂商在2024年的主战场。从国内各手机厂透露的信息来看,大家几乎都把希望寄托在了芯片厂身上,自身能做的、会做的工作太少。希望苹果的工作对国内厂商们有启发、借鉴意义。
论文链接:LLM in a flash: Efficient Large Language Model Inference with Limited Memory
1. Flash Memory and DRAM
在移动端设备中(如手机),DRAM可理解为“运行时内存”,Flash Memory可理解为“存储空间”。做一个简单的类比,在PC中,DRAM对应于内存;Flash Memory对应于硬盘存储(注意:仅仅是对应于,实现方案并不一样)。
在通常的LLM推理阶段,LLM都是直接加载到DRAM中的。一个7B半精度LLM,完全加载进DRAM所需的存储空间超过14GB。考虑到目前主流手机的DRAM最高也就16GB的水平,在端侧直接使用DRAM来加载7B LLM面临巨大挑战。
图1给出了一个移动端标准的存储结构示意图。
图1 移动端存储结构示意图
Flash Memory的特点是大存储,低带宽。也就是说,Flash Memory可存储的内容多(图中的100GB),但数据传输速率低(图中的1GB/s)。而DRAM的特点是小存储,高带宽。
现在的问题是:模型大小 > DRAM,所以无法将模型全部加载进DRAM。
苹果的解决方案是将LLM放在Flash Memory中,在每次需要进行推理时,仅仅将部分必要参数加载到DRAM中。
苹果的整个方案重点解决两个问题:
- 如何快速识别出哪些模型参数是必要的
- 考虑到由Flash memory到DRAM的带宽较低,如何加快由Flash memory到DRAM的传输效率
论文中从三个不同方面做了尝试,下面分别介绍。
2. 减少数据传输量
这部分介绍论文中采用的三种降低数据传输量的方法。
2.1 方法一:Selective Persistence Strategy
对于常见的LLM而言,它的模型参数主要由Attention参数和MLP参数两部分构成,其中Attention参数占比约为1/3,MLP参数占比约为2/3。除此,还有参数量级可忽略不计的Embedding层的参数。
因为Attention参数量相对较少,所以苹果的方案是将Attention参数和Embedding层的参数直接加载到DRAM中。
这就是所谓的Selective Persistence Strategy,其意为:有选择性地把部分参数常驻在DRAM中。而这部分常驻的参数就是Attention参数和Embedding参数。原因是因为它们占比较小。
2.2 方法二:Anticipating ReLU Sparsity
这里主要借鉴了DejaVu的思路:MLP层的输出只有不到10%的值是激活状态(不为0)。一般把这种现象称为稀疏性。稀疏性越强,则非激活状态的值就越多。
所以我们也可把这句话“MLP层的输出只有不到10%的值是激活状态”简写作“MLP层的稀疏性超过90%”。
要注意,此处的稀疏性一般称为“Contextual Sparsity”。也就是说,MLP层的哪些神经元会激活,与当前的输入相关。
苹果照搬了DejaVu的方法,使用一个两层MLP来预测哪些神经元会激活。方法也很简单,假设神经元个数为4096,只需要将MLP的输出层的大小设为4096,并为每一个输出使用sigmoid来做一个二分类即可(“选择”or“不选择”)。
注意1:不同Transformer层使用的预测模型不同。
注意2:同一个Transformer层中的MLP一般有两层。它们的激活神经元始终保持相同。
在能够准确预测的前提下,每次在推理时动态加载预测为激活神经元对应的参数即可。
这里有对DejaVu详细介绍的文章:[ICML'23] DejaVu:LLM中的动态剪枝
2.3 方法三:Sliding Window
根据2.2小节中介绍的稀疏性可知,在每一次LLM进行前向推理时,它都需要使用模型预测每一个MLP层中激活神经元的编号,并将所需的神经元所对应的权重由Flash memory加载到DRAM中。
因为LLM的推理阶段是逐token进行的,这意味着在生成不同token的时候,需要加载到DRAM中的MLP的参数也不同。
用一个简单的例子来说明这个基础概念,只考虑第 层的Transformer模块。在处理当前token 时,该层使用模型预测MLP会激活的神经元编号,假设为{0, 1, 3, 5},并将其对应的参数从Flash memory加载到DRAM中,然后进行推理。
在处理下一个token 时,将 从DRAM中删除,再使用模型预测MLP会激活的神经元编号,假设为{0, 2, 3, 6},并将其对应的参数从Flash memory加载到DRAM中,然后进行推理。
注意到在我们的例子中,两次前向推理时,第 层的Transformer结构中MLP有部分被预测为激活的神经元是重叠的:{0, 3}。所以实际上在进行第二次前向推理时,没有必要把完全从DRAM中删除,而是将其中编号为{1, 5}神经元对应的参数删除,再将编号为{2, 6}的神经元对应的参数读入即可。这样可以减少I/O的总开销。
这就是Sliding Window的核心思想:保留处理过去k个token时的激活神经元所对应的参数在DRAM中,并在处理当前token时只对:1)部分多余的参数进行删除;2)缺少的参数进行加载。图2是原文中的示意图。
图2 Sliding Window示意图
图中上图表示在处理当前token “Was”之前,前5个token(k=5)的激活神经元(淡蓝色偏绿部分)。图中下图表示在处理当前token “Was”之时,需要新加入的神经元(蓝色部分)和需要删除的神经元(分红部分)。
Sliding Window的核心假设是LLM在处理相邻token时产生的稀疏性具有相似性。原文没有仔细分析和论证这个假设。
3 提高传输吞吐量
3.1 Bundling Columns and Rows
通常LLM中的MLP层包含两个全连层。在忽略激活函数的情况下,这两个全连层可以写为:
在2.2小节中曾经提到,稀疏性预测是对MLP中两个全连层同时进行的。也就是说,如果我们预测结果是第一个全连层中0号神经元不会被激活,那么该预测结果同样也适用于第二个全连层:第二个全连层的0号神经元也不会被激活。
对于第一个全连层的参数矩阵,第i个神经元对应于它的第i列;对于第二个全连层的参数矩阵,第i个神经元对应于它的第i行。
当第i个神经元被预测为激活时,需要同时读取的第i列和的第i行。所以为了提高读取速度,可以将的每一列和的对应行拼接起来存储,如下图所示:
图3 将两个全连层的列与行拼接存储
图3中的Up Proj指的是第一个全连层,对应于上文参数矩阵;Down Proj指第二个全连层,对应于上文参数矩阵。
这样做的好处是原本需要两次I/O,现在只需要一次了。虽然总的数据读取量并没有变,但读取大块、连续的数据会比读取大量小块、非连续数据更加高效,因此整体传输吞吐量提升了。
3.2 Bundling Based on Co-activation
这是一个原文尝试过,但被验证为无效的策略。既然原文提到了,所以这里也罗列出来。
原文中猜测某些神经元之间可能存在一些紧密联系。比如对于两个神经元a和b,当a激活时,b也会激活(或者当b激活时,a也会激活)。
因此可以通过分析来找到每个神经元的“closest friend”(与该神经元同时激活频率最高的其它某个神经元)。然后在存储Flash memory中存储时,也将它们的参数拼接存在一起。这样的读取效率更高。
但该方法之所以无效,主要原因是可能会存在某个神经元i,它是其它很多神经元的“closest friend”。这样导致的问题则是神经元i被额外传输了太多次,导致实际的I/O成本增加了。
4 Optimized Data Management in DRAM
虽然DRAM的数据读取速度比Flash memory快很多,但当需要对其中数据进行大量、高频读写时,它的时耗仍然不可忽略。在本文介绍的内容中,对DRAM的读写主要发生在对MLP层中所需神经元对应参数的删除与新增(参考图2)。
为此,论文中设计了一种特殊的数据结构来对DRAM中的数据做精细化管理。该数据结构的核心变量如下:
- Matrix:按照“Bundling Columns and Rows”的方法存储激活神经元所对应的参数
- bias:激活神经元所对应的bias参数
- num_used:激活神经元的个数
- last_k_active:过去k个token所对应的激活神经元编号
- Pointer:当前行参数对应的神经元编号
图4 Optimized Data Management in DRAM
通过预分配一个足够大的空间,可以避免因反复分配而导致的额外开销。下面来说明基于该数据结构的一些操作的高效实现方法。
该矩阵中的行对应的是当前存储在DRAM中激活神经元的参数。前文提到(2.3小节),当处理新的token时,需要将不会被激活的神经元删除,并添加新的会被激活的神经元。所以最重要的两个操作是“删除”和“新增”。
当需要删除某个神经元时(如图4中左图标红部分,对应的是编号为10的神经元),只需将num_rows的数量减1,并将最后一行Copy至被删除行,结果如图4的中图所示。虚线框表示当前有效的数据。
当需要“新增”时,直接将其对应的参数由Flash memory中copy至该矩阵中即可,无需额外分配存储空间。
5. 实验结果
苹果这篇paper的主要关注点在于:让LLM在运行时内存受限的情况下能高效地跑起来。所以论文的实验主要对比了各种情况下I/O导致的时耗,如下图所示。
图5 实验结果
图5中的实验使用的是OPT 6.7B模型,半精度。表中第一行和第二行都是基准baseline。第一行假设初始模型全部在Flash memory中,那么为了完成一次完整的推理,需要将模型全部加载一遍,整个I/O耗时2130ms。
第二行对应于假设模型有一半参数提前在DRAM中的整个加载耗时。
第三行到第五行对应于应用了Predictor(2.2小节)、Windowing(2.3小节)和Bundling(3.1小节)后对应的耗时。
效率提升非常明显。