深度度量学习的这十三年,难道是错付了吗?

开发 开发工具 深度学习
近日,Facebook AI 和 Cornell Tech 的研究者在论文预印本平台 arXiv 上公布了最新研究论文,声称这十三年来深度度量学习领域的研究进展「实际上并不存在」。

或许对于每一个领域来说,停下脚步去思考,与低头赶路一样重要。

[[326375]]

「度量学习(Metric Learning)」即学习一个度量空间,在该空间中的学习异常高效,这种方法用于小样本分类时效果很好,不过度量学习方法的效果尚未在回归或强化学习等其他元学习领域中验证。

在人脸识别、指纹识别等开集分类的任务中,类别数往往很多而类内样本数比较少。在这种情况下,基于深度学习的分类方法常表现出一些局限性,如缺少类内约束、分类器优化困难等。而这些局限可以通过深度度量学习来解决。

「四年来,深度度量学习领域的论文经常表示已经在准确性层面取得了很大的进展,基本是十年前方法的两倍以上。」事实上,我们真的取得了这么惊人的进展吗?

近日,Facebook AI 和 Cornell Tech 的研究者在论文预印本平台 arXiv 上公布了最新研究论文,声称这十三年来深度度量学习领域的研究进展「实际上并不存在」。

论文链接:https://arxiv.org/pdf/2003.08505.pdf

研究者发现,度量学习的这些论文在实验设置方面存在多种缺陷,比如不公平的实验比较、测试集标签泄露、不合理的评价指标等。于是,他们提出了一种新的评估方法来重新审视度量学习领域的多项研究。最后,他们通过实验表明,现有论文宣称的那些改进实在是「微不足道」,近几年的 ArcFace,、SoftTriple,、CosFace 等十种算法,和十三年前的 Contrastive、Triplet 基线方法相比,并没有什么实质性的提高。

也就是说,论文宣称的改进是节节攀升的:

但实际情况却是原地踏步:

之前的论文存在哪些缺陷?

1. 不公平的比较

为了宣称新算法的性能比已有的方法要好。尽可能多地保持参数不变是很重要的。这样便能够确定性能的优化是新算法带来的提升,而不是由额外的参数造成的。但现有的度量学习论文的研究情况却不是如此。

提高准确率最简单的方法之一是优化网络架构,但这些论文却没有保证这项基本参数固定不变。度量学习中架构的选择是非常重要的。在较小的数据集上的初始的准确率会随着所选择的网络而变化。2017 年一篇被广泛引用的论文用到了 ResNet50,然后声称性能得到了巨大的提升。这是值得质疑的,因为他们用的是 GoogleNet 作比较,初始准确率要低得多(见表 1)。

2. 通过测试集反馈进行训练

该领域大多数论文会将每个数据集分开,类中的前 50% 用作训练集,剩下的部分用作测试集。训练过程中,研究者会定期检查模型在测试集上的准确率。也就是,这里没有验证集,模型的选择和超参数的调整是通过来自测试集的直接反馈完成的。一些论文并不定期检查性能,而是在预先设置好的训练迭代次数之后报告准确率。在这种情况下,如何设置迭代次数并不确定,超参数也仍然是在测试集性能的基础上调整的。这种做法犯了机器学习研究的一个大忌。依靠测试集的反馈进行训练会导致在测试集上过拟合。因此度量学习论文中所阐述的准确率的持续提升会被质疑。

3. 常用的准确率度量的缺点

为了报告准确率,大多数度量学习论文用到的指标是 Recall@K、标准化互信息(NMI)以及 F1 分值。但这些真的是最佳度量标准吗?图 1 展示了三种嵌入空间,虽然它们有不同的特性,但每个 Recall@1 的分值都接近 100%,说明这个指标基本上提供不了什么信息。

新的评估方法

以上种种缺陷造成了度量学习领域的「虚假繁荣」。因此研究者提出了一种新的评估方法,希望能够对损失函数进行恰当的评估。为此,他们做了如下设置:

1. 公平的比较和复现

所有的实验都是在 PyTorch 上进行的,用到了 ImageNet 来预训练 BN-Inception 网络。训练过程中冻结 BatchNorm 参数,以减少过拟合。批大小设置为 32。

训练过程中,图像增强通过随机调整大小的裁剪策略来完成。所有的网络参数都用学习率为 1e-6 的 RMSprop 进行优化。在计算损失函数之前和评估过程中,对嵌入进行 L2 归一化。

2. 通过交叉验证进行超参数搜索

为了找到最好的损失函数超参数,研究运行了 50 次贝叶斯优化迭代,每次迭代均包括 4 折交叉验证:

类中的第一半用来交叉验证,创建 4 个分区,前 0-12.5% 是第一个分区,12.5-25% 是第二个分区,以此类推。

第二半用来做测试集,这和度量学习论文使用多年的设置相同,目的是便于和之前的论文结果做比较。

超参数都被优化到能最大化验证精确度的平均值。对于最佳超参数,将加载每个训练集分区的最高准确率检查点,测试集的嵌入是经过计算和 L2 归一化的,然后计算准确率。

3. 更有信息量的准确率度量指标

研究者用 Mean Average Precision at R (MAP@R) 来度量准确度,这一指标综合了平均精度均值和 R 精度的思想。

R 精度的一个弱点是,它没有说明正确检索的排序。因此,该研究使用 MAP@R。MAP@R 的好处是比 Recall@1 更有信息量(见图 1)。它可以直接从嵌入空间中计算出来,而不需要聚类步骤,也很容易理解。它奖励聚类良好的嵌入空间。

实验

1. 损失和数据集

研究者选择了近年来多个会议论文在度量学习领域提出的先进方法(如表 6 所示),在 11 种损失和一种损失+miner 组合上进行实验。

此前,度量学习领域的论文一直没有面向验证损失的内容,因此该研究加入了这方面的两项损失。

研究者选用了 3 个度量学习领域广泛使用的数据集:CUB200、Cars196 和 Stanford Online Products (SOP),选择这 3 个数据集也便于和之前的论文做比较。表 3-5 展示了训练运行的平均准确率,以及在适用时 95% 的置信区间,加粗部分代表了最好的平均准确率。同时也包括了预训练模型的准确率,用 PCA 将模型的嵌入值减少到 512 或 128。

2. 论文 vs 现实

首先,让我们看一下论文结果的普遍趋势,图 4(a) 展示了该领域中「本以为」的准确率提升,即新方法完全淘汰了旧方法。

但正如图 4(b) 所示,实验结果和预期并不一致。

研究者发现,这些论文过分夸大了自己相对于两种经典方法——对比损失(contrastive loss)和三元组损失(triplet loss)——的改进。许多论文表示,自己方法的性能超出了对比损失一倍还多,比三元组损失也高出 50% 以上。这些提升是因为这些损失造成了非常低的准确性。

这些数据有一些是来源于 2016 年的提升结构损失论文,在他们的对比损失和三元组损失的实现中,他们每批采样 N/2 样本对和 N/3 样本三元组(N 是批的大小)。因此,他们只用到了每批里的一小部分数据信息。

他们将三元组的 margin 设置为 1,而最优的值大约是 0.1。尽管有这些实现缺陷,大多数论文仍旧只是简单地引用这些较低的数字,而不是依靠自己实现损失去获得一个更有意义的基线。

通过这些基线损失所呈现的良好实现、公平竞争环境和机器学习实践,研究者获得了如图 4(b) 所示的趋势图——事实上它似乎是平滑的走向。这表明无论是在 2006 年还是在 2019 年,各种方法的性能都是相似的。换句话说,度量学习算法并没有取得论文中所说的那么夸张的进展,论文中没有提到的前沿论文也值得怀疑。

这十几年的研究投入,终究是错付了吗?

在这篇论文出现以后,很多人在讨论:度量学习是否已经到了一个瓶颈期?我们还要继续在这个研究方向上前进吗?

第一个问题的答案是肯定的,第二个问题的答案也是肯定的。

中科院计算所博士生、知乎用户 @ 王晋东认为:「其实大可不必心潮澎湃、攻击别人、对该领域前途失望。」

其实每个领域经历过一段长时间的发展以后,都必然会有研究者回过头来进行反思。学术研究也适用于这条定律:「走得太远,忘记了为什么出发。」

图源:知乎 @ 王晋东不在家。https://www.zhihu.com/question/394204248/answer/1219383067

也有深度度量学习领域研究者、CVPR 2019 论文一作前来回答,并将这篇论文放在了自身研究介绍项目的开篇,希望「能让做这个领域的人看到,引导新入这个坑的人向着正确的方向走。因为,我也曾是踩过这些坑过来的」。

图源:知乎 @ 王珣。https://www.zhihu.com/question/394204248/answer/1219001568

质疑会带来讨论,讨论则引起反思。停下脚步后的思考,与赶路一样重要。在你的领域,也曾经有过这样的讨论吗?

参考链接:https://www.zhihu.com/question/394204248

【本文是51CTO专栏机构“机器之心”的原创译文,微信公众号“机器之心( id: almosthuman2014)”】 

戳这里,看该作者更多好文

 

责任编辑:赵宁宁 来源: 51CTO专栏
相关推荐

2023-05-26 14:02:29

AI智能

2021-12-21 15:31:40

KubernetesDocker容器

2013-07-17 09:13:19

2019-01-17 05:14:07

深度学习人工智能AI

2023-10-10 15:33:55

机器学习相似性度量

2020-01-06 09:14:59

Java程序员线程

2023-09-20 09:56:18

深度学习人工智能

2018-09-29 10:05:54

深度学习神经网络神经元

2017-10-30 14:51:44

APP网页窗口

2021-05-10 11:40:51

函数NumpyPython

2020-05-28 15:35:07

人工智能

2017-05-09 08:18:11

机器学习损失函数正则化

2021-08-12 05:41:23

人工智能AI深度学习

2016-11-04 23:45:12

云安全信息安全

2020-06-24 08:26:10

编程语言Perl技术

2021-03-02 14:23:06

人工智能深度学习

2021-10-08 10:45:38

深度学习编程人工智能

2010-04-28 13:31:52

IT技术人员

2018-11-14 08:13:55

机房搬迁网络

2022-09-16 15:17:44

机器之心
点赞
收藏

51CTO技术栈公众号