Jamba-1.5:大规模混合Transformer-Mamba模型
一、结论写在前面
论文标题:Jamba-1.5: Hybrid Transformer-Mamba Models at Scale
论文链接:https://arxiv.org/pdf/2408.12570
模型:https://huggingface.co/ai21labs
论文介绍了Jamba-1.5,基于Jamba架构的新型指令调优大型语言模型。Jamba是一种混合Transformer-Mamba专家混合架构,能够在不同上下文长度下提供高吞吐量和低内存使用,同时保持与Transformer模型相同或更好的质量。
论文发布了两种模型尺寸:Jamba-1.5-Large,具有940亿活跃参数,以及Jamba-1.5-Mini,具有120亿活跃参数。这两种模型都针对多种对话和指令遵循能力进行了微调,并且具有256Ktoken的有效上下文长度,是开放权重模型中最大的。
为了支持成本效益高的推理,论文引入了ExpertsInt8,一种新颖的量化技术,允许在处理256K token上下文时,将Jamba-1.5-Large适配到具有8张80GB GPU的机器上,而不损失质量。在学术和聊天机器人基准测试中评估时,Jamba模型取得了优异的成绩,同时提供了高吞吐量,并在长上下文基准测试中超越了其他开放权重模型。
二、论文的简单介绍
2.1 论文的背景
论文介绍了Jamba-1.5,两个基于论文的Jamba架构[的新型大语言模型,可供公众使用。Jamba-1.5-Mini是论文早期Jamba版本的更新和指令调优版本。与其较小的同类产品一样,Jamba-1.5-Large是一种混合架构,结合了Transformer和Mamba层,以及专家混合(MoE)模块。
自Jamba推出以来,类似的努力已经证实了在高达8B参数规模上结合Transformer和状态空间模型的优势。Jamba-1.5-Large在更大规模上展示了这种架构的优势。它具有94B活跃参数,总共398B参数。即使在这个大尺寸下,由于Jamba架构的效率以及论文开发的一种新颖量化技术,该模型可以在处理256Ktoken上下文时适配到具有8张80GB GPU的单台机器上。
Jamba-1.5-Mini和Jamba-1.5-Large均为经过指令微调的模型,通过Post-training赋予了多种能力。论文在广泛的基准测试中评估发现,它们在性能上与同尺寸模型相当,同时得益于Jamba架构的高效性。特别是,Jamba-1.5模型在长上下文评估中表现突出,使其成为唯一在RULER
基准测试中有效长度达到256K的模型,同时实现了KV缓存内存减少10倍以及卓越的吞吐量和延迟。
这些模型已公开可用:
Jamba-1.5-Mini: https://huggingface.co/ai21labs/AI21-Jamba-1.5-Mini
Jamba-1.5-Large: https://huggingface.co/ai21labs/AI21-Jamba-1.5-Large
2.2 模型架构
Jamba-1.5-Large基于Jamba ,这是一种混合解码器架构,融合了Transformer层 与Mamba层(一种状态空间模型(state-space model,SSM)),并附加了混合专家(MoE)模块 。详见 [24] 对该架构的详细描述。
在开发Jamba [24] 的过程中,论文发现Transformer、Mamba和MoE元素的结合有助于平衡吞吐量、内存使用和质量的需求。Jamba-1.5-Large在大规模上展现了这种灵活性。
Jamba-1.5-Large 遵循相同的 Jamba 结构,但容量更大。它拥有 94B 活跃参数和 398B 总参数。它包含 9 个块,每个块具有以下规格:
•l= 8每个block包含 8 层。
•a : m=1 : 7注意力层与 Mamba 层的比例。在论文的 Jamba 研究中,这一比例被发现是最佳的 [ 2 4 ],后续工作 [6, 37] 也证实了类似比例的成功.
•每隔 e=2 层使用 MoE 替代单一 MLP。共有 n=1 6 个专家,每个token选择前 K=2 个。
•隐藏状态维度为 8192。
•注意力查询头数为 64,KV 头数为 8。
表 1 将 Jamba-1.5 模型与公开可用的类似尺寸模型进行了比较。Jamba-1.5-Mini 的活跃参数数量与 Mixtral 8x7B 相近,而 Jamba-1.5-Large 的活跃参数数量介于 LLaMA-3.1-70B 和 Mistral-Large-2 之间。同时,论文的两个 Jamba 模型在 KV 缓存内存使用(256K token)方面远小于所有其他模型,相较于各自的对应模型,大约减少了近一个数量级的内存使用。
通过这些设置以及论文的专用量化(第 3.1 节),Jamba-1.5-Large 可以在单台配备 8 块 80GB GPU 的机器上提供服务,上下文长度可达 256K token。
表 1:Jamba-1.5-Mini、Jamba-1.5-Large 与近期开放模型在总可用参数、活跃参数及长上下文 KV 缓存内存方面的比较。Jamba-1.5-Mini 和 Jamba-1.5-Large 在 KV 缓存内存需求上提供了显著的减少。
对于这次发布,论文还尝试了Mamba-2 [6],这是Mamba的一个更快且改进的版本,据报道其性能优于单独的Mamba和Transformers。然而,如图1所示,论文发现,在混合架构中,Mamba-1-Attention组合的性能优于Mamba-2-Attention,因此论文在Jamba-1.5-Large中采用了Mamba-1。(论文还发现混合架构的性能优于纯Mamba-2。)论文推测这是因为Mamba-2相对于Mamba-1的一些优势,特别是能够使用更大的状态大小,在论文将全注意力层交错放置在Mamba层之间时,其重要性有所降低,因为这些全注意力层能够从整个上下文中汇聚信息。
图1:Mamba-1、Mamba-2、Mamba-1-Attention和Mamba-2-Attention在训练100B tokens模型上的比较。尽管Mamba-2在没有注意力机制的情况下优于Mamba-1,但混合的Mamba-1-Attention表现更佳。
2.3 服务考虑与改进
论文分享了一些见解和改进措施,以实现Jamba模型在大规模上的高效服务。
2.3.1 专家Int8量化
为了支持Jamba-1.5-Large的高效服务,论文开发了一种新的量化技术,论文称之为ExpertsInt8。论文观察到,超过85%的模型权重位于MoE层中,超过90%位于MoE或MLP层中。论文希望量化这些权重,同时仍然享受快速BF16内核的好处。为此,论文将MoE和MLP权重量化为INT8,以INT8格式保存,并在实际计算前将其反量化回BF16。重要的是,反量化步骤直接在vLLM[18]的融合moe内核内部进行。这样,反量化过程几乎不增加额外开销,甚至导致比BF16更低的延迟。论文已经将修改后的融合moe内核贡献给了vLLM。
论文的ExpertsInt8方法具有多个优点
•首先,它速度快;量化仅在模型加载时花费几秒钟。
•其次,与vLLM中的大多数其他技术不同,它不依赖于校准,校准可能需要数小时或数天,并且可能不稳定。
•第三,论文仍然可以使用BF16来保存大型激活。
•第四,它可以在A100 GPU上使用,而FP8仅在H100上可用。
•最后,论文的量化在延迟上与FP8匹配,同时超越其他量化技术,且没有质量损失。
图2比较了使用Jamba-1.5-Mini、Jamba-1.5-Large以及两个Mixtral模型(8x78B和8x22B)的不同量化技术的延迟。在IH100 GPU上,ExpertsInt8与FP8的延迟相匹配。在A100上,由于FP8不可用,ExpertsInt8是一种有吸引力的技术,大大优于GPTQ。结合上述ExpertsInt8的优势,这使得它成为服务大型MoE模型的有吸引力的量化技术。
图2:不同量化技术的比较,展示了在1024个token上下文和128个token解码条件下的端到端延迟。ExpertsInt8与FP8表现相似,同时快速且易于应用,仍允许BF16激活,并且适用于A100 GPU,而FP8在这些GPU上不可用。
2.3.2 激活损失
在预训练过程中,论文发现某些激活,即特定专家的输出以及最后Mamba层的输出,对于某些输入token,其幅度逐渐增加,最终达到高达4 \times 10^9的值。尽管论文没有发现这对预训练本身造成伤害,预训练是在BF16精度下进行的,但激活的幅度可能在推理过程中引起数值问题,因为某些量化库仅支持FP16精度的激活,其最大范围为64K。
为了缓解这些担忧,论文增加了一个“激活损失”项,与前向传播中激活的均方成正比,并有一个可配置的α因子,惩罚较大的激活值。通过实验论文发现,即使\alpha值高达至少10^-3,这种辅助损失对训练也没有影响。对于Jamba-1.5-Large,论文使用了α=10^-5,这足以将激活值降低到可接受的范围(最大2K-3K)。此外,添加这种辅助损失几乎立即减少了激活值,使其仅在训练结束时添加,而不影响训练速度和质量。
为了验证这种方法,论文使用FP16激活对模型进行了全面评估,并获得了与BF16评估相同的结果,没有任何NaN/溢出。
2.4吞吐量和延迟分析
得益于混合Jamba架构,论文的Jamba-1.5模型提供了出色的吞吐量和延迟。图3和4分别展示了Jamba-1.5-Mini和Jamba-1.5-Large的情况。如图所示,论文的模型在延迟和吞吐量方面远优于类似规模的模型。它们在长上下文中的优势尤为明显,存在显著差距。重要的是,Jamba-1.5-Large即使在长上下文中也能高效运行,而大型LLaMA3-405B在相同硬件上无法运行。
图3:Jamba-1.5-Mini与其他模型在延迟和吞吐量方面的比较。所有测量均在2个A100 80GB GPU上进行,批量大小为1,输出长度为512个token。Jamba-1.5-Mini表现出更好的延迟,尤其是在大型上下文中,输出token吞吐量仅略有下降。
图4:Jamba-1.5-Large与其他模型在延迟和吞吐量方面的比较。所有测量均在8块A100 80GB GPU上进行,批量大小为1,输出长度为512个token。Jamba-1.5-Large在大型上下文中表现出更好的延迟,输出token吞吐量仅略有下降。LLaMA-3.1-405B的结果截断至64K,因为该模型在8块80GB GPU上无法适应超过100Ktoken的上下文长度。
2.5 训练
2.5.1 训练基础设施和数据
Jamba-1.5-Large在NVIDIA H100 GPU上使用论文自有的专有框架进行训练,该框架包括FSDP、张量并行、序列并行和专家并行。对于后者,论文采用了MegaBlocks
2.5.2训练阶段
该模型分三个阶段进行训练。在预训练阶段,首先在2024年3月更新的自有数据集上进行训练。论文的预训练数据集是公开可用的
网页文档、代码、书籍和科学文章的混合体。
论文的预处理流程包括解析、质量过滤和去重。为了充分利用公开可用数据,论文开发了自己的自有解析器,并使用它提取文本和格式。确切的数据混合是通过各种消融实验确定的。这一阶段包括多语言数据,重点是以下语言:
英语、西班牙语、法语、葡萄牙语、意大利语、荷兰语、德语、阿拉伯语和希伯来语。然后,在中间训练阶段进行了一小段时间的高比例长文档训练,以强调其远程能力。最后,模型进行了Post-training。
2.5.3 Post-training
论文的Post-training方法旨在同时实现两个目标:(i) 赋予模型各种技能和对话能力;(ii) 保留预训练尤其是中间训练的长上下文能力。这两个目标部分存在冲突,因为大多数可用的Post-training数据集由相对较短的示例组成。
鉴于这些考虑,论文的Post-training过程包括在高质量对话数据、技能特定数据和长上下文数据上的监督微调。混合这些不同类型的数据旨在保留长上下文能力并获取所需技能。如以下评估所示,论文发现论文的模型在长上下文评估中表现非常出色。
在进行监督微调时,论文大量使用合成数据,这在最近的基石模型中很常见,并且反映了论文构建结构化数据以构建复合AI系统的方法。论文开发了多个不同的数据合成流程,针对不同的模型能力。所有流程都采用以下模式:(i)在目标分布中采样或生成提示;(ii)从语言模型生成响应;(iii)根据自动验证和评分对响应进行质量过滤或排序;以及(iv)后期编辑以去除伪影并适应所需的格式。论文为构成最终数据混合的不同数据管道使用不同的模型、提示、采样、过滤和编辑方法。
论文根据一系列主要是内部的自动指标选择了最终的训练配方(数据混合和超参数)。Jamba-1.5模型都使用相同的控制标记和格式模板进行微调,论文将其作为HlF兼容的标记器和聊天模板的一部分提供;详见模型卡。
论文提供了几个合成数据生成的显著例子:
基于表格的问答。论文生成表格数据和伴随的问答对,如论文在表格理解工作[20]中所展示的。然后,论文使用语言模型将表格转换为自然语言段落。论文生成的训练示例包括针对给定表格中特定行或列的文本的提取、聚合和归因任务。
文档问答。给定一个文档,论文引导语言模型生成单段落和多段落的问题-答案对。有时,论文通过添加类似文本来嵌入这些示例于更长的上下文中,以鼓励带有归属的长上下文理解。
工具使用。论文以开源的Glaive函数调用数据集为起点,通过各种启发式方法和输出模式的验证进行过滤。为了支持并行函数调用,论文首先为Glaive中的每个函数生成多个有效的参数分配。接着,论文从相同函数和不同函数中抽取这些有效参数分配的子集,以生成对应函数调用集合的用户请求。最后,论文引导一个函数调用语言模型响应这些生成的用户请求,并仅保留函数调用匹配原始参数分配的响应。
可引导性。论文定义了一组易于验证的指令,并合成了包含通用文档草拟任务及一个或多个约束条件的提示。论文从语言模型中生成这些提示的完成结果,并基于细粒度指令的验证和通用奖励模型进行拒绝采样。为了支持系统消息中的指令,论文选择了多个共享细粒度指令的此类提示。
2.5.4 一些观察
论文分享了从Jamba-l.5开发过程中得出的一些观察。尽管这些观察并未完全深入探讨,但论文希望它们能启发社区进一步研究这些问题。
首先,尽管论文仅包含了一小部分非英语数据,且仅针对特定技能在微调阶段进行了处理,但论文的Jamba-1.5模型在多种语言上表现相当出色。如前所述,论文在预训练阶段确实包含了多语言数据。因此,论文推测模型能够在主要使用英语进行微调时利用预训练阶段学到的知识。
其次,论文高效的Jamba架构降低了在长上下文上进行微调的成本,使得在给定预算下能够进行更多实验。因此,论文能够在微调阶段尝试多种不同的训练方案。
最后,尽管像PPO [33]或DPO [29]这样的偏好调优算法改进了模型输出与人类意图之间的一致性,但论文发现,精心生成的合成数据、数据过滤和监督微调的组合对于获得强大的微调模型至关重要。
2.6 评估
虽然论文认为基准测试仅部分相关于实际应用的成功和用户满意度,但论文仍报告了关键公共基准的结果。首先,论文报告了标准学术基准的结果。然后,论文在聊天机器人基准上评估模型。最后,论文对Jamba-1.5-Large进行了多项长上下文评估和多语言评估。
论文与近期同尺寸范围内的开放权重模型进行了比较:与Jamba-1.5-Large相比,有LLaMA-3.1 70B和Mistral-Large-2-123B;与Jamba-1.5-Mini相比,有LLaMA-3.1-8B和Gemma-2-9B。
2.6.1 学术基准
论文报告了一系列标准学术基准的结果:MMLU、MMLU-Pro、GPQA、ARC-Challence、BBII和HumanEval 。论文还评估了IFEval指令遵循数据集和BFCL v1函数调用数据集。最后,论文在RealToxicity和TruthfulQA上报告了安全评估结果。
表2将Jamba-1.5-Large与几个公开可用且规模相当的模型进行了比较。所有结果均来自官方来源或由论文评估,如表中所示。论文观察到,Jamba-1.5模型在包括知识、推理、指令遵循和功能调用能力在内的标准学术基准上,与近期公开可用的最先进模型表现相当。论文还观察到与文献中报告的安全指标相似。
重要的是,如上所述,Jamba-1.5模型在实现这些结果的同时,提供了更好的吞吐量和延迟。
表2:Jamba-1.5模型在获得与同等规模模型相似性能的同时,享受到了更好的吞吐量和延迟。
2.6.2 聊天机器人评估
论文评估了Jamba-1.5模型在两个聊天机器人场景中的表现:Arena-Hard ,一组500个具有挑战性的用户查询,使用GPT4-Turbo作为评判标准,以及WildBench,使用GPT4-Turbo作为评判标准并进行了长度偏差缓解。如表3所示,Jamba-1.5模型在这些评估中取得了优异的结果,其中Jamba-1.5-Large超过了LLaMA-3.1 70B,但略逊于Mistral-Large-2 123B,后者拥有大约30%更多的活跃参数。
表3:Jamba-1.5模型与类似大小模型在聊天机器人基准测试中的比较。Jamba-1.5模型在性能相似的情况下,具有更好的吞吐量和延迟。" 由论文进行的评估。
2.6.3 长上下文评估
发布的模型能够处理长达256K个标记的上下文长度。在本节中,论文对其在测试其长上下文能力的合成和自然主义基准上进行了评估。
2.6.3.1 RULER
论文在RULER基准上进行了评估,这是一组13个合成任务,旨在评估语言模型的长上下文能力。RULER包括8种针在草堆中的检索任务变体,包括多个‘needles’[2]。它还包括一个变量跟踪任务,其中应返回一系列变量绑定,两个聚合任务,其中一个需要返回最常见的单词,以及两个问答任务,其中从自然主义数据集[30, 41]中插入的段落来模拟长上下文。
结果展示在表4中。在所有公开和专有模型中,Jamba-1.5-Mini和Jamba-1.5-Large是唯一确认有效长度为256Ktoken的模型。Gemini-pro在原始RULER论文中报告了高达128K的良好结果。然而,尽管论文付出了很大努力,仍无法重现这些结果。论文检查了Gemini-pro的生成内容,发现该模型经常无法回答或生成拒绝。由于官方RULER结果来自预览版本,论文假设Gemini-pro自那时起经历了更新,这些更新损害了其在RULER上的性能。
表4:Jamba-1.5模型与其他公开和专有模型在RULER基准上的比较。其他模型的结果来自RULER Github。bigstar 由论文进行的评估。Jamba-1.5模型是唯一确认有效长度为256Ktoken的模型。
2.6.3.2 Infinite-BENCH
接下来,论文在cOBENCH数据集上进行评估,该数据集旨在评估语言模型的长上下文能力,平均长度为100K个标记。论文重点关注理解长篇小说的两个英语任务:问答(EN.QA)和多项选择问答(EN.MC)。如表5所示,Jamba-1.5模型在这种情况下表现非常出色,优于同样大小的LLaMA-3.1和Mistral-Large-2模型。(由于Gemma-2 9B的上下文窗口较短,仅为8K,因此未报告其结果。)
表5:Jamba-1.5模型在长上下文评估中优于同样大小的LLaMA-3和Mistral-Large-2模型。T评估由论文进行。
2.6.4 多语言能力
论文对Jamba-1.5在非英语语言中的能力进行了基本评估。特别是,论文报告了通过LM Evaluation Harness 分发的多语言MMLU数据集上的结果。表6显示了结果,其中Jamba-1.5-Mini与其比较点相比表现相似或更好。Jamba-1.5-Large略落后于其可比模型,但仍展现出良好的多语言能力。
表6:Jamba-1.5与其他模型在多语言MMLU数据集上的比较。
2.7 对齐与安全考量
论文模型对齐的方法是通过在模型行为与客户期望之间建立透明度来驱动的。论文的模型默认遵循基于论文参与行业标准机构、智库以及与客户直接经验的商业行为准则。论文认为这是一种持续且不断发展的合作关系。此外,企业有多种方式来控制模型行为,以反映其独特的价值观和文化,例如额外的培训和微调、系统消息和提示工程。总体而言,论文的AI行为准则基于以下目标:
•使模型行为和输出与公司价值观和规范的商业礼仪相一致。
•明确声明预期行为的条款,以便错误/漏洞易于识别。
•与客户合作,并将行为映射到他们的最佳实践。
•持续收集反馈,以监控并积极改进行为。
根据论文在OECD
任务组中的角色,该任务组旨在为应用G7广岛人工智能系统开发组织行为准则开发一个监控机制,论文将模型对齐工作与OECD基于价值观的AI原则相结合:包容性增长、可持续发展与福祉;以人为本的价值和公平性;透明度和可解释性;鲁棒性、安全性和安全性;以及问责制。
对于前四个原则,论文详细阐述了行为预期或准则,并提供了可用于训练/对齐和测试合规性的示例。问责原则侧重于Al21在承担模型行为责任中的角色。论文认为,这种问责主要通过与客户、监管机构和独立第三方的透明度和接触来体现。论文与经合组织(OECD)、斯坦福大学的HELM [23]和FMTI [3]以及此类文件的接触,展示了这一承诺,以及论文在FMTI中的高排名(截至2024年5月排名第二)。
论文创建了60个准则,这些准则与OECD原则相对应。这些准则被表述为论文的模型应避免的行为指令。完整列表将公开发布。