机器学习 | 从0开发大模型之模型预训练

人工智能 机器学习
在训练过程中,通常会使用 scaler.scale(loss).backward() 来计算缩放后的损失的梯度,然后使用 scaler.step(optimizer) 来更新模型参数,最后使用 scaler.update() 来更新缩放因子,这样可以确保训练过程的稳定性和效率。

1、参数初始化

初始化参数模板:

from transformers import PretrainedConfig

class MyPretrainConfig(PretrainedConfig):
    model_type = "myllm"

    def __init__(
            self,
            dim: int = 512,
            n_layers: int = 8,
            n_heads: int = 16,
            n_kv_heads: int = 8,
            vocab_size: int = 6400,
            hidden_dim: int = None,
            multiple_of: int = 64,
            norm_eps: float = 1e-5,
            max_seq_len: int = 512,
            dropout: float = 0.0,
            flash_attn: bool = True,
            use_moe: bool = False,
            num_experts_per_tok=2,
            n_routed_experts=4,
            n_shared_experts: bool = True,
            scoring_func='softmax',
            aux_loss_alpha=0.01,
            seq_aux=True,
            norm_topk_prob=True,
            **kwargs,
    ):
        self.dim = dim
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.multiple_of = multiple_of
        self.norm_eps = norm_eps
        self.max_seq_len = max_seq_len
        self.dropout = dropout
        self.flash_attn = flash_attn
        self.num_experts_per_tok = num_experts_per_tok  # 每个token选择的专家数量
        self.n_routed_experts = n_routed_experts        # 总的专家数量
        self.n_shared_experts = n_shared_experts        # 共享专家
        self.scoring_func = scoring_func                # 评分函数,默认为'softmax'
        self.aux_loss_alpha = aux_loss_alpha            # 辅助损失的alpha参数
        self.seq_aux = seq_aux                          # 是否在序列级别上计算辅助损失
        self.norm_topk_prob = norm_topk_prob            # 是否标准化top-k概率
        super().__init__(**kwargs)

这里依赖 transformers 库的 PretrainedConfig,其中 MyPretrainConfig 参数如下:

  • dim: int = 512:模型的维度,默认为 512
  • n_layers: int = 8:模型的层数,默认为 8
  • n_heads: int = 16:注意力头的数量,默认为 16
  • n_kv_heads: int = 8:键值对的头数,默认为 8
  • vocab_size: int = 6400:词汇表的大小,默认为 6400
  • hidden_dim: int = None:隐藏层的维度,默认为 None,可以根据需要设置
  • multiple_of: int = 64:模型维度必须是这个值的倍数,默认为 64
  • norm_eps: float = 1e-5:归一化的 epsilon 值,默认为 1e-5
  • max_seq_len: int = 512:最大序列长度,默认为 512
  • dropout: float = 0.0:dropout 概率,默认为 0.0
  • flash_attn: bool = True:是否使用快速注意力机制,默认为 True
  • num_experts_per_tok=2:每个 token 选择的专家数量,默认为 2
  • n_routed_experts=4:总的专家数量,默认为 4
  • n_shared_experts: bool = True:是否使用共享专家,默认为 True
  • scoring_func='softmax':评分函数,默认为 'softmax'
  • aux_loss_alpha=0.01:辅助损失的 alpha 参数,默认为 0.01
  • seq_aux=True:是否在序列级别上计算辅助损失,默认为 True
  • norm_topk_prob=True:是否标准化 top-k 概率,默认为 True
  • **kwargs:接收其他关键字参数,传递给父类的构造函数

PretrainedConfig 提供预训练的参数模板,由于每个模型都是不一样的,所以一般做成配置文件携带模型一起发布。

2、加载预处理的数据

加载上一篇文章已经处理好的预处理数据,代码如下:

data_path_list = [f'./pretrain_data.bin']
train_ds = PretrainDataset(data_path_list, max_length=max_seq_len, memmap=True)
train_sampler = None
num_workers = 16  # 可以根据系统的 CPU 核心数来调整
train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    pin_memory=True,
    drop_last=False,
    shuffle=False,
    num_workers=num_workers,
    sampler=train_sampler
)

其中 PretrainDataset 是加载代码,主要目的是将数据转换到内存中,方便 DataLoader 获取:

class PretrainDataset(Dataset):
    def __init__(self, data_path_lst, max_length=512, memmap=False):
        super().__init__()
        if memmap:
            with open(data_path_lst[0], 'r') as f:
                nbytes = f.seek(0, 2)
                flen = f.tell() // np.dtype('uint16').itemsize
            self.data = np.memmap(data_path_lst[0], dtype=np.dtype('uint16'), shape=(flen // max_length, max_length))
        else:
            data_lst = []
            for data_path in data_path_lst:
                with open(data_path, 'rb') as f:
                    data = np.fromfile(f, dtype=np.uint16)
                    data_lst.append(data)
            data = np.concatenate(data_lst)
            data = data[:max_length * int(len(data) / max_length)]
            self.data = data.reshape(-1, max_length)
        print("memmap:{} train data.shape:{}".format(memmap, self.data.shape))
        print("downloading finished.....")

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index: int):
        sample = self.data[index]
        X = np.array(sample[:-1]).astype(np.int64)
        Y = np.array(sample[1:]).astype(np.int64)

        return torch.from_numpy(X), torch.from_numpy(Y)

其中 Datasetfrom torch.utils.data import Dataset 通用代码。

3、初始化模型

初始化模型,借鉴 llama2.c 的代码,路径:https://github.com/karpathy/llama2.c/blob/master/model.py,使用 Transformerdecoder 阶段,即 Decoder-Only,主要是如下逻辑:

  • 初始化:创建tok_embeddings,dropout,layers和CausalLMOutputWithPast等
  • forward:获取迭代输出的结果

具体代码如下:

class Transformer(PreTrainedModel):
    last_loss: Optional[torch.Tensor]

    def __init__(self, params: MyPretrainConfig):
        super().__init__(params)
        self.params = params
        self.vocab_size = params.vocab_size
        self.n_layers = params.n_layers

        self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
        self.dropout = nn.Dropout(params.dropout)
        self.layers = torch.nn.ModuleList()
        for layer_id in range(params.n_layers):
            self.layers.append(TransformerBlock(layer_id, params))
        self.norm = RMSNorm(params.dim, eps=params.norm_eps)
        self.output = nn.Linear(params.dim, params.vocab_size, bias=False)

        # share the unembedding parameters with the embedding parameters
        self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying

        # some useful precompute for the RoPE relative positional embeddings
        freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
        self.register_buffer("freqs_cos", freqs_cos, persistent=False)
        self.register_buffer("freqs_sin", freqs_sin, persistent=False)

        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers))

        # Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor.
        self.last_loss = None
        self.OUT = CausalLMOutputWithPast()

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:
        _bsz, seqlen = tokens.shape
        h = self.tok_embeddings(tokens)
        h = self.dropout(h)
        freqs_cos = self.freqs_cos[:seqlen]
        freqs_sin = self.freqs_sin[:seqlen]

        for layer in self.layers:
            h = layer(h, freqs_cos, freqs_sin)
        h = self.norm(h)

        if targets is not None:
            # if we are given some desired targets also calculate the loss
            logits = self.output(h)
            self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            # inference-time mini-optimization: only forward the output on the very last position
            logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim
            self.last_loss = None

        self.OUT.__setitem__('logits', logits)
        self.OUT.__setitem__('last_loss', self.last_loss)
        return self.OUT
...

然后通过上述模型初始化,并打印模型:

def init_model():
    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    model = Transformer(lm_config).to(device)
    print(f'LLM总参数量:{count_parameters(model) / 1e6:.3f} 百万')
    return model

model = init_model()
print(model)

获取输出结果如下:

Transformer(
  (tok_embeddings): Embedding(6400, 512)
  (dropout): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0-7): 8 x TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=512, out_features=512, bias=False)
        (wk): Linear(in_features=512, out_features=256, bias=False)
        (wv): Linear(in_features=512, out_features=256, bias=False)
        (wo): Linear(in_features=512, out_features=512, bias=False)
        (attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_dropout): Dropout(p=0.0, inplace=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=512, out_features=1408, bias=False)
        (w2): Linear(in_features=1408, out_features=512, bias=False)
        (w3): Linear(in_features=512, out_features=1408, bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=512, out_features=6400, bias=False)
)

模型初始化这里就不详细说了,这个系列出一篇文章具体分析 llama2.c 源码,讲述是如何实现模型创建的。

4、选择optimizer

执行模型初始化后则选择优化器,这里代码如下:

scaler = torch.cuda.amp.GradScaler(enabled=(dtype == dtype))
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

4.1 GradScaler

GradScaler 在 PyTorch 中的作用是用于自动混合精度(Automatic Mixed Precision, AMP)训练时的梯度缩放,具体来说,它的主要功能包括:

  • 防止梯度下溢:在使用混合精度训练时,模型的权重和激活值可能会使用较低的精度(如半精度浮点数,FP16)。这可能导致在反向传播过程中计算出的梯度值过小,从而出现梯度下溢(即梯度变为零),GradScaler 会自动调整梯度的缩放因子,以确保梯度在更新时不会下溢;
  • 提高训练速度:使用混合精度可以减少内存使用和计算时间,从而加速训练过程,GradScaler 通过动态调整缩放因子,帮助在保持数值稳定性的同时,充分利用混合精度的优势;
  • 简化代码:使用 GradScaler 可以简化混合精度训练的实现,开发者不需要手动管理缩放因子和反缩放操作;

在训练过程中,通常会使用 scaler.scale(loss).backward() 来计算缩放后的损失的梯度,然后使用 scaler.step(optimizer) 来更新模型参数,最后使用 scaler.update() 来更新缩放因子,这样可以确保训练过程的稳定性和效率。

4.2 optimizer

optimizer 在深度学习中是一个非常重要的组件,其主要作用是更新模型的参数,以最小化损失函数,具体来说,optimizer 的作用包括:

  • 参数更新:优化器根据计算得到的梯度信息来更新模型的参数(权重和偏置),通过调整这些参数,优化器试图使模型在训练数据上的表现更好;
  • 控制学习率:优化器通常会使用学习率(learning rate)来控制每次参数更新的幅度。学习率是一个超参数,决定了模型在每次迭代中向最优解移动的步长;
  • 实现不同的优化算法:PyTorch 提供了多种优化算法(如 SGD、Adam、RMSprop 等),每种算法都有其独特的更新规则和策略。选择合适的优化器可以影响模型的收敛速度和最终性能;
  • 处理动量和自适应学习率:一些优化器(如 Adam 和 RMSprop)使用动量和自适应学习率的策略来加速收敛和提高稳定性。这些策略可以帮助优化器在训练过程中更有效地探索参数空间;
  • 支持正则化:某些优化器可以集成正则化技术(如 L2 正则化),以防止模型过拟合;

在下面的迭代训练中,主要作用是根据损失值调整优化器参数:

# 反向传播
scaler.scale(loss).backward()

# 梯度剪裁和更新参数
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()

# 清零梯度
optimizer.zero_grad(set_to_none=True)

5、迭代训练

上述预处理数据加载完,模型执行了初始化,然后优化器也初始化后,就可以进行迭代训练了,不过迭代训练最重要的是设置学习率,根据loss动态调整参数,代码如下:

for epoch in range(epochs):
    start_time = time.time()

    for step, (X, Y) in enumerate(train_loader):
        X = X.to(device)
        Y = Y.to(device)

        # 设置学习率
        lr = get_lr(epoch * iter_per_epoch + step, epochs * iter_per_epoch)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        # 前向传播和损失计算
        with ctx:
            out = model(X, Y)
            loss = out.last_loss

        # 反向传播
        scaler.scale(loss).backward()

        # 梯度剪裁和更新参数
        if (step + 1) % accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

        # 清零梯度
        optimizer.zero_grad(set_to_none=True)

        if step % 100 == 0:
            spend_time = time.time() - start_time
            print(
                'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.7f} epoch_Time:{}min:'.format(
                    epoch,
                    epochs,
                    step,
                    iter_per_epoch,
                    loss.item(),
                    optimizer.param_groups[-1]['lr'],
                    spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
            model.eval()
            ckp = f'{save_dir}/pretrain_{lm_config.dim}.pth'
            state_dict = model.state_dict()
            torch.save(state_dict, ckp)
            model.train()
  • out = model(X, Y) 前向传播,计算输出
  • scaler.scale(loss).backward() 反向传播,计算梯度,执行 accumulation_steps 后更新梯度
  • model.eval()model.train() 分别是模型评估和训练,并保存当前模型到指定的文件夹

本人在T4的GPU上,跑了30+小时完成迭代训练,如果使用CPU时间会X4,我在附录中放了完整的代码,有兴趣的可以跑一下。

附录

完成代码:

import os
import time
import math
import warnings
import inspect
import numpy as np
import torch
from torch import optim
from torch.utils.data import DataLoader
from contextlib import nullcontext
from model.model import Transformer
from torch.utils.data import Dataset
from transformers import PretrainedConfig
from typing import Any, Optional, Tuple
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
os.environ["TOKENIZERS_PARALLELISM"] = "false"

warnings.filterwarnings('ignore')
basepath = "../datasets"

class MyPretrainConfig(PretrainedConfig):
    model_type = "myllm"

    def __init__(
            self,
            dim: int = 512,
            n_layers: int = 8,
            n_heads: int = 16,
            n_kv_heads: int = 8,
            vocab_size: int = 6400,
            hidden_dim: int = None,
            multiple_of: int = 64,
            norm_eps: float = 1e-5,
            max_seq_len: int = 512,
            dropout: float = 0.0,
            flash_attn: bool = True,
            num_experts_per_tok=2,
            n_routed_experts=4,
            n_shared_experts: bool = True,
            scoring_func='softmax',
            aux_loss_alpha=0.01,
            seq_aux=True,
            norm_topk_prob=True,
            **kwargs,
    ):
        self.dim = dim
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.multiple_of = multiple_of
        self.norm_eps = norm_eps
        self.max_seq_len = max_seq_len
        self.dropout = dropout
        self.flash_attn = flash_attn
        self.num_experts_per_tok = num_experts_per_tok  # 每个token选择的专家数量
        self.n_routed_experts = n_routed_experts        # 总的专家数量
        self.n_shared_experts = n_shared_experts        # 共享专家
        self.scoring_func = scoring_func                # 评分函数,默认为'softmax'
        self.aux_loss_alpha = aux_loss_alpha            # 辅助损失的alpha参数
        self.seq_aux = seq_aux                          # 是否在序列级别上计算辅助损失
        self.norm_topk_prob = norm_topk_prob            # 是否标准化top-k概率
        super().__init__(**kwargs)

class PretrainDataset(Dataset):
    def __init__(self, data_path_lst, max_length=512, memmap=False):
        super().__init__()
        if memmap:
            with open(data_path_lst[0], 'r') as f:
                nbytes = f.seek(0, 2)
                flen = f.tell() // np.dtype('uint16').itemsize
            self.data = np.memmap(data_path_lst[0], dtype=np.dtype('uint16'), shape=(flen // max_length, max_length))
        else:
            data_lst = []
            for data_path in data_path_lst:
                with open(data_path, 'rb') as f:
                    data = np.fromfile(f, dtype=np.uint16)
                    data_lst.append(data)
            data = np.concatenate(data_lst)
            data = data[:max_length * int(len(data) / max_length)]
            self.data = data.reshape(-1, max_length)
        print("memmap:{} train data.shape:{}".format(memmap, self.data.shape))
        print("downloading finished.....")

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index: int):
        sample = self.data[index]
        X = np.array(sample[:-1]).astype(np.int64)
        Y = np.array(sample[1:]).astype(np.int64)

        return torch.from_numpy(X), torch.from_numpy(Y)
    
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cos = torch.cos(freqs)  # real part
    freqs_sin = torch.sin(freqs)  # imaginary part
    return freqs_cos, freqs_sin

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(shape)

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cos: torch.Tensor,
    freqs_sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:

    # reshape xq and xk to match the complex representation
    xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
    xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)

    # reshape freqs_cos and freqs_sin for broadcasting
    freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
    freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)

    # apply rotation using real numbers
    xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
    xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
    xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
    xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos

    # flatten last two dimensions
    xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
    xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)

    return xq_out.type_as(xq), xk_out.type_as(xk)

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    )

class Attention(nn.Module):
    def __init__(self, args: MyPretrainConfig):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        assert args.n_heads % self.n_kv_heads == 0
        model_parallel_size = 1
        self.n_local_heads = args.n_heads // model_parallel_size
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads
        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
        self.attn_dropout = nn.Dropout(args.dropout)
        self.resid_dropout = nn.Dropout(args.dropout)
        self.dropout = args.dropout

        # use flash attention or a manual implementation?
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
            mask = torch.triu(mask, diagonal=1)
            self.register_buffer("mask", mask)

    def forward(
        self,
        x: torch.Tensor,
        freqs_cos: torch.Tensor,
        freqs_sin: torch.Tensor,
    ):
        bsz, seqlen, _ = x.shape

        # QKV
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

        # RoPE relative positional embeddings
        xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)

        # grouped multiquery attention: expand out keys and values
        xk = repeat_kv(xk, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
        xv = repeat_kv(xv, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)

        # make heads into a batch dimension
        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        xk = xk.transpose(1, 2)
        xv = xv.transpose(1, 2)

        # flash implementation
        if self.flash:
            output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
        else:
            # manual implementation
            scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
            assert hasattr(self, 'mask')
            scores = scores + self.mask[:, :, :seqlen, :seqlen]   # (bs, n_local_heads, seqlen, cache_len + seqlen)
            scores = F.softmax(scores.float(), dim=-1).type_as(xq)
            scores = self.attn_dropout(scores)
            output = torch.matmul(scores, xv)  # (bs, n_local_heads, seqlen, head_dim)

        # restore time as batch dimension and concat heads
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

        # final projection into the residual stream
        output = self.wo(output)
        output = self.resid_dropout(output)
        return output

class FeedForward(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
        super().__init__()
        if hidden_dim is None:
            hidden_dim = 4 * dim
            hidden_dim = int(2 * hidden_dim / 3)
            hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))

class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: MyPretrainConfig):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)
        self.feed_forward = FeedForward(
            dim=args.dim,
            hidden_dim=args.hidden_dim,
            multiple_of=args.multiple_of,
            dropout=args.dropout,
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

    def forward(self, x, freqs_cos, freqs_sin):
        h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out

class Transformer(PreTrainedModel):
    last_loss: Optional[torch.Tensor]

    def __init__(self, params: MyPretrainConfig):
        super().__init__(params)
        self.params = params
        self.vocab_size = params.vocab_size
        self.n_layers = params.n_layers

        self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
        self.dropout = nn.Dropout(params.dropout)
        self.layers = torch.nn.ModuleList()
        for layer_id in range(params.n_layers):
            self.layers.append(TransformerBlock(layer_id, params))
        self.norm = RMSNorm(params.dim, eps=params.norm_eps)
        self.output = nn.Linear(params.dim, params.vocab_size, bias=False)

        # share the unembedding parameters with the embedding parameters
        self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying

        # some useful precompute for the RoPE relative positional embeddings
        freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
        self.register_buffer("freqs_cos", freqs_cos, persistent=False)
        self.register_buffer("freqs_sin", freqs_sin, persistent=False)

        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers))

        # Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor.
        self.last_loss = None
        self.OUT = CausalLMOutputWithPast()

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:
        _bsz, seqlen = tokens.shape
        h = self.tok_embeddings(tokens)
        h = self.dropout(h)
        freqs_cos = self.freqs_cos[:seqlen]
        freqs_sin = self.freqs_sin[:seqlen]

        for layer in self.layers:
            h = layer(h, freqs_cos, freqs_sin)
        h = self.norm(h)

        if targets is not None:
            # if we are given some desired targets also calculate the loss
            logits = self.output(h)
            self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            # inference-time mini-optimization: only forward the output on the very last position
            logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim
            self.last_loss = None

        self.OUT.__setitem__('logits', logits)
        self.OUT.__setitem__('last_loss', self.last_loss)
        return self.OUT

    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
        # start with all of the candidate parameters
        param_dict = {pn: p for pn, p in self.named_parameters()}
        # filter out those that do not require grad
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
        print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
        # Create AdamW optimizer and use the fused version if it is available
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == 'cuda'
        extra_args = dict(fused=True) if use_fused else dict()
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
        print(f"using fused AdamW: {use_fused}")

        return optimizer

    def estimate_mfu(self, fwdbwd_per_iter, dt):
        """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
        # first estimate the number of flops we do per iteration.
        # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
        N = sum(p.numel() for p in self.parameters())
        cfg = self.params
        L, H, Q, T = cfg.n_layers, cfg.n_heads, cfg.dim//cfg.n_heads, cfg.max_seq_len
        flops_per_token = 6*N + 12*L*H*Q*T
        flops_per_fwdbwd = flops_per_token * T
        flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
        # express our flops throughput as ratio of A100 bfloat16 peak flops
        flops_achieved = flops_per_iter * (1.0/dt) # per second
        flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
        mfu = flops_achieved / flops_promised
        return mfu

    @torch.inference_mode()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        """
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
        Also note this is a super inefficient version of sampling with no key/value cache.
        """
        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at block_size
            idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
            # forward the model to get the logits for the index in the sequence
            logits = self(idx_cond)
            logits = logits[:, -1, :] # crop to just the final time step
            if temperature == 0.0:
                # "sample" the single most likely index
                _, idx_next = torch.topk(logits, k=1, dim=-1)
            else:
                # pluck the logits at the final step and scale by desired temperature
                logits = logits / temperature
                # optionally crop the logits to only the top k options
                if top_k is not None:
                    v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                    logits[logits < v[:, [-1]]] = -float('Inf')
                # apply softmax to convert logits to (normalized) probabilities
                probs = F.softmax(logits, dim=-1)
                idx_next = torch.multinomial(probs, num_samples=1)
            # append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

def get_lr(it, all):
    warmup_iters = 0
    lr_decay_iters = all
    min_lr = learning_rate / 10

    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    if it > lr_decay_iters:
        return min_lr
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return min_lr + coeff * (learning_rate - min_lr)

def init_model():
    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    model = Transformer(lm_config).to(device)
    print(f'LLM总参数量:{count_parameters(model) / 1e6:.3f} 百万')
    return model


if __name__ == "__main__":
    # -----------------------------------------------------------------------------
    lm_config = MyPretrainConfig()
    max_seq_len = lm_config.max_seq_len
    out_dir = 'out'
    epochs = 20             # 训练轮数
    batch_size = 8          # batch_size
    learning_rate = 1e-4    # 学习率
    device = 'cuda:0'       # or cpu
    dtype = 'bfloat16'
    save_dir = os.path.join(out_dir)
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(out_dir, exist_ok=True)
    tokens_per_iter = batch_size * max_seq_len
    torch.manual_seed(1337)
    device_type = device if "cuda" in device else "cpu"
    print(f"device_type: {device_type}")
    ctx = (
        nullcontext()
        if device_type == "cpu"
        else torch.cuda.amp.autocast()
    )
    # -----------------------------------------------------------------------------

    # -----init dataloader------
    data_path_list = [f'{basepath}/pretrain_data.bin']
    train_ds = PretrainDataset(data_path_list, max_length=max_seq_len, memmap=True)
    train_sampler = None
    num_workers = 16  # 可以根据系统的 CPU 核心数来调整
    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        pin_memory=True,
        drop_last=False,
        shuffle=False,
        num_workers=num_workers,
        sampler=train_sampler
    )

    # init model
    model = init_model()
    print(model)
    scaler = torch.cuda.amp.GradScaler(enabled=(dtype == dtype))
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # training loop
    accumulation_steps = 8
    iter_per_epoch = len(train_loader)
    for epoch in range(epochs):
        start_time = time.time()

        for step, (X, Y) in enumerate(train_loader):
            X = X.to(device)
            Y = Y.to(device)

            # 设置学习率
            lr = get_lr(epoch * iter_per_epoch + step, epochs * iter_per_epoch)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # 前向传播和损失计算
            with ctx:
                out = model(X, Y)
                loss = out.last_loss

            # 反向传播
            scaler.scale(loss).backward()

            # 梯度剪裁和更新参数
            if (step + 1) % accumulation_steps == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()

            # 清零梯度
            optimizer.zero_grad(set_to_none=True)

            if step % 100 == 0:
                spend_time = time.time() - start_time
                print(
                    'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.7f} epoch_Time:{}min:'.format(
                        epoch,
                        epochs,
                        step,
                        iter_per_epoch,
                        loss.item(),
                        optimizer.param_groups[-1]['lr'],
                        spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
                model.eval()
                ckp = f'{save_dir}/pretrain_{lm_config.dim}.pth'
                state_dict = model.state_dict()
                torch.save(state_dict, ckp)
                model.train()

参考

(1)https://github.com/jingyaogong/minimind?tab=readme-ov-file#%E6%95%B0%E6%8D%AE%E9%9B%86%E4%B8%8B%E8%BD%BD%E5%9C%B0%E5%9D%80
(2)https://github.com/karpathy/llama2.c/blob/master/train.py

责任编辑:武晓燕 来源: 周末程序猿
相关推荐

2020-08-10 15:05:02

机器学习人工智能计算机

2019-05-07 11:18:51

机器学习人工智能计算机

2022-03-28 09:00:00

SQL数据库机器学习

2017-03-24 15:58:46

互联网

2017-12-26 13:53:31

深度学习迁移学习

2023-02-28 13:09:53

训练模型

2023-06-24 19:59:40

2024-01-03 18:53:13

语言模型LLM

2017-07-11 10:19:24

浅层模型机器学习优化算法

2022-07-07 14:06:39

LiBai模型库

2022-09-06 08:00:00

机器学习金融数据科学

2023-05-19 07:25:34

2018-11-07 09:00:00

机器学习模型Amazon Sage

2022-05-18 16:24:36

PythonPyCaret机器学习

2024-09-27 10:31:22

2024-09-29 13:10:08

2017-10-09 12:55:29

机器学习KaggleStacking

2017-08-25 14:05:01

机器学习算法模型
点赞
收藏

51CTO技术栈公众号