深度学习中的注意力机制革命:MHA、MQA、GQA至DeepSeek MLA的演变

人工智能
本文将系统梳理这一发展脉络,深入剖析MHA、MQA、GQA等变体的核心思路与实现方法。

在深度学习领域,注意力机制已然成为现代大模型的核心基石。从最初的多头注意力(MHA,Multi-Head Attention)到如今的多查询注意力(MQA,Multi-Query Attention)、分组查询注意力(GQA,Grouped-Query Attention),再到DeepSeek提出的创新性多头潜在注意力(MLA,Multi-Head Latent Attention)方法,这一演变历程不仅是技术发展的脉络,更是对效率与性能极致追求的生动写照。本文将系统梳理这一发展脉络,深入剖析MHA、MQA、GQA等变体的核心思路与实现方法。

图片

一、Multi-Head Attention

图片

多头注意力(Multi-Head Attention,MHA)是Transformer模型架构中的一个核心组件,它允许模型在处理输入序列时能够同时关注来自不同位置的不同表示子空间的信息。

MHA通过将输入向量分割成多个并行的注意力“头”,每个头独立地计算注意力权重并产生输出,然后将这些输出通过拼接和线性变换进行合并以生成最终的注意力表示。

下面来看下计算公式:

1. 输入变换:输入序列首先通过三个不同的线性变换层,分别得到查询(Query)、键(Key)和值(Value)矩阵。这些变换通常是通过全连接层实现的。

图片

2. 分头:将查询、键和值矩阵分成多个头(即多个子空间),每个头具有不同的线性变换参数。

3. 注意力计算:对于每个头,都执行一次缩放点积注意力(Scaled Dot-Product Attention)运算。具体来说,计算查询和键的点积,经过缩放、加上偏置后,使用softmax函数得到注意力权重。这些权重用于加权值矩阵,生成加权和作为每个头的输出。

图片

4. 拼接与融合:将所有头的输出拼接在一起,形成一个长向量。然后,对拼接后的向量进行一个最终的线性变换,以整合来自不同头的信息,得到最终的多头注意力输出。

图片

作为最早提出的注意力机制方法,多头注意力机制存在的问题:

  1. 计算复杂度高:多头注意力机制的计算复杂度与输入序列长度的平方成正比(图片),这使得在处理长序列时计算量显著增加。例如,对于长度为1000的序列,计算复杂度将达到图片,这在实际应用中可能导致训练和推理速度变慢。
  2. 内存占用大:在多头注意力机制中,每个头都需要独立存储查询(Query)、键(Key)和值(Value)矩阵,这导致内存消耗显著增加。对于大规模模型,尤其是在长序列任务中,KV缓存的大小会线性增长,成为内存瓶颈。
  3. 特征冗余:多头注意力机制中,不同头可能学习到相似的特征,导致特征冗余。这种冗余不仅浪费计算资源,还可能降低模型的泛化能力。
  4. 模型解释性差:多头注意力机制的内部工作机制较为复杂,每个头的具体功能难以直观理解,降低了模型的可解释性。尽管可以通过注意力权重可视化来理解模型关注的输入信息,但这种解释性仍然有限。
  5. 过拟合风险:由于多头注意力机制增加了模型的参数量和复杂度,尤其是在数据量有限的情况下,模型可能会过度拟合训练数据。
  6. 推理效率低:在自回归模型中,每个解码步骤都需要加载解码器权重以及所有注意力的键和值,这不仅计算量大,还对内存带宽要求高。随着模型规模的扩大,这种开销会进一步增加,使得模型扩展变得困难。

二、Multi-Query Attention

针对MHA存在的问题,Google提出了多查询注意力(Multi-Query Attention,MQA)。MQA的设计初衷是为了在保持Transformer模型性能的同时,显著提升计算效率和降低内存占用。

在MHA中,输入分别经过图片的变换之后,都切成了n份(n=头数),维度也从图片降到了图片,分别进行attention计算再拼接。MQA的做法很简单,在线性变换之后,只对Q进行切分(和MHA一样),而K、V则直接在线性变换的时候把维度降到了图片(而不是切分变小),然后这n个Query头分别和同一个K、V进行attention计算,之后把结果拼接起来。

简单来说,就是MHA中,每个注意力头的K、V不一样,而MQA中每个注意力头的K、V一样,值共享,其他步骤和MHA一样。

图片

简单看下公式:

1. 查询(Query)保持多头设计:

图片

2. 键(Key)和值(Value)共享一组矩阵:

图片

3. 计算注意力输出:

图片

下图是论文中MHA和MQA的对比结果,可以看到由于共享了多个头的参数,限制了模型的表达能力,MQA虽然能好地支持推理加速,但是在效果上略比MHA差一点,但相比其他修改hidden size或者head num的做法效果都好。

图片

图片

MQA通过共享键(K)和值(V)矩阵的设计,显著降低了计算复杂度和内存占用,同时保持了较好的性能表现。这种设计特别适合长序列任务、资源受限的设备以及需要快速推理的场景。但是MQA对于所有query全部共享同一个key、value可能会限制每个查询头捕捉不同特征的能力,进而影响模型的整体表达能力和灵活性。

三、Grouped-Query Attention

MQA对效果有点影响,MHA缓存又存不下,Google又继续提出了一个折中的办法组查询注意力(Grouped-Query Attention,GQA),既能减少MQA效果的损失,又相比MHA需要更少的缓存。

图片

简单看下公式:

1. 将头分为g组,每组有h/g个头。对于每组i:

图片

2. 计算每个组的注意力输出并拼接:

图片

来看下结果:

图片

看表中2/3/4行对比,GQA的速度相比MHA有明显提升,而效果上比MQA也好一些,能做到和MHA基本没差距。文中提到,这里的MQA和GQA都是通过average pooling从MHA初始化而来,然后进行了少量的训练得到的。

下面是Llama2技术报告中做的MHA、MQA、GQA效果对比,可以看到效果确实很不错。

图片

四、Multi-Head Latent Attention

图片

在最新的DeepSeek论文中,为解决MHA在高计算成本和KV缓存方面的局限性,提出改进的多头潜在注意力(Multi-Head Latent Attention,MLA),旨在提高Transformer模型在处理长序列时的效率和性能。

MLA的技术创新主要是采用低秩联合压缩键值技术,优化键值(KV)矩阵,显著减少了内存消耗并提高了推理效率。

具体来说,MLA通过低秩联合压缩键值(Key-Value),将它们压缩为一个潜在向量(latent vector),从而大幅减少所需的缓存容量,还降低了计算复杂度。在推理阶段,MHA需要缓存独立的键(Key)和值(Value)矩阵,这会增加内存和计算开销。而MLA通过低秩矩阵分解技术,显著减小了存储的KV(Key-Value)的维度,从而降低了内存占用。

MLA利用低秩压缩技术,使得DeepSeek的KV缓存减少了93.3%。来看下公式:

1. KV联合低秩压缩

MLA模型通过低秩压缩对键(keys)和值(values)进行联合压缩,以减少KV缓存的大小。其核心公式为:

图片

  • 图片是key和value的压缩潜在向量,图片表示KV压缩维度;
  • 图片是降维投影矩阵,图片分别是key和value的升维投影矩阵。

2. Q的低秩压缩

为了降低训练时的激活内存占用,MLA对query(Q)也进行了低秩压缩。其核心公式为:

图片


  • 图片是query的压缩潜在向量,图片表示query压缩的维度;
  • 图片是query的降维投影矩阵,图片是query的升维投影矩阵。

3. RoPE 与低秩 KV 压缩不兼容问题-解耦 RoPE 策略

RoPE 对keys和queries都是位置敏感的。如果对键图片应用 RoPE,则会有一个与位置相关的 RoPE 矩阵。 这种情况下,图片在推理过程中不能再被吸收进图片,因为一个与当前生成的 token 相关的 RoPE 矩阵会存在于图片之间,矩阵乘法不遵循交换律。因此,必须在推理过程中重新计算所有前缀 token 的键,这将显著阻碍推理效率。 为了解决这个问题,论文提出了解耦 RoPE 策略,该策略使用额外的多查询注意力(MQA)图片和共享的键图片来携带 RoPE,其中图片表示解耦query和key的head_dim。

在解耦 RoPE 策略下,MLA 执行以下计算:

图片

  • 其中图片图片是分别生成解耦query和key的矩阵;

下面附上代码:

import torch
import torch.nn as nn
import math
class MLA(nn.Module):
    def __init__(self, d_model=512, down_dim=128, up_dim=256, num_heads=8, rope_head_dim=26, dropout_prob=0.1):
        super(MLA, self).__init__()
        
        self.d_model = d_model
        self.down_dim = down_dim
        self.up_dim = up_dim
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.rope_head_dim = rope_head_dim
        self.v_head_dim = up_dim // num_heads    
        # 初始化kv联合以及q对应的dow,up projection
        self.down_proj_kv = nn.Linear(d_model, down_dim) # W^{DKV}
        self.up_proj_k = nn.Linear(down_dim, up_dim)# W^{UK}
        self.up_proj_v = nn.Linear(down_dim, up_dim) # W^{UV}
        self.down_proj_q = nn.Linear(d_model, down_dim) #W^{DQ}
        self.up_proj_q = nn.Linear(down_dim, up_dim) # W^{UQ}  
        # 初始化解耦的q,k进行MQA计算的映射矩阵
        self.proj_qr = nn.Linear(down_dim, rope_head_dim * num_heads)
        self.proj_kr = nn.Linear(d_model, rope_head_dim*1)
        #初始化解耦的q,k对应的rope类,因为头的数量不同,初始化2个实例
        self.rope_q = RotaryEmbedding(rope_head_dim * num_heads, num_heads)
        self.rope_k = RotaryEmbedding(rope_head_dim, 1)     
        # Dropout and final linear layer
        self.dropout = nn.Dropout(dropout_prob)
        self.fc = nn.Linear(num_heads * self.v_head_dim, d_model)
        self.res_dropout = nn.Dropout(dropout_prob)
    def forward(self, h, mask=None):
        bs, seq_len, _ = h.size()
       # setp1 :低秩转换
        c_t_kv = self.down_proj_kv(h)
        k_t_c = self.up_proj_k(c_t_kv)
        v_t_c = self.up_proj_v(c_t_kv)
        c_t_q = self.down_proj_q(h)
        q_t_c = self.up_proj_q(c_t_q)    
        
        #step2:解耦的q,k进行MQA计算,同时引入ROPE
        #q_t_r,k_t_r施加rope时均扩展了n_h_r维度->[bs,n_h_r,seq_len,rope_head_dim]
        q_t_r = self.rope_q(self.proj_qr(c_t_q))
        k_t_r = self.rope_k(self.proj_kr(h))    
        
        #step3:拼接step1,step2得到的q,k,进行sdpa计算
        #q_t_c扩展出num_heads为4维,以便于和q_t_r拼接
        q_t_c = q_t_c.reshape(bs, seq_len, self.num_heads, -1).transpose(1, 2)
        #head_dim,rope_head_dim拼接
        q = torch.cat([q_t_c, q_t_r], dim=-1)
        #k_t_c扩展出num_heads为4维,以便于和k_t_r拼接
        k_t_c = k_t_c.reshape(bs, seq_len, self.num_heads, -1).transpose(1, 2)
        #k_t_r为MQA,n_h_k_r=1,为了和q_t_r计算,需要在n_h_k_r维度复制
        #k_t_r:[bs,n_h_r_k,seq_len,rope_head_dim]->[bs,num_heads,seq_len,rope_head_dim]
        k_t_r=k_t_r.repeat(1,self.num_heads,1,1)
        #head_dim,rope_head_dim拼接
        k = torch.cat([k_t_c, k_t_r], dim=-1)  
        # 注意力计算,[bs,num_heads,seq_len,seq_len]
        scores = torch.matmul(q, k.transpose(-1, -2))
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        scores = torch.softmax(scores / (math.sqrt(self.head_dim) + math.sqrt(self.rope_head_dim)), dim=-1)
        scores = self.dropout(scores)
        #v_t_c和scores计算,扩展出num_heads维度
        v_t_c = v_t_c.reshape(bs, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2)
        output = torch.matmul(scores, v_t_c)
        #压缩num_head,送入最终统一映射层
        output = output.transpose(1, 2).reshape(bs, seq_len, -1)
        output = self.fc(output)
        output = self.res_dropout(output)
        return output
bs, seq_len, d_model = 4, 10, 512
h = torch.randn(bs, seq_len, d_model)
mla = MLA(d_model=d_model)
output = mla(h)

DeepSeek没有给出MLA与其他几个注意力机制对比的实验结果,但是结果导向来看,MLA的KV缓存大幅减少,大幅提高模型推理速度,在减少资源消耗的同时,保持甚至提升模型性能。

五、总结

从MHA到MQA、GQA,再到MLA,注意力机制的演变展示了在效率与性能之间不断优化的轨迹。MLA通过创新的KV缓存压缩和恢复机制,实现了在资源消耗、推理速度和模型性能之间的最佳平衡,为大语言模型的高效部署和应用提供了新的可能性。

[1]MHA: Attention Is All You Need(https://arxiv.org/pdf/1706.03762)

[2]MQA: Fast Transformer Decoding: One Write-Head is All You Need(https://arxiv.org/pdf/1911.02150)

[3]GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints(https://arxiv.org/pdf/2305.13245)

[4]MLA:https://github.com/deepseek-ai/DeepSeek-V3/blob/main/DeepSeek_V3.pdf


责任编辑:庞桂玉 来源: 小白学AI算法
相关推荐

2025-02-10 00:00:55

MHAValue向量

2025-01-16 09:20:00

AI论文模型

2024-04-03 14:31:08

大型语言模型PytorchGQA

2024-10-31 10:00:39

注意力机制核心组件

2024-06-28 08:04:43

语言模型应用

2025-02-25 10:03:20

2018-08-26 22:25:36

自注意力机制神经网络算法

2020-09-17 12:40:54

神经网络CNN机器学习

2023-05-05 13:11:16

2025-02-19 15:30:00

模型训练数据

2025-02-14 11:22:34

2024-09-19 10:07:41

2024-12-09 00:00:10

2025-02-25 10:21:15

2025-02-24 11:31:33

2024-02-19 00:12:00

模型数据

2017-08-03 11:06:52

2011-07-07 13:12:58

移动设备端设计注意力

2024-11-04 10:40:00

AI模型

2024-12-04 09:25:00

点赞
收藏

51CTO技术栈公众号