SFT loss计算的那些坑,完美避开!!!

发布于 2024-12-11 10:48
浏览
0收藏

​SFT 可以说是 LLM 的基本操作了,如果只是想把 SFT 跑起来是非常简单的,只需要构造 input_ids 和 labels,然后就可以把训练跑起来。然而,这样的训练效率实际上非常低。

所以在训练时,通常有两个加速方法:

  • 多轮合并
  • packing

无论是哪种方法,加速后都需要保证 loss 和原来是等价的。本文主要介绍这两种加速方法,以及 loss 计算时遇到的问题。

1.多轮合并

假设我们有一个对话,其中 user 和 bot 交互了 3 轮,我们可以构建三个样本:

SFT loss计算的那些坑,完美避开!!!-AI.x社区

input_ids 就是对应的 token id,labels 输入部分(白色)使用 -100,输出部分(绿色)使用 input_ids。

这样计算的 loss 可以表示为:

SFT loss计算的那些坑,完美避开!!!-AI.x社区

其中 l_i 表示第 i 个样本的 loss,n_i 表示第 i 个样本输出的 token 数量 (对应绿色部分)。

这样除了训练比较慢,没有什么别的问题。因为不同样本之间有很多重复计算的前缀,实际上这部分计算一次就行。

2.加速计算

SFT loss计算的那些坑,完美避开!!!-AI.x社区

如果将三轮三个样本合并成一个样本,可以尝试这种构造形式。

因为存在 causal attention mask,所以每个 token 只能看到前面的 token,计算上和之前是等价的。

但是这样有一个坑:如果还是按照刚才的方式构建 input_ids 和 labels (白色用-100,绿色用input_ids)loss 计算是有问题的。

pytorch CrossEntropyLoss 计算 loss 按照下面的方法,默认是"mean"。

SFT loss计算的那些坑,完美避开!!!-AI.x社区

所以我们会得到这样的 loss:

SFT loss计算的那些坑,完美避开!!!-AI.x社区

当不同轮次的输出长度不同时,这种 loss 和刚才的不等价。多轮对话中输出较短的权重被降低了,输出较长的被提高了。所以结果就是短输出的数据训练不够充分。

3.Packing

假设我们有两个对话,第一个是单轮对话,第二个是三轮对话。

SFT loss计算的那些坑,完美避开!!!-AI.x社区

正确的 loss:

SFT loss计算的那些坑,完美避开!!!-AI.x社区

其中 l_ij 表示第 i 个样本第 j 轮对话的 loss,n_ij 同理。

问题:真实场景中的训练集文本长度长短不一,Padding 后矩阵非常稀疏,只有不到一半是有效计算。

加速计算:

SFT loss计算的那些坑,完美避开!!!-AI.x社区

将所有样本拼接成一条,并且加入 attention mask 保证后面的样本看不见前面的 token。

比如在 flash attention 中,可以调用 flash_attn_varlen_qkvpacked_func,并传入 cu_seqlens 参数。

和之前一样,如果不修改 loss 计算方法,packing 的样本之间会存在因为长度不同,导致训练不充分的问题。

4.正确方法

一般情况下,loss 计算会经历三次平均:

  • micro batch 维度,分母是这个 micro batch 中的所有 label 不是 -100 的 token 数
  • DP 维度,分母是 DP size (和GPU数量相关)
  • 梯度累加维度,分母是梯度累加数

我们这里要做的就是禁用这三个平均,统一用这个 global batch 的对话轮数作为分母。

在新版 megatron 框架中,开启开关 --calculate-per-token-loss 即可禁用 DP 和梯度累加的平均,然后修改 loss_func。

每个 micro batch 都需要返回这个 micro batch 的轮数,最后框架会自动将所有轮数求和,作为分母。对于分子,需要除以这个轮次的token 数。

正确实现代码如下(loss_token_num, turn_num 是在构建 data 的时候构建的):

def loss_func(output_tensor, loss_mask, loss_token_num, turn_num):
    losses = output_tensor.view(-1).float()
    loss_mask = loss_mask.view(-1).float()
    loss_token_num = loss_token_num.view(-1).float()
    # label: [-100, -100, a, a, a, -100, b, b, -100, -100, c, c, c, -100, -100]
    # loss_mask: [0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0]
    # losses: [a0, a1, a2, a3, a4, b0, b1, b2, c0, c1, c2, c3, c4, d0, d1]
    # losses * loss_mask = [0, 0, a2, a3, a4, 0, b1, b2, 0, 0, c2, c3, c4, 0, 0]
    # loss_token_num: [3, 3, 3, 3, 3, 2, 2, 2, 3, 3, 3, 3, 3, 1, 1]
    # losses * loss_mask / loss_token_num = [0, 0, a2/3, a3/3, a4/3, 0, b1/2, b2/2, 0, 0, c2/3, c3/3, c4/3, 0, 0]
    # sum = 1/3 (a2 + a3 + a4) + 1/2 (b1 + b2) + 1/3 (c2 + c3 + c4)
    loss = torch.sum(losses * loss_mask / loss_token_num)
    loss_and_turn_num = torch.cat([loss.view(1), turn_num.view(1)])
    # Reduce loss for logging.
    loss_and_turn_num = loss_and_turn_num.clone().detach()
    torch.distributed.all_reduce(loss_and_turn_num, group=mpu.get_data_parallel_group())
    # 新版返回结构,开启 calculate_per_token_loss 开关后,返回三个值
    # 第一个是反向传播实际使用的 loss, 所有 packing 的 loss 求和
    # 第二个是 turn_num, 优化器状态更新时会使用对这个值求和然后缩放梯度
    # 第三个是用于日志打印的 loss, 包含两个值,第一个是所有 loss 求和作为分子,第二个是所有 turn_num 求和作为分母
    return loss, turn_num, {"lm loss": (loss_and_turn_num[0], loss_and_turn_num[1])}

5.总结

在 SFT 时,如果要加速,需要注意:

  • 不同样本之间是等价的;
  • 不同轮次之间也是等价的。

在合并多轮 / packing 时,需要修改 loss 计算方法,为每个 token 设置正确的权重,并且关闭 DP / 梯度累加的平均。​

本文转载自​丁师兄大模型​,作者:Ethan Yan ​​

收藏
回复
举报
回复
相关推荐