注意力机制的变体之MLA 原创

发布于 2024-10-15 13:54
浏览
0收藏

本文介绍注意力机制的变体-MLA。

MLA(Multi-head Latent Attention),是由杭州深度求索人工智能在DeepSeekV2提出的一种注意力机制变体。MLA主要旨在解决推理过程中由于attention机制中KV Cache占用过多内存而导致的性能瓶颈问题。为此,MLA引入了低秩KV压缩技术,有效减少了KV Cache的大小,从而缓解了这一问题。

有兴趣小伙伴可以看官方技术报告的介绍:​​https://arxiv.org/pdf/2405.04434v2​

原理介绍

注意力机制的变体之MLA-AI.x社区

上图为MHA、GQA、MQA、MLA的原理对比图。从上图可知传统Transformer采用MHA,但KV Cache在推理过程中可能成为性能瓶颈。MQA和GQA虽然在一定程度上可以减少KV Cache的占用,但其效果通常不如MHA。MLA通过低秩的Key-Value联合压缩技术,不仅实现了比MHA更优的效果,还大幅减少了所需的KV Cache大小。

具体来说,MLA通过低秩联合压缩key和value来减少kv cache。从注意力机制的步骤来分析:

  • 通过输入x乘以不同的矩阵参数Wq、Wk、Wv得到不同的QKV向量
  • 在转换到QKV向量时候,将x乘以一个低秩矩阵,得到低阶矩阵表示
  • 再通过一个高阶矩阵来恢复原来的特征空间。由于矩阵是模型的权重参数已经保存,所以只需要保存一个低秩的潜层特征就可以恢复成KV,而不是像之前需要同时缓存KV。

代码实现


bsz, q_len, _ = hidden_states.size()
        
# 计算压缩后的Q,再还原成高维
# [B, q_len, hidden_size]
# 即[B, q_len, num_head * q_head_dim]
q = self.w_uq(self.q_a_layernorm(self.w_dq(hidden_states)))
# [B, num_head, q_len, q_head_dim]
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
# 包含当前位置可用上下文的长度
kv_seq_len = q.size(-2)
if past_key_value is not None:
    if self.layer_idx is None:
        raise ValueError(
            f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
            "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
            "with a layer index."
        )
    kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# 得到当前压缩后的kv, c_t^{kv}
# [B, q_len, d_c]
compressed_kv = self.w_dkv(hidden_states)

# 将当前位置之前的压缩后的kv拼接到前面
if past_key_value is not None:
    # 得到的应该是[B, kv_seq_len, d_c], c^{kv}
    compressed_kv = past_key_value.update(compressed_kv)
# 计算得到k^C和v^C
# [B, num_head, kv_seq_len, q_head_dim]
k = self.w_uk(compressed_kv).view(bsz, -1, self.num_heads, self.q_head_dim).transpose(1, 2)
v = self.w_uv(compressed_kv).view(bsz, -1, self.num_heads, self.q_head_dim).transpose(1, 2)

# 注意力权重
# [B, num_head, q_len, kv_seq_len]
attn_weights = (
    torch.matmul(q, k.transpose(2, 3)) * self.softmax_scale
)
...
attn_weights = nn.functional.softmax(
    attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
    attn_weights, p=self.attention_dropout, training=self.training
)
# [B, num_head, q_len, q_head_dim]
attn_output = torch.matmul(attn_weights, v)
...

以上为MLA的核心部分代码实现,里面有相应的代码注释。


本文转载自公众号瓦力算法学研所,作者:喜欢瓦力的卷卷

原文链接:​https://mp.weixin.qq.com/s/dWZk8TBY89re207ZL3GjfA​​​

©著作权归作者所有,如需转载,请注明出处,否则将追究法律责任
收藏
回复
举报
回复
相关推荐