可解释性是人工智能模型的关键话题之一。最近的复杂人工智能倾向于成为一个黑盒算法,使得人类难以理解人工智能为何提供这些结果。最近,我读了一篇论文,“通过增强开放词汇任务的可解释性进行CLIP手术”[1],主要关于CLIP的可解释技术。尽管这篇论文展示了CLIP的极佳可解释性,但很少有博客对此进行解释。因此,我将在这篇博客中介绍CLIP_Surgery的架构及其应用。
1. 快速回顾CLIP
CLIP是由OpenAI[2]开发的改变游戏规则的人工智能之一。得益于其独特的架构,它能够进行零样本图像分类。架构如下所示:
CLIP具有图像和文本编码器,用于创建图像和文本嵌入。训练数据是图像和文本对,例如带有文本“一只狗的照片”的狗的图像。它利用对比预训练来对齐图像和文本嵌入,如果图像和文本是一对,则对齐,如果不是一对则不进行对齐。为了直观理解,让我们考虑以下示例。在这个示例中,我们使用三个图像和文本对(上图中的N = 3)。
图像和文本编码器输出的嵌入维度始终是(1,512),每个图像和文本都是如此。在这个示例中,我们有维度为(3,512)的图像和文本嵌入。使用嵌入的余弦相似度,我们可以计算相似度矩阵,如上图中的矩阵。在对比预训练中,CLIP利用这个相似度矩阵来对齐匹配对(=对角线元素)使其相似,而其他对(=其他元素)则变得不相似。具体来说,论文[2]中的伪代码过程如下:
# image_encoder - ResNet or Vision Transformer
# text_encoder - CBOW or Text Transformer
# I[n, h, w, c] - minibatch of aligned images
# T[n, l] - minibatch of aligned texts
# W_i[d_i, d_e] - learned proj of image to embed
# W_t[d_t, d_e] - learned proj of text to embed
# t - learned temperature parameter
# extract feature representations of each modality
I_f = image_encoder(I) #[n, d_i]
T_f = text_encoder(T) #[n, d_t]
# joint multimodal embedding [n, d_e]
I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
T_e = l2_normalize(np.dot(T_f, W_t), axis=1)
# scaled pairwise cosine similarities [n, n]
logits = np.dot(I_e, T_e.T) * np.exp(t)
# symmetric loss function
labels = np.arange(n)
loss_i = cross_entropy_loss(logits, labels, axis=0)
loss_t = cross_entropy_loss(logits, labels, axis=1)
loss = (loss_i + loss_t)/2
计算图像和文本嵌入的余弦相似度后,它们应用交叉熵损失,使得相似度矩阵中的对角线元素变为一,其他元素变为零。作者称这种计算为对比损失。CLIP仅通过这种对比损失进行训练。
对于零样本分类,过程如下。首先,我们输入n个候选文本并获得维度为(n,512)的嵌入。接下来,我们计算目标图像嵌入和候选文本嵌入之间的相似性。最后,我们可以选择最相似的候选作为类别。不是很简单吗?
这个过程简单直观,但我们需要用数百万的图像和文本对以及数百个GPU来训练CLIP。从原始论文中,他们使用了非常大的小批量大小32,768,并在592个V100 GPU上训练了18天。因此,许多公司将这个模型作为基础模型,而不是从头开始训练。
2. CLIP手术算法的解释
CLIP手术主要是为了增强CLIP结果的可解释性而开发的。令人惊讶的是,CLIP手术可以在没有任何额外训练的情况下可视化对应标签的激活图。由于其良好的激活图可视化,这种技术可以应用于“分割任何事物”,这是分割任务的基础模型。我将在后面的章节中介绍应用。
作者彻底检查了注意力层,以实现无需训练的良好可解释性。请参见下图。
左侧显示了原始CLIP的注意力层,而右侧显示了CLIP手术的注意力层。他们明确指出,查询-键自注意力激活了与标签相对应的相反语义区域。另一方面,值-值自注意力可以只关注语义区域。这是什么意思?下图显示了查询-键自注意力和值-值自注意力的激活图可视化。
如您所见,查询键自注意力除了目标标签区域外,还可视化了不相关的区域。相反,值-值自注意力可以专注于相应的目标标签区域。基于实验,查询-键自注意力可能导致特征图混淆。请注意,这一事实是启发式的,并非由数学定理推导出来。
此外,他们意识到激活图在所有标签中具有冗余特征。请参见下图。
如您所见,冗余区域在跨标签的相同位置出现。因此,他们想出了一个主意,即通过移除所有标签中的共同激活区域来移除冗余特征。
他们是如何实现的?具体来说,官方实现如下。
# weights to restrain influence of obvious classes on others
# (batch_size, 1, 512) @ (the number of labels, 512).T = (batch_size, 1, the number of labels)
prob = image_features[:, :1, :] @ text_features.t()
# prob has (batch_size, 1, the number of labels)
prob = (prob * 2).softmax(-1)
# w has (batch_size, 1, the number of labels)
w = prob / prob.mean(-1, keepdim=True)
# element-wise multiplied features
# b is batch_size
# n_t is the number of labels
# n_i is the number of tokens (=197)
# c is the feature dimension (=512)
b, n_t, n_i, c = image_features.shape[0], text_features.shape[0], image_features.shape[1], image_features.shape[2]
# feats has (batch_size, n_i, n_t, c)
feats = image_features.reshape(b, n_i, 1, c) * text_features.reshape(1, 1, n_t, c)
feats *= w.reshape(1, 1, n_t, 1)
# redundant_feats has (batch_size, n_i, n_t, c)
redundant_feats = feats.mean(2, keepdim=True) # along cls dim
feats = feats - redundant_feats
# sum the element-wise multiplied features as cosine similarity
# similarity has (batch_size, n_i, n_t)
similarity = feats.sum(-1)
为了更好地演示,我为代码中的每次计算添加了维度大小转换。现在,让我们一步一步地弄清楚。
第一个模块计算权重向量,以保持每个类别的影响相等。首先,我们从图像嵌入中提取类别标记。在变换器架构中,类别标记是标记维度中的第一个。请注意,类别标记应该包含有关所有其他标记的信息(如果您不熟悉视觉变换器,可以参考这篇博客[5])。然后,我们计算余弦相似度并获得相似度矩阵。接下来,我们将相似度矩阵的值转换为标签维度上的概率,并获取权重矩阵。
在第二个模块中,我们计算除了冗余特征之外的特征矩阵。首先,我们计算图像和文本嵌入的逐元素特征矩阵。直观地说,跨标签的激活区域在这张图中将具有更高的值,如上图所示。因此,我们可以通过在标签上计算平均值从特征矩阵中获得冗余特征。从原始特征矩阵中减去冗余特征后,我们可以获得纯净的特征矩阵。
在最后一个模块中,我们通过沿特征维度对特征矩阵求和来获得相似度矩阵。
对于特征图可视化,我们需要将相似度矩阵归一化、重塑和插值到输入图像大小(您可以稍后检查使用附加代码的实现)作为后处理。下图显示了CLIP手术的结果。
如您所见,它可以捕获与标签相对应的语义区域。您可以感受到这种可视化的强大之处。
到目前为止,我们已经看到了CLIP手术的详细算法。在最后一节中,我们将检查其对现实世界数据的能力及其应用。
3. 应用:检查现实世界数据的能力以及为“分割任何事物”提供点
在最后一节中,我将指导您了解CLIP手术在现实世界数据和“分割任何事物”(SAM)中的应用。让我们深入了解它们!
(1) 环境设置
作为第一步,您需要设置一个环境。我使用了ubuntu20.04、cuda11.7和Python3.10环境。首先,我使用conda创建虚拟环境。
conda create --name sam python==3.10 -y
conda activate sam
conda install pip
## optional: To avoid install libraries on the local environment,
## check the which pip will be used to store libraries
which pip
# I use /opt/conda/envs/sam/bin/pip in my enviornment.
接下来,您需要按照官方说明安装Pytorch和torchvision。您可以安装与您的环境相对应的版本。例如,下面的命令是我的案例。
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
然后,您需要使用以下命令安装SAM存储库和模型权重。
pip install git+https://github.com/facebookresearch/segment-anything.git
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
您还需要安装CLIP手术存储库。
git clone https://github.com/xmed-lab/CLIP_Surgery.git
最后,您需要安装几个包。您可以通过pip以“pip install <library>”的格式安装它们。
tqdm==4.66.5
ftfy==6.2.3
matplotlib
opencv-python
regex
现在,您已经完成了环境设置。
(2) CLIP手术对Flickr30k数据集的能力
首先,我想检查CLIP手术对现实世界数据的能力,使用Flickr30k数据集。因此,我将比较CLIP和CLIP手术激活图。我稍后会附上使用的代码。下图是比较的结果。
如您所见,原始CLIP无法精确检测对象,但当对象存在时,CLIP手术可以检测与标签相对应的对象。然而,当对象不存在时,例如猫和植物,CLIP手术仍然存在问题。这个问题的一个原因是后处理中的最小-最大归一化。当激活图中只有不相关区域时,最小-最大归一化可能会增强它们值之间的差异。为了解决这个问题,我们可以在最小-最大归一化之前简单地添加一个阈值。在Flickr数据集的情况下,相关区域值的阈值是0.1以上,这是通过相似度图的直方图检查的。结果如下所示。
多亏了阈值,我们可以移除不相关的区域。阈值可能根据数据集而变化;因此,我们应该使用直方图检查并找到该值。
(3) 为“分割任何事物”提供点
由于激活图可视化的精确性,CLIP手术可以应用于“分割任何事物”的点提供器。供您参考,SAM是Meta在2023年开发的分割基础模型之一。下图显示了架构。
SAM的分割能力令人难以置信。然而,它不是通过带有标签的分割数据集训练的,所以我们需要在指定对象时提供一些点、边界框或掩码。正如您所猜,这些类型的注释非常耗时。在这里,CLIP手术帮助我们自动找到点。让我们看看如何在实际实现中结合CLIP手术和SAM。
为了为SAM生成点,我们对激活图进行下采样并对值进行排序以选择相关区域。在官方实现中,他们使用维度为(7 x 7)的激活图来找到最相关的区域。当目标对象不存在时也存在问题,所以我稍微修改了原始实现,添加了一个阈值。结果如下所示。
橙色点指的是与标签相关的点,而蓝色点代表标签的负点。如您所见,它可以以相当准确的精度检测目标标签坐标。请注意,点的准确性来自CLIP的能力。因此,如果CLIP不理解目标,它无法准确提供目标点。我将附上在此应用中使用的Jupyter笔记本。
详细代码可以参考链接:https://gist.github.com/tanukon/55715b577a32998f3417e7cea268c658#file-clip_surgery_experiment-ipynb