实现最先进的蒙版自编码器(MAE)

人工智能 机器视觉
今天,我深入探讨视觉变换器之后计算机视觉领域最重要的突破之一:蒙版自编码器(MAE)。

今天,我深入探讨视觉变换器之后计算机视觉领域最重要的突破之一:蒙版自编码器(MAE)。简要回顾一下它的工作原理:

以下是工作步骤:

  • 图像被分割成块。
  • 这些块的一个子集被随机蒙版。
  • 只有可见的块被送入编码器(这很关键)。
  • 解码器接收编码器的压缩表示,并尝试使用可见和蒙版的块重建整个图像。
  • 仅在蒙版块上计算损失。

导入

  • einops:用于其“repeat”函数
  • architectures.vit:标准ViT变换器的架构,我使用的是在“如何训练ViT?”中提供的。

import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat

from architectures.vit import Transformer

设置MAE类:

class MAE(nn.Module):
    def __init__(
        self,
        *,
        encoder,
        decoder_dim,
        masking_ratio=0.75,
        decoder_depth=1,
        decoder_heads=8,
        decoder_dim_head=64
    ):
        super().__init__()
        # Ensure the masking ratio is valid
        assert 0 < masking_ratio < 1, 'masking ratio must be between 0 and 1'
        self.masking_ratio = masking_ratio

我们定义一个从PyTorch的nn.Module继承的MAE类。

  • 编码器:我们的视觉变换器模型。
  • decoder_dim:解码器嵌入空间的维度(例如512)。
  • masking_ratio:要蒙版的块的比例(文章发现75%是最优的)。
  • 其他解码器配置,如深度、头和头维度,这些都是变换器的标准。
  • 我们断言蒙版比例在0和1之间。

块:

        # Save the encoder (a Vision Transformer to be trained)
        self.encoder = encoder

        # Extract the number of patches and the encoder's dimensionality from the positional embeddings
        num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]

        # Separate the patch embedding layers from the encoder
        # The first layer converts the image into patches
        self.to_patch = encoder.to_patch_embedding[0]
        # The remaining layers embed the patches
        self.patch_to_emb = nn.Sequential(*encoder.to_patch_embedding[1:])

发生了什么?

我们存储编码器并提取必要信息,如块的数量和编码器的输出维度。

我们分离块嵌入过程:

  • self.to_patch:这层将图像分割成较小的块。
  • self.patch_to_emb:这将每个块嵌入到向量空间。
# Determine the dimensionality of the pixel values per patch
pixel_values_per_patch = encoder.to_patch_embedding[2].weight.shape[-1]

我们计算每个块中的像素值数量,稍后会需要。

设置解码器

self.enc_to_dec:如果编码器和解码器的维度不同,我们相应地映射它们。通常编码器较大且维度较高(例如1024),而解码器可以更浅且维度较小(例如512),但我们需要一个适配器将编码器的维度映射到解码器的维度。

self.mask_token:一个可学习的标记,代表解码器的蒙版块。当块被蒙版时,这是解码器看到的标记。

我们初始化解码器变换器和其他重建所需的层。

self.decoder = Transformer(
    dim=decoder_dim,
    depth=decoder_depth,
    heads=decoder_heads,
    dim_head=decoder_dim_head,
    mlp_dim_ratio=4
)
# Positional embeddings for the decoder tokens
self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim)
# Linear layer to reconstruct pixel values from decoder outputs
self.to_pixels = nn.Linear(decoder_dim, pixel_values_per_patch)

到目前为止,你的MAE类应该像这样初始化:

class MAE(nn.Module):
    def __init__(
        self,
        *,
        encoder,
        decoder_dim,
        masking_ratio=0.75,
        decoder_depth=1,
        decoder_heads=8,
        decoder_dim_head=64
    ):
        super().__init__()
        # Ensure the masking ratio is valid
        assert 0 < masking_ratio < 1, 'masking ratio must be between 0 and 1'
        self.masking_ratio = masking_ratio

        # Save the encoder (a Vision Transformer to be trained)
        self.encoder = encoder

        # Extract the number of patches and the encoder's dimensionality from the positional embeddings
        num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]

        # Separate the patch embedding layers from the encoder
        # The first layer converts the image into patches
        self.to_patch = encoder.to_patch_embedding[0]
        # The remaining layers embed the patches
        self.patch_to_emb = nn.Sequential(*encoder.to_patch_embedding[1:])

        # Determine the dimensionality of the pixel values per patch
        pixel_values_per_patch = encoder.to_patch_embedding[2].weight.shape[-1]

        # Set up decoder parameters
        self.decoder_dim = decoder_dim
        # Map encoder dimensions to decoder dimensions if they differ
        self.enc_to_dec = (
            nn.Linear(encoder_dim, decoder_dim)
            if encoder_dim != decoder_dim
            else nn.Identity()
        )
        # Learnable mask token for masked patches
        self.mask_token = nn.Parameter(torch.randn(decoder_dim))
        # Define the decoder transformer
        self.decoder = Transformer(
            dim=decoder_dim,
            depth=decoder_depth,
            heads=decoder_heads,
            dim_head=decoder_dim_head,
            mlp_dim_ratio=4
        )
        # Positional embeddings for the decoder tokens
        self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim)
        # Linear layer to reconstruct pixel values from decoder outputs
        self.to_pixels = nn.Linear(decoder_dim, pixel_values_per_patch)

太好了!现在让我们看看如何在前向传递中使用这些不同的部分,这有点像拼图。

前向传递

让我们走过前向函数,它定义了我们的模型如何处理输入数据。

def forward(self, img):
    device = img.device

    # Convert the input image into patches
    patches = self.to_patch(img)  # Shape: (batch_size, num_patches, patch_size)
    batch_size, num_patches, *_ = patches.shape

    # Embed the patches using the encoder's patch embedding layers
    tokens = self.patch_to_emb(patches)  # Shape: (batch_size, num_patches, encoder_dim)

开始非常标准,我们只需要将“将图像块化”操作与“投影到标记”操作分解,因为我们使用原始块作为计算损失的基准。

  • 前向方法以图像张量img作为输入。
  • 我们获取张量所在的设备(CPU或GPU)。
  • 我们将图像分割成块。
  • 我们获得批量大小和块的数量。
  • 每个块被嵌入到一个向量中。

位置编码:

# Add positional embeddings to the tokens
if self.encoder.pool == "cls":
    # If using CLS token, skip the first positional embedding
    tokens += self.encoder.pos_embedding[:, 1 : num_patches + 1]
elif self.encoder.pool == "mean":
    # If using mean pooling, use all positional embeddings
    tokens += self.encoder.pos_embedding.to(device, dtype=tokens.dtype)

我们为每个标记添加位置信息,以便模型知道每个块来自哪里。如果有额外的CLS标记,我们需要跳过它,因为它不是图像的一部分。

蒙版和编码

现在我们来到最有趣的部分,蒙版图像。

# Determine the number of patches to mask
num_masked = int(self.masking_ratio * num_patches)

# Generate random indices for masking
rand_indices = torch.rand(batch_size, num_patches, device=device).argsort(dim=-1)
masked_indices = rand_indices[:, :num_masked]
unmasked_indices = rand_indices[:, num_masked:]

我们根据我们的蒙版比例计算我们将蒙版多少块。

我们为每个块序列生成一个随机排列。

我们相应地定义masked_indices和unmasked_indices。

# Select the tokens corresponding to unmasked patches
batch_range = torch.arange(batch_size, device=device)[:, None]
tokens = tokens[batch_range, unmasked_indices]

# Select the original patches that are masked (for reconstruction loss)
masked_patches = patches[batch_range, masked_indices]

# Encode the unmasked tokens using the encoder's transformer
encoded_tokens = self.encoder.transformer(tokens)

我们选择刚刚定义的masked_indices对应的masked_patches。

我们只保留未蒙版块的标记以进行编码。

解码

现在让我们进入最令人兴奋但也最难的部分,解码!

# Map encoded tokens to decoder dimensions if necessary
decoder_tokens = self.enc_to_dec(encoded_tokens)

# Add positional embeddings to the decoder tokens of unmasked patches
unmasked_decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)

# Create mask tokens for the masked patches and add positional embeddings
mask_tokens = repeat(self.mask_token, 'd -> b n d', b=batch_size, n=num_masked)
mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)

# Initialize the full sequence of decoder tokens
decoder_sequence = torch.zeros(
  batch_size, num_patches, self.decoder_dim, device=device
)
# Place unmasked decoder tokens and mask tokens in their original positions
decoder_sequence[batch_range, unmasked_indices] = unmasked_decoder_tokens
decoder_sequence[batch_range, masked_indices] = mask_tokens

# Decode the full sequence
decoded_tokens = self.decoder(decoder_sequence)

# Extract the decoded tokens corresponding to the masked patches
masked_decoded_tokens = decoded_tokens[batch_range, masked_indices]
  • 我们调整编码标记以匹配解码器预期的输入大小self.enc_to_dec
  • 我们向解码器标记添加位置嵌入。
  • 对于蒙版位置,我们使用蒙版标记并添加位置嵌入。
  • 我们通过将未蒙版和蒙版标记放回其原始位置来重建完整序列。
  • 我们将完整序列传递给解码器。
  • 我们提取对应于蒙版块的解码标记。
# Reconstruct the pixel values from the masked decoded tokens
pred_pixel_values = self.to_pixels(masked_decoded_tokens)

# Compute the reconstruction loss (mean squared error)
recon_loss = F.mse_loss(pred_pixel_values, masked_patches)
return recon_loss
  • 我们尝试重建蒙版块的原始像素值。
  • 我们通过将重建的块与原始蒙版块进行比较来计算L2损失。

参考资料:

  • 参考代码:https://github.com/FrancoisPorcher/awesome-ai-tutorials?source=post_page-----6f454b736087--------------------------------
  • 论文:https://arxiv.org/abs/2111.06377
  • 参考代码:https://github.com/lucidrains/vit-pytorch
责任编辑:赵宁宁 来源: 小白玩转Python
相关推荐

2021-03-29 11:37:50

人工智能深度学习

2021-03-22 10:52:13

人工智能深度学习自编码器

2022-09-13 15:26:40

机器学习算法数据

2021-02-20 20:57:16

深度学习编程人工智能

2017-07-19 13:40:42

卷积自编码器降噪

2018-05-21 08:22:14

自编码器协同过滤深度学习

2024-06-18 08:52:50

LLM算法深度学习

2017-11-10 12:45:16

TensorFlowPython神经网络

2022-04-02 21:46:27

深度学习编码器图像修复

2020-04-26 11:26:02

人脸合成编码器数据

2017-12-26 10:48:37

深度学习原始数据

2017-07-03 07:14:49

深度学习无监督学习稀疏编码

2024-11-13 16:24:33

ViT架构PyTorch

2021-11-02 20:44:47

数字化

2014-08-07 10:49:20

debugdebug技巧

2013-09-16 09:41:13

400G网络处理器思科网络处理器

2014-08-07 10:03:31

debug技巧原则

2012-04-10 16:55:22

PowerSmart编码器

2012-04-01 16:40:45

编码器

2023-04-25 21:36:07

火山引擎
点赞
收藏

51CTO技术栈公众号