基于 Faster ViT 进行图像分类

开发 人工智能
Faster Vision Transformer(FVT)是Vision Transformer(ViT)架构的一个变体,这是一种为计算机视觉任务设计的神经网络。

Faster Vision Transformer(FVT)是Vision Transformer(ViT)架构的一个变体,这是一种为计算机视觉任务设计的神经网络。FVT 是原始 ViT 模型的更快、更高效版本,原始模型由 Dosovitskiy 等人在 2020 年的论文 “一幅图像值 16x16 个词:用于大规模图像识别的转换器” 中引入。

FVT 的关键特性

  • 高效架构:FVT 旨在比原始 ViT 模型更快、更高效。它通过减少参数数量和计算复杂性,同时保持类似的性能来实现这一点。
  • 多尺度视觉转换器:FVT 使用多尺度视觉转换器架构,允许它以多种尺度和分辨率处理图像。这是通过使用层次结构实现的,其中较小的转换器用于处理图像的较小区域。
  • 自注意力机制:FVT 使用自注意力机制,允许它对图像的不同部分之间的复杂关系进行建模。这是通过使用在训练过程中学习到的注意力权重来实现的。
  • 位置编码:FVT 使用位置编码来保留图像的空间信息。这是通过使用学习到的位置嵌入来实现的,它们被添加到输入令牌中。

首先,让我们开始实现在自定义数据集上训练视觉转换器。为此,我们需要通过 pip 安装 fastervit。

pip install fastervit
  • 1.

让我们导入我们刚刚通过 pip 安装的 pytorch 库以及更快视觉转换器库。

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os
  • 1.
  • 2.
  • 3.
  • 4.

在这个实现中,我从 Kaggle 下载了损坏道路数据集。在这里检查。然后将它们分割为训练和验证数据集。之后加载数据集并应用数据转换。

data_dir = 'sih_road_dataset'

# Define data transformations
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Load datasets
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=4) for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
接下来我们将加载更快视觉转换器模型。
# Load the FasterViT model and modify it for your number of classes.

from fastervit import create_model

# Load FasterViT model
model = create_model('faster_vit_0_224', 
                     pretrained=True,
                     model_path="faster_vit_0.pth.tar")

# Print the model architecture
print(model)
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.

接下来我们将加载更快视觉转换器模型。

# Load the FasterViT model and modify it for your number of classes.

from fastervit import create_model

# Load FasterViT model
model = create_model('faster_vit_0_224', 
                     pretrained=True,
                     model_path="faster_vit_0.pth.tar")

# Print the model architecture
print(model)
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.

当我们打印模型时,我们可以看到末尾的头部层,这是需要修改以进行微调的部分。

为了针对您的自定义分类任务修改这一层,您应该用一个具有适当数量输出类别的新线性层替换头部层。

# Modify the final layer for custom classification
num_ftrs = model.head.in_features
model.head = torch.nn.Linear(num_ftrs, len(class_names))

# Move the model to GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.

接下来指定优化器和学习率,

import torch.optim as optim
from torch.optim import lr_scheduler

# Define loss function
criterion = torch.nn.CrossEntropyLoss()

# Define optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Learning rate scheduler
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.

好的,现在一切都已定义,我们现在指定将用于训练我们模型的自定义数据集的训练函数。


import time
import copy

def train_model(model, criterion, optimizer, scheduler, num_epochs=5):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # Backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # Statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # Deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:.4f}')

    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.
  • 36.
  • 37.
  • 38.
  • 39.
  • 40.
  • 41.
  • 42.
  • 43.
  • 44.
  • 45.
  • 46.
  • 47.
  • 48.
  • 49.
  • 50.
  • 51.
  • 52.
  • 53.
  • 54.
  • 55.
  • 56.
  • 57.
  • 58.
  • 59.
  • 60.
  • 61.
  • 62.
  • 63.
  • 64.
  • 65.
  • 66.
  • 67.
  • 68.
  • 69.

下一步是启动训练过程!

# Train the model
model = train_model(model, criterion, optimizer, exp_lr_scheduler, num_epochs=5)

# Save the model
torch.save(model.state_dict(), 'faster_vit_custom_model.pth')
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.

请注意,这不是最好的模型,因为我们可以看到模型在训练数据集上过拟合了。本文的主要目的是演示如何实现 Faster Vision Transformer 并在自定义数据集上训练它们。还有其他方法可以解决过拟合问题。

让我们对下面的图像进行训练过的模型的快速测试:

import torch
from torchvision import transforms
from PIL import Image
from fastervit import create_model

# Define the number of classes in your custom dataset
num_classes = 4  # Replace with your actual number of classes

# Create the model architecture
model = create_model('faster_vit_0_224', pretrained=False)

# Modify the final classification layer to match the number of classes in your custom dataset
model.head = torch.nn.Linear(model.head.in_features, num_classes)

# Move the model to GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Load the trained model weights
model.load_state_dict(torch.load('faster_vit_custom_model.pth'))
model.eval()  # Set the model to evaluation mode

# Define data transformations for the input image
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Function to load and preprocess the image
def load_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image = preprocess(image)
    image = image.unsqueeze(0)  # Add batch dimension
    return image.to(device)

# Function to make predictions
def predict(image_path, model, class_names):
    image = load_image(image_path)
    with torch.no_grad():
        outputs = model(image)
        _, preds = torch.max(outputs, 1)
        predicted_class = class_names[preds.item()]
    return predicted_class

# List of class names (ensure this matches your custom dataset's classes)
class_names = ['good', 'poor', 'satisfactory', 'very_poor']  # Replace with your actual class names

# Example usage
image_path = 'test_img.jpg'
predicted_class = predict(image_path, model, class_names)
print(predicted_class)
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.
  • 36.
  • 37.
  • 38.
  • 39.
  • 40.
  • 41.
  • 42.
  • 43.
  • 44.
  • 45.
  • 46.
  • 47.
  • 48.
  • 49.
  • 50.
  • 51.
  • 52.
  • 53.

预测的类别是,

责任编辑:赵宁宁 来源: 小白玩转Python
相关推荐

2024-11-21 16:06:02

2023-01-05 16:51:04

机器学习人工智能

2022-09-29 23:53:06

机器学习迁移学习神经网络

2022-06-29 09:00:00

前端图像分类模型SQL

2022-10-30 15:00:40

小样本学习数据集机器学习

2023-11-30 09:55:27

鸿蒙邻分类器

2022-06-16 10:29:33

神经网络图像分类算法

2018-04-09 10:20:32

深度学习

2024-08-23 08:57:13

PyTorch视觉转换器ViT

2017-11-23 14:35:36

2022-11-11 15:07:50

深度学习函数鉴别器

2024-06-03 07:55:00

2022-01-12 17:53:52

Transformer数据人工智能

2022-06-05 21:16:08

机器学习Python

2025-01-17 10:30:00

2018-03-26 20:49:08

图像分类

2023-01-06 19:02:23

应用技术

2020-10-10 12:53:57

逻辑回归机器学习分析

2015-04-07 13:23:36

NEC

2021-03-01 10:01:22

开发技能编码
点赞
收藏

51CTO技术栈公众号