全参数微调(Full Parameter Fine-Tuning)的显存需求取决于多个因素,包括模型的大小、数据的批量大小(Batch Size)、优化器的状态存储以及是否使用混合精度训练等。以下是一个详细的分析:
模型参数大小
模型参数显存占用:模型的每个参数在显存中占用一定的空间。通常,单精度浮点数(FP32)占用4字节,半精度浮点数(FP16)占用2字节。
计算公式:
模型参数显存=模型参数数量×每个参数占用的字节数
示例:
如果模型有1.5亿个参数(如BERT-Base),使用FP32精度,显存占用为:
梯度存储
在反向传播中,每个参数的梯度也需要存储在显存中。
计算公式:
梯度显存=模型参数数量×每个参数占用的字节数
示例:
对于上述BERT-Base模型(FP32),梯度显存占用为:
优化器状态
常用的优化器(如Adam)会为每个参数存储额外的状态(如动量和方差估计)。
不同优化器的状态倍数如下:
AdamW (2 states): 8 Bytes per parameter
AdamW (bitsandbytes Quantized): 2 Bytes per parameter
SGD (1 state): 4 Bytes per parameter
计算公式:
优化器状态显存=模型参数数量×每个参数占用的字节数×优化器状态倍数
示例:
对于BERT-Base模型(FP32),优化器状态显存占用为:
激活值和临时变量
在前向和反向传播过程中,网络的激活值(中间层输出)和临时变量也会占用显存。
估算公式:
激活值显存≈模型参数数量×每个参数占用的字节数×2
示例:
对于BERT-Base模型(FP32),激活值显存占用为:
批量大小(Batch Size)
批量大小会显著影响显存占用。每个样本的输入、输出和中间激活值都需要存储。
估算公式:
Batch Size显存=Batch Size×(输入大小+输出大小+中间激活值大小)
示例:
假设输入为512个token的文本,每个token的嵌入维度为768(BERT-Base),Batch Size为32,则输入显存占用为:
总结公式
综合以上因素,全参数微调的显存需求估算公式为:
总显存需求=(模型参数显存+梯度显存+优化器状态显存+激活值显存)×精度倍数+Batch Size显存
示例:BERT-Base全参数微调(FP32)
- 模型参数显存:600MB
- 梯度显存:600MB
- 优化器状态显存:1200MB
- 激活值显存:1200MB
- Batch Size显存:假设为100MB(根据输入大小和Batch Size估算)
最终总显存需求:
600+600+1200+1200+100=3700MB≈3.7GB