在大模型推理领域,投机采样是一种被广泛使用的无损加速算法。近期一些投机采样的工作将大模型的上下文信息(例如 hidden states 和 KV cache)引入草稿模型,可以充分利用大模型的知识来提升加速比,但这类算法也会带来训练和解码的上下文不一致问题。此外,我们也发现现有算法在训练和解码的目标上也存在一定的不一致现象。小红书中台算法团队提出的 HASS 算法在目标和上下文上对齐了草稿模型的训练和解码阶段,达到了普通推理速度的 2.81~4.05 倍,相比 SOTA 方法 EAGLE-2 提升 8%~20%,相关技术已应用在小红书实际业务场景中。
论文地址
https://arxiv.org/pdf/2408.15766
01 背景
生成式大语言模型(LLMs)在各种任务上表现出令人惊叹的能力。然而,由于其固有的自回归解码机制,人们难以在这些模型上高效推理,这限制了它们在时间敏感场景中的应用。投机采样技术通过利用额外的资源来增加并发性,提供了一种大模型推理加速的解决方案。
投机采样(Speculative Sampling)
投机采样是一种先起草再验证的解码范式。在每一步解码时,先高效地生成多个草稿 token,再使用目标 LLM 并行地验证这些 token 来加速推理。 表示想要加速推理的目标 LLM, 表示基于前缀 从目标 LLM 生成下一个 token 的条件概率分布(简写为 )。 表示一个更高效的草稿模型, 表示基于前缀 从草稿模型生成下一个 token 的条件概率分布(简写为 )。投机采样分为如下3步:
1.使用更高效的草稿模型 来生成 个草稿 token。
2.使用目标 LLM 来并行地验证这些草稿 token 以及它们从 被生成的概率,接受能使得输出分布和 一致的所有草稿 token。
3.如果某个草稿 token 被拒绝后,从修正后的分布中采样一个额外的 token 替代它;如果所有的草稿 token都被接受,额外增加一个新的 token。
具体验证过程如下:从 中采样一个草稿 token ,如果 则接受 ;否则将以 的概率拒绝并从修正分布 中重新采样一个 token 接受。
经证明,对于任意的 和 ,如此得到的 token 总是与目标 LLM 分布一致。目标 LLM 的每一次前向推理至少产生一个新的 token,而至多产生 个新的 token,生成的个数取决于目标 LLM 和草稿模型的对齐程度。
投机采样的实际性能取决于两个因素:草稿模型的解码成本及其与目标 LLM 的对齐程度。为了获得与目标 LLM 高度对齐的高效草稿模型,近期的工作提出利用目标 LLM 的上下文信息。例如,EAGLE 使用目标 LLM 的 hidden states 作为草稿模型的输入特征。然而,这些方法在训练和解码阶段引入了不一致的上下文,如图 2 所示。在训练期间,草稿模型总是能获取到目标 LLM 在先前时间步的 hidden states。但在解码期间,草稿模型却无法获取到未被验证时间步的目标 LLM 的 hidden states,这导致了训练和解码阶段的上下文不一致。这一问题可以看作是投机采样中在特征层面的 exposure bias。
训练和解码阶段之间还存在目标上的不一致。在解码阶段,草稿模型的目标是生成目标 LLM 会赋予高概率的 token。在这种情况下,草稿模型应更关注于召回这些高概率 token,而对它们之间的具体顺序则可以稍微放松。另外,大部分 LLM 在应用时采取核采样或 top-k 采样。在这些解码策略中,高概率 token 对输出起着更重要的作用。因此,为了获得高效的草稿模型,它的训练目标应考虑到解码阶段的这些特性。据我们所知,现有的涉及训练草稿模型的投机采样方法普遍忽视了这些解码目标。
02 方法
为解决上述的训练和解码阶段不一致问题,我们提出了协调投机采样(HASS),旨在通过训练阶段学习协调的表征来解决上述问题。我们的方法包含两部分:(1)为了让草稿模型在训练阶段感知到解码目标,HASS 将推荐系统中的排序蒸馏思想扩展到投机采样,即协调目标蒸馏;(2)为了解决训练和解码间的上下文不一致,我们提出了一种多步的对齐训练策略,即协调上下文对齐。结合这两部分,HASS 显著提高了 LLM 的推理速度。在无需额外推理开销的情况下,也保持了草稿模型训练的高效。
协调目标蒸馏(Harmonized Objective Distillation)
HASS 通过引入推荐系统中的排序蒸馏思想,优先考虑草稿模型解码时更重要的一些 token。具体来说,排序蒸馏的目标是训练学生模型,使其对教师模型中排名靠前的项赋予更高的排序。在投机采样中,草稿模型是学生模型,而目标 LLM 是教师模型。具有类似特性的草稿模型在解码阶段将获得更高的接收率。设 K 个概率最高的 token 组成的集合为 ,其中 代表整个词汇表。HASS 在训练时使用以下的 Top-K 蒸馏损失:
其中 和 分别表示目标 LLM 和草稿模型预测下一个词的条件概率分布。在结合 EAGLE 时,训练阶段可以从目标 LLM 的 hidden states 中获取 ,这意味着结合 Top-K 损失训练有着和 EAGLE 一样的训练效率。
协调上下文对齐(Harmonized Context Alignment)
HASS 采用了多步的对齐训练策略,使草稿模型在训练和解码阶段的上下文保持一致。具体来说,HASS 将训练过程分为 n 步,使草稿模型能够利用与解码阶段一致的上下文特征。过程如下:
- 第一步与 EAGLE 的训练相同。在时间步 t+1,草稿模型以目标LLM的特征 作为输入并生成草稿模型特征 。这一步中,注意力掩码与因果掩码一致,不做修改。
- 第二步利用了来自第一步的特征。在时间步 t+1 的自注意力机制中,使用 来生成 query。key 和 value 由 生成,其中 表示拼接操作, 表示早于时间步 t 的特征。注意力掩码被修改以确保 看到的前一个特征始终是 ,如图 3 中的“HASS Training Step 2“所示。
- 对于第 j 步(j ≥ 3),前一步生成的特征 用于生成时间步 t+1 的query,而 key 和 value 由 生成。
HASS 的训练开销是 EAGLE 的 n 倍,但解码开销不变。后续实验证明,HASS 的加速效果在 n 值较小时就会收敛,因此是训练高效的,具体实现请参考论文的附录部分。
03 实验
主要实验
如表 1、2 所示,HASS 在所有的数据集和目标 LLM 上都表现出了最高的接受长度和最优的加速比。大部分方法在 HumanEval 数据集上加速效果最好,因为代码生成任务中的固定模版对于草稿模型更易生成从而加速。尽管 PLD 和 Lookahead 无需训练,但是它们的性能都显著弱于 EAGLE、EAGLE-2 和 HASS。
协调目标蒸馏的消融实验
我们改变了 Top-K 损失的 K 和权重,结果如图 4 所示。使用 Top-K 损失训练(权重大于 0)时,总是能提升草稿模型的接受长度。当 K 值很小时(K=1)会导致性能下降,可能是因为草稿模型过度关注概率最高的 token 而忽视了其他潜在 token。在 K=5 时,草稿模型的接受长度最大。
我们还尝试了更多关注高概率 token 的损失函数以替换 Top-K 损失,结果如表 3 所示。BiLD 损失在 T=0 时表现最好,Top-K 损失在 T=1 时表现最好。总体上,Top-K 损失的表现最好。
协调上下文对齐的消融实验
我们改变了协调上下文对齐的对齐步数,将用 Top-K 损失训练后的 EAGLE-2 权重作为基准,结果如表 4 所示。在不使用协调上下文对齐时(EAGLE-2+Top-K),草稿模型的效果最差。用 3 或 4 步协调上下文对齐训练的草稿模型总体上能获得最优的接受长度。当对齐步数增加到 5 步时,接受长度反而会下降,这可能是因为草稿模型的能力有限,当过度关注后几步的 token 生成时就会导致在前几步的预测精度下降。
我们画出了 HASS 和 EAGLE-2 在每一步生成token时的接受率曲线,如图 5 所示。可见在后几步生成 token 时,HASS 的接受率显著高于 EAGLE-2,验证了协调上下文对齐的有效性。
但在 LLaMA2-Chat 13B 和 LLaMA3-Instruct 70B 上,HASS 的第一步接受率相比 EAGLE-2 下降了。这可能是因为草稿模型关注后几步的 token 生成而忽视了第一步的,但第一步的接受率对于接受长度非常关键。因此我们考虑调整训练时每一步对齐的损失权重,来强调前几步的重要性。具体的,我们对于第 j 步的训练损失乘上权重 ,结果如表 5 和图 6 所示。当 从 1.0 降到 0.5 时,草稿模型的接受长度不断提高。其在第一步的接受率也对应增长,而后几步的接受率有所下降。当 下降到 0.3 时,训练过程过分强调了第一步 token 生成,导致了接受长度下降。我们将在多步对齐间取得平衡的探索留到后续工作中。
04 作者简介
- 乐凡
小红书中台算法工程师,目前主要负责大语言模型的相关研究和应用。
- 晓丹
小红书中台算法工程师,目前主要负责大语言模型的相关研究和应用。
- 特图
小红书中台算法基础模型方向负责人,主要研究方向:多模态大模型 x 内容分发技术。
- 瑞格
小红书中台算法团队负责人。