今天,我深入探讨视觉变换器之后计算机视觉领域最重要的突破之一:蒙版自编码器(MAE)。简要回顾一下它的工作原理:
以下是工作步骤:
- 图像被分割成块。
- 这些块的一个子集被随机蒙版。
- 只有可见的块被送入编码器(这很关键)。
- 解码器接收编码器的压缩表示,并尝试使用可见和蒙版的块重建整个图像。
- 仅在蒙版块上计算损失。
导入
- einops:用于其“repeat”函数
- architectures.vit:标准ViT变换器的架构,我使用的是在“如何训练ViT?”中提供的。
设置MAE类:
我们定义一个从PyTorch的nn.Module继承的MAE类。
- 编码器:我们的视觉变换器模型。
- decoder_dim:解码器嵌入空间的维度(例如512)。
- masking_ratio:要蒙版的块的比例(文章发现75%是最优的)。
- 其他解码器配置,如深度、头和头维度,这些都是变换器的标准。
- 我们断言蒙版比例在0和1之间。
块:
发生了什么?
我们存储编码器并提取必要信息,如块的数量和编码器的输出维度。
我们分离块嵌入过程:
- self.to_patch:这层将图像分割成较小的块。
- self.patch_to_emb:这将每个块嵌入到向量空间。
我们计算每个块中的像素值数量,稍后会需要。
设置解码器
self.enc_to_dec:如果编码器和解码器的维度不同,我们相应地映射它们。通常编码器较大且维度较高(例如1024),而解码器可以更浅且维度较小(例如512),但我们需要一个适配器将编码器的维度映射到解码器的维度。
self.mask_token:一个可学习的标记,代表解码器的蒙版块。当块被蒙版时,这是解码器看到的标记。
我们初始化解码器变换器和其他重建所需的层。
到目前为止,你的MAE类应该像这样初始化:
太好了!现在让我们看看如何在前向传递中使用这些不同的部分,这有点像拼图。
前向传递
让我们走过前向函数,它定义了我们的模型如何处理输入数据。
开始非常标准,我们只需要将“将图像块化”操作与“投影到标记”操作分解,因为我们使用原始块作为计算损失的基准。
- 前向方法以图像张量img作为输入。
- 我们获取张量所在的设备(CPU或GPU)。
- 我们将图像分割成块。
- 我们获得批量大小和块的数量。
- 每个块被嵌入到一个向量中。
位置编码:
我们为每个标记添加位置信息,以便模型知道每个块来自哪里。如果有额外的CLS标记,我们需要跳过它,因为它不是图像的一部分。
蒙版和编码
现在我们来到最有趣的部分,蒙版图像。
我们根据我们的蒙版比例计算我们将蒙版多少块。
我们为每个块序列生成一个随机排列。
我们相应地定义masked_indices和unmasked_indices。
我们选择刚刚定义的masked_indices对应的masked_patches。
我们只保留未蒙版块的标记以进行编码。
解码
现在让我们进入最令人兴奋但也最难的部分,解码!
- 我们调整编码标记以匹配解码器预期的输入大小self.enc_to_dec
- 我们向解码器标记添加位置嵌入。
- 对于蒙版位置,我们使用蒙版标记并添加位置嵌入。
- 我们通过将未蒙版和蒙版标记放回其原始位置来重建完整序列。
- 我们将完整序列传递给解码器。
- 我们提取对应于蒙版块的解码标记。
- 我们尝试重建蒙版块的原始像素值。
- 我们通过将重建的块与原始蒙版块进行比较来计算L2损失。
参考资料:
- 参考代码:https://github.com/FrancoisPorcher/awesome-ai-tutorials?source=post_page-----6f454b736087--------------------------------
- 论文:https://arxiv.org/abs/2111.06377
- 参考代码:https://github.com/lucidrains/vit-pytorch