在 CIFAR10 数据集上微调 Vision Transformer (ViT)

开发
在这篇文章中,我们将对预训练的 Vision Transformer (ViT) 模型进行微调,以适应 CIFAR10 数据集。

在这篇文章中,我们将对预训练的 Vision Transformer (ViT) 模型进行微调,以适应 CIFAR10 数据集。在之前的文章《在 CIFAR10 数据集上训练 Vision Transformer (ViT)》中,我们从头开始创建了一个 ViT 模型,并在 CIFAR10 数据集上进行了训练。然而,模型的准确率仅达到了67%,没有进行刻意的超参数微调。这是意料之中的,因为 ViT 模型的原始创建者指出,这些模型在小数据集上训练时,性能与卷积神经网络(CNNs)相比是中等的。

然而,当在大型数据集上进行扩展时,它们开始与 CNNs 相当,甚至更好。这就是为什么建议对已经在大型数据集(如 ImageNet)上预训练的 ViT 模型进行微调。而这正是我们在这篇文章中将要做的事情。

训练循环

我们首先编写训练和测试任何模型在 CIFAR10 数据集上的样板代码。您会注意到,我们在训练和测试图像转换中将图像大小调整为224,注意 CIFAR10 的原始图像大小是32。这是因为我们将要使用的模型需要输入大小为224,因为它已经在 ImageNet 上进行了训练。

transform_train = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_set = CIFAR10(root='./datasets', train=True, download=True, transform=transform_train)
test_set = CIFAR10(root='./datasets', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_set, shuffle=True, batch_size=64)
test_loader = DataLoader(test_set, shuffle=False, batch_size=64)
n_epochs = 10
lr = 0.0001

optimizer = Adam(model.parameters(), lr=lr)
criterion = CrossEntropyLoss()

for epoch in range(n_epochs):
    train_loss = 0.0
    for i,batch in enumerate(train_loader):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = model(x)
        loss = criterion(y_hat, y)

        batch_loss = loss.detach().cpu().item()
        train_loss += batch_loss / len(train_loader)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i%100==0:
          print(f"Batch {i}/{len(train_loader)} loss: {batch_loss:.03f}")

    print(f"Epoch {epoch + 1}/{n_epochs} loss: {train_loss:.03f}")

加载模型

现在我们必须从 torchvision.models 加载 ViT_b_16 模型。torchvision 中可用的所有 ViT 模型都列在以下链接中。如果您查看链接,您会发现有几个带有 b、l 和 h 标签的模型。这些标签对应于我们拥有的基础、大型和巨型模型大小。这些模型的架构正是在第一篇 ViT 论文中发表的,标题为《 An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》(原文链接:https://arxiv.org/abs/2010.11929)。与这些标签相关的数字,如16、32和14,对应于模型使用的补丁大小。所有这些模型都已在 ImageNet 上进行了训练。我们首先加载模型。默认提供的模型不是预训练的,为了确保我们加载一个预训练的模型,我们必须将权重参数传递为 ViT_B_16_Weights.IMAGENET1K_V1。

from torchvision.models import ViT_B_16_Weights, vit_b_16

model = vit_b_16(ViT_B_16_Weights.IMAGENET1K_V1)

默认情况下,这个模型输出来自1000个类别的对数几率,因为它已经在 ImageNet 上进行了训练。然而,我们的数据集只包含10个类别。因此,我们需要将这个模型的头部从1000个对数几率更改为10个。加载模型的外层是“heads”层,这是一个序列层,其中只包含一个线性层。为了适应模型,我们只需在保留层的输入特征的同时,将一个新的线性层分配给“heads”层,并将外部特征替换为10。

model = vit_b_16(ViT_B_16_Weights.IMAGENET1K_V1)

model.heads = nn.Sequential(
    nn.Linear(model.heads.head.in_features, 10)
)

我们不是训练或加载模型中的变换器块,我们可以冻结所有层,除了最后一个变换器层。通过这样做,我们使微调过程的计算强度降低。我们最后将模型移动到 GPU 设备,并使用之前的训练循环进行训练。

model = vit_b_16(ViT_B_16_Weights.IMAGENET1K_V1)

model.heads = nn.Sequential(
    nn.Linear(model.heads.head.in_features, 10)
)

# Freeze all layers
for param in model.parameters():
    param.requires_grad = False

# Unfreeze the last encoder layer and the head
for param in model.encoder.layers[-1].parameters():
    param.requires_grad = True
for param in model.heads.parameters():
    param.requires_grad = True

测试循环

我们最后在 CIFAR10 的测试数据集上测试我们的模型。您会发现,即使只训练了一个周期,模型也达到了非常高的准确率。这是因为在模型在 ImageNet 上训练时所打造的强大的特征。

with torch.no_grad():
    correct, total = 0, 0
    test_loss = 0.0
    for batch in tqdm(test_loader, desc="Testing"):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = model(x)
        loss = criterion(y_hat, y)
        test_loss += loss.detach().cpu().item() / len(test_loader)

        correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
        total += len(x)
    print(f"Test loss: {test_loss:.2f}")
    print(f"Test accuracy: {correct / total * 100:.2f}%")
责任编辑:赵宁宁 来源: 小白玩转Python
相关推荐

2024-11-12 06:23:50

ViTCIFAR10模型

2024-07-17 09:27:28

2022-07-06 13:13:36

SWIL神经网络数据集

2024-11-21 16:06:02

2023-08-14 07:42:01

模型训练

2023-02-02 13:22:40

AICIFAR数据集

2018-04-11 09:30:41

深度学习

2024-06-20 08:52:10

2022-02-08 15:43:08

AITransforme模型

2022-05-30 11:39:55

论文谷歌AI

2023-06-02 15:47:49

2022-10-28 15:08:30

DeepMind数据

2023-09-12 13:59:41

OpenAI数据集

2024-12-05 08:30:00

2022-12-28 15:10:39

LinuxNginx服务器

2021-07-13 17:59:13

人工智能机器学习技术

2024-11-29 16:49:23

2023-08-04 13:34:00

人工智能深度学习

2023-12-01 16:23:52

大数据人工智能

2012-11-07 09:55:14

IE10Windows 8
点赞
收藏

51CTO技术栈公众号