SFT loss计算的那些坑,完美避开!!!
SFT 可以说是 LLM 的基本操作了,如果只是想把 SFT 跑起来是非常简单的,只需要构造 input_ids 和 labels,然后就可以把训练跑起来。然而,这样的训练效率实际上非常低。
所以在训练时,通常有两个加速方法:
- 多轮合并
- packing
无论是哪种方法,加速后都需要保证 loss 和原来是等价的。本文主要介绍这两种加速方法,以及 loss 计算时遇到的问题。
1.多轮合并
假设我们有一个对话,其中 user 和 bot 交互了 3 轮,我们可以构建三个样本:
input_ids 就是对应的 token id,labels 输入部分(白色)使用 -100,输出部分(绿色)使用 input_ids。
这样计算的 loss 可以表示为:
其中 l_i 表示第 i 个样本的 loss,n_i 表示第 i 个样本输出的 token 数量 (对应绿色部分)。
这样除了训练比较慢,没有什么别的问题。因为不同样本之间有很多重复计算的前缀,实际上这部分计算一次就行。
2.加速计算
如果将三轮三个样本合并成一个样本,可以尝试这种构造形式。
因为存在 causal attention mask,所以每个 token 只能看到前面的 token,计算上和之前是等价的。
但是这样有一个坑:如果还是按照刚才的方式构建 input_ids 和 labels (白色用-100,绿色用input_ids)loss 计算是有问题的。
pytorch CrossEntropyLoss 计算 loss 按照下面的方法,默认是"mean"。
所以我们会得到这样的 loss:
当不同轮次的输出长度不同时,这种 loss 和刚才的不等价。多轮对话中输出较短的权重被降低了,输出较长的被提高了。所以结果就是短输出的数据训练不够充分。
3.Packing
假设我们有两个对话,第一个是单轮对话,第二个是三轮对话。
正确的 loss:
其中 l_ij 表示第 i 个样本第 j 轮对话的 loss,n_ij 同理。
问题:真实场景中的训练集文本长度长短不一,Padding 后矩阵非常稀疏,只有不到一半是有效计算。
加速计算:
将所有样本拼接成一条,并且加入 attention mask 保证后面的样本看不见前面的 token。
比如在 flash attention 中,可以调用 flash_attn_varlen_qkvpacked_func,并传入 cu_seqlens 参数。
和之前一样,如果不修改 loss 计算方法,packing 的样本之间会存在因为长度不同,导致训练不充分的问题。
4.正确方法
一般情况下,loss 计算会经历三次平均:
- micro batch 维度,分母是这个 micro batch 中的所有 label 不是 -100 的 token 数
- DP 维度,分母是 DP size (和GPU数量相关)
- 梯度累加维度,分母是梯度累加数
我们这里要做的就是禁用这三个平均,统一用这个 global batch 的对话轮数作为分母。
在新版 megatron 框架中,开启开关 --calculate-per-token-loss 即可禁用 DP 和梯度累加的平均,然后修改 loss_func。
每个 micro batch 都需要返回这个 micro batch 的轮数,最后框架会自动将所有轮数求和,作为分母。对于分子,需要除以这个轮次的token 数。
正确实现代码如下(loss_token_num, turn_num 是在构建 data 的时候构建的):
def loss_func(output_tensor, loss_mask, loss_token_num, turn_num):
losses = output_tensor.view(-1).float()
loss_mask = loss_mask.view(-1).float()
loss_token_num = loss_token_num.view(-1).float()
# label: [-100, -100, a, a, a, -100, b, b, -100, -100, c, c, c, -100, -100]
# loss_mask: [0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0]
# losses: [a0, a1, a2, a3, a4, b0, b1, b2, c0, c1, c2, c3, c4, d0, d1]
# losses * loss_mask = [0, 0, a2, a3, a4, 0, b1, b2, 0, 0, c2, c3, c4, 0, 0]
# loss_token_num: [3, 3, 3, 3, 3, 2, 2, 2, 3, 3, 3, 3, 3, 1, 1]
# losses * loss_mask / loss_token_num = [0, 0, a2/3, a3/3, a4/3, 0, b1/2, b2/2, 0, 0, c2/3, c3/3, c4/3, 0, 0]
# sum = 1/3 (a2 + a3 + a4) + 1/2 (b1 + b2) + 1/3 (c2 + c3 + c4)
loss = torch.sum(losses * loss_mask / loss_token_num)
loss_and_turn_num = torch.cat([loss.view(1), turn_num.view(1)])
# Reduce loss for logging.
loss_and_turn_num = loss_and_turn_num.clone().detach()
torch.distributed.all_reduce(loss_and_turn_num, group=mpu.get_data_parallel_group())
# 新版返回结构,开启 calculate_per_token_loss 开关后,返回三个值
# 第一个是反向传播实际使用的 loss, 所有 packing 的 loss 求和
# 第二个是 turn_num, 优化器状态更新时会使用对这个值求和然后缩放梯度
# 第三个是用于日志打印的 loss, 包含两个值,第一个是所有 loss 求和作为分子,第二个是所有 turn_num 求和作为分母
return loss, turn_num, {"lm loss": (loss_and_turn_num[0], loss_and_turn_num[1])}
5.总结
在 SFT 时,如果要加速,需要注意:
- 不同样本之间是等价的;
- 不同轮次之间也是等价的。
在合并多轮 / packing 时,需要修改 loss 计算方法,为每个 token 设置正确的权重,并且关闭 DP / 梯度累加的平均。