图片来自 Lisha Li
Andrej Karpathy 在他的 Train AI 演讲中展示了这张胶片,我非常喜欢。这张胶片***地揭示了深度学习在研究与生产间的区别。通常来说,学术论文的主要精力是放在开发新的、先进的模型上面,在数据集方面一般都是从公开的数据集取一部分使用。而相反,那些我所知道的想用深度学习技术做实际应用的开发者们,他们绝大部分的精力都花在了担心他们的训练数据上面。
有许多好的原因可以解释为什么那些研究员会将精力放在模型架构上,而对于那些想要在实际生产过程中运用机器学习的人来说,相关的指导材料就比较少。为了解决这个问题,我在此次会议上的演讲主题是「关于训练数据上不可以思议的效果」,并且我想在这篇 blog 中进一步说明为什么优化训练数据是如此重要,并给出一些实用的建议。
在我的工作中经常与许多研究员与生产团队有紧密的合作,而我之所以如此相信优化数据的魔力是因为我亲眼目睹了它们在构建模型时所发挥的作用以及所带来的巨大收益。对于大多数应用来说,能否使用深度学习的***障碍是模型精度能否达到实际使用的要求,而据我所知,最快捷的提升精度的方式就是优化训练集。并且如果在部署过程中你还受限于例如时延或存储大小等因素,那么你可能需要在一个满足性能要求,经过折衷的较小模型架构上来提升模型精度。
语音控制
我不能分享我在所生产系统中观察到的,但是我能从一个开源项目的例子中来证明这点。去年我通过 Tensorflow 框架搭建了一个简单的语音识别系统,但在实现的过程中发现没有现成的数据集可以让我来训练模型。在大量志愿者的帮助下,我收集了 60000 个 1 秒钟的语音片段,这里要感谢 Open Speech 网站中的 AIY 团队。通过这个数据集,最终训练出来的模型是可用的,但是精度没有达到我的预期。为了看看这是否是由于我所设计模型所导致的,我在 Kaggle 上用同样的训练集发起了一个竞赛。在竞赛中,许多竞赛者设计的模型要比我的效果好,但即使是来自不同团队不同方法,最终的精度也就只能达到 91%。对我来说这暗示着数据集中存在一些基本的错误,确实,有竞赛者也向我指出了许多关于训练集的错误,如:有些音频标签打错了,有些音频不完整。这促使我有了发布一个新数据集的动力,解决他们所指出的那些问题并且再补充更多的样本。
我通过查看错误的度量标准去理解什么样的词汇是模型难以识别的。结果显示「其他」种类(当语音辨识系统识别语音时该单词却不在模型所训练到的词汇中)的是最容易识别错误的。为了解决这个问题,我们获取了更多不同的单词以确保训练数据的多样性。
自从 Kaggle 竞赛者报告了标签错误这一问题,我就请人做了额外的验证环节,请人去听每个语音片段然后确保它与预期的标签相符。另外由于 Kaggle 数据集还被发现了又很多几乎没有声音或者很短的声音,因此我写了能自动做一些音频分析并且剔除这些不太好的数据的小程序。最终,虽然删除了很多不太好的语音数据但还是将语音数据增加到了十万,这多亏了志愿者与一些受雇人员的努力。
为了帮助其他人更好地使用数据集(不要重蹈我的覆辙!)我将这一切相关的都写在一篇 Arixiv 论文当中了,以及更新后的准确率。最重要的结论是,在不改变模型和测试集数据的情况下,top-one(***次预测类)的准确率从 85.4% 提高到 89.7%,整整提高了超过 4%。这是一个令人惊叹的提升,这样的提升对于将模型部署在安卓或树莓派小应用的人来说这是十分满意的了。另外,我也很有自信在使用的模型架构落后于现有先进水平的情况下通过花费一定时间来调整模型获得进一步的少量提高。
这是在实际应用过程中一次又一次获得过好结果的过程,然而如果你想做一样的尝试,那么从哪开始对于你来说其实是比较困难的。你可以从我刚才对语音数据的处理中得到一些启发,但是我这里有一些更实用的方法。
首先,观察你的数据
这个似乎是显而易见的,但其实你首先需要做的是随机观察你的训练数据。将其中一部分复制到你的本地机,并且花费几个小时去预览它们。如果你的数据是一些图片,那么使用一些类似 MacOs 系统的查看器非常快速的查看数千张缩略图。对于音频来说,可以播放它们的预览,又或者对于文本来说,可以随机的转储一些片段到终端设备。我没有用足够的时间在***次语音中做这样的工作,因此导致 Kaggle 参赛者发现非常多的问题。
我总是会感觉到这个观察数据的过程有点傻,但是我从未后悔过。每次我这么做之后都会发现很多关于数据集的一些很关键的问题,像是类别不均衡问题,无法读入的数据(例如 PNG 格式的图片被加上了 JPG 格式的后缀),错误的标签亦或者令人奇怪的组合。Tom White 在预先观察数据的过程中在 ImageNet 数据集中有一些奇特的发现,例如将一个用于放大太阳光的古老设备标记为「太阳镜」,一张有魅力的图片被标记为」垃圾车」,一个穿着红斗篷的女人被标记为「斗篷」。Andrej 手工区分从 ImageNet 中的照片的工作也教会了我很多对于数据集的理解,包括即使对于一个人来说区分狗的种类也是很困难的事。
你所采取的行动取决于你的发现,但其实你应该坚持在做其他的数据清洗工作之前有这样的对于数据预先观察的过程,因为在这个过程中你会获取一些直觉信息以帮助你在其他步骤里作出决策。
快速选择模型
不要花太多时间在模型选择上。如果你需要做图像分类,可以看看 AutoML,TensorFlow 的 model repository,或者 Fast.AI 的 collection of examples,这些模型库里一般可以找到一些解决类似于你产品的问题的模型。最重要的是尽快开始迭代你的模型,这样你就可以尽早地让真实用户来测试你的模型。你总会有机会在后续的过程中来提高你的模型,也可能会得到更好的结果,但是首先你必须保证你的数据在一开始就是有效的。深度学习依然遵从最基本的计算定律,输入无效数据,那么就输出无效的结果。因此,即使***的模型也受限于训练数据中的瑕疵。通过选择一个模型并对其进行测试,你会找到训练数据中存在的问题并对这些问题进行改进。
为了进一步提高你的迭代速度,可以从一个已经在大样本数据上训练过的模型开始,利用迁移学习在一个你所收集的可能小很多的数据集上来对模型参数进行微调。相比直接用你的小样本数据来对模型进行训练,这样通常可以更快地得到更好的结果,你也可以由此找到一些感觉需要怎样对你收据收集的方式做一些必要的调整。这样做最重要的效果是你可以把你训练得到的结果反馈到数据收集的过程中去,边学习边调整,而不是把数据采集作为一个独立的在训练之前的一个步骤。
创造之前先模拟
研究建模和产品建模之间***的区别在于前者通常在一开始就有一个非常明确的问题描述,而后者对于模型的许多要求一开始并不明确,而是存在于用户的脑中,需要一步一步挖掘出来。例如,我们希望 Jetpac 能提供好的照片给自动化的城市旅游指南。开始的时候,我们让打分的人来给照片贴标签,如果他们觉得照片「好」,就给照片做标记。但是,我们***得到了很多带有微笑的人像的照片,因为打分者认为这些是「好」照片。我们把这些标记了的照片用到我们的产品模型中,看看测试用户的反应。可想而知,他们并不满意***的结果。
为了解决这个问题,我们把标记照片时的问题改成了:「这张照片会吸引你去这个地方旅游吗?」这使得我们得到的标记照片质量提高很多,但是,我们还是得到了一些这样的标记照片,一些东南亚寻找工作的人认为在大酒店里有很多西装革履拿着酒杯的会议照片非常吸引人。这些被不正确标记的照片提醒我们当下生活的泡沫时代,但是这是个实际的问题,因为我们的目标人群是美国人,他们会觉得会议照片非常的乏味又无趣。最终,我们 Jetpac 组的六个人手动标记了超过 200 万的照片,因为我们比任何其他人更清楚打分的标准是什么。
这是一个比较极端的例子,但是它说明了数据标记过程很大程度上取决于你应用的要求。对于大部分产品应用的案例,开发人员需要花很长的时间来搞清楚我们到底需要模型回答一个什么样的问题,而搞清楚这个问题是非常关键的。如果你的模型在回答一个错误的问题,那么在这个基础上,你将永远不能创造一个好的用户体验。
图片来自Thomas Hawk
我认为唯一能够确认你是否在问正确的问题的方法是模拟你的应用,但是不是用机器学习模型而是用人的反馈。这个方法有时被称为「Wizard-of-Oz-ing」,因为幕后有人在操作。在 Jetpac 这个案例上,我们请人手动从一些旅游指南的样本中选择一些照片,并用测试用户的反馈来调整我们选择照片的标准,而不是训练一个模型来做这件事。一旦我们的测试反馈都是肯定的,我们就把***的照片选择标准作为标记的准则来对训练集里数百万的照片进行打分做标记。这些照片之后训练的模型有能力对数以亿计的照片进行高精度的预测,但是这个模型的核心起源于我们手动标记确定出来的打分标准。
用同源数据做训练
在 Jetpac 这个项目上,我们用于训练模型的图片(主要来源于 Facebook 和 Instagram)和我们最终使用模型的图片是同源的。但是,一个普遍的问题在于,用于训练的数据常常和最终要应用模型处理的数据在一些很重要的特征上不一致。例如,我常常遇到有些团队在 ImageNet 上训练他们的模型,但是最终他们的模型却是要用于解决无人机或者机器人图片的问题。这样做也是有道理的,因为 ImageNet 上有很多人拍摄的照片,而这些照片和无人机或者机器人得到的照片有很多共性。ImageNet 上的照片很多是手机或者相机,使用中性镜头,在大概一人的高度,自然光或者人工打光下,并保证被标记的对象处于前景中心位置这些条件下拍摄的。机器人和无人机使用摄像照相机,通常用视角镜头,从地面或者高空在光线较弱,也不会使用智能定位的情况下拍摄照片,所以这些照片中的对象常常是不完整的。这些图片特性的差异最终会导致在 ImageNet 上训练的模型,当应用于这些器械得到的图片上时,精度是很低的。
还有一些更巧妙的方法可以让你的训练数据偏离最终的应用显示出来。想象一下,你正在建造一个相机来识别野生动物,并利用世界各地的动物数据集进行训练。如果你只用 Borneo 丛林中的数据来部署模型,那么分类标签为企鹅的概率会非常低。但假如训练集中加入南极的照片,那很有可能其他的动物会被误认为企鹅,所以比起训练集没有这些南极照片的时候,模型的整体错误率可能会更高。
有时候可以通过先验知识来校正模型的结果(比如在丛林环境中大量降低企鹅的概率),但是更简单而有效的方法是使用可以反映真实的产品环境的训练集。所以我认为***的方法还是始终使用从实际应用中直接获取的数据,这与我上面建议的「Wizard-of-Oz-ing」可以联系起来。人工干预可以是对初始数据集进行标记,即使收集到的标签数量非常少,它们也可以反映实际使用情况,并有望进行一些迁移学习的初步实验。
联系数据分析指标
当我在处理语音命令示例时,最频繁出现的报告之一就是训练过程中的混淆矩阵。下面是控制台中显示的例子:
这可能看起来很吓人,但实际上它只是一张表格,显示了神经网络出错的细节。这里有一个更漂亮一些的标签版本:
该表中的每一行表示某个标签下的语音被预测为所有类别标签的具体数量,每一列显示了被预测为某个具体的标签的原型的分布。例如,突出显示的这行表示这些音频样本实际上都属于 Slience,从左到右可以看到每一个样本都在 Slience 这一列中, 表示预测的这些标签都是正确的。这说明此模型非常善于预测 Slience 类,没有错误的否定任何一个 Slience 样本。如果我们查看 Slience 这一整列,可以看到有多少其他样本被预测为 Slience 标签,我们可以看到有很多单词片段被误认为是 Slience 类,假阳性数量相当多。事实证明这是有帮助的,因为它让我更仔细地观察那些被错误地归类为 Slience 的片段,这其中很多录音都出现了异常低音的情况。这帮助我通过删除低音量的剪辑来提高数据的质量,如果没有混乱矩阵,我不会知道应该这样做。
几乎任何一种总结都有助于改进实验结果,但是我发现混淆矩阵是一个很好的折衷方案,它给出的信息比一个精确的数字要多,但不会因为细节太多而使我困惑。在训练期间观察数字变化也很有用,因为它可以告诉您模型正在努力学习的类别,这样就可以给你提供需要清理和扩展的数据集区域。
物以类聚
我最喜欢的理解方式之一是让网络来解释我的训练数据----可视化聚类。TensorBoard 非常支持这样的探究方法,虽然它经常被用于查看词嵌入向量,但我发现它几乎适用于任何类似于嵌入的层。例如,图像分类网络通常把***的全连接层或 softmax 单元之前的倒数第二层当作一个嵌入层 (像 Tensorflow for poets 的简单迁移学习的示例一样)。这些并不是严格意义上的嵌入,因为在训练过程中没有采取任何手段来确保它们会在一个真实的嵌入层中有理想的空间属性,但是对它们的向量进行聚类确实会产生有趣的结果。
有一个真实的事例是,我工作的其中一个团队难以理解为什么某些动物在图像分类模型中有很高错误率。所以他们使用聚类可视化来查看他们的训练数据在不同类别下的分布情况,当他们看到「Jaguar」时,他们清楚地看到数据被分成两个不同的组,彼此之间有一定的距离。
图片来自 djblock99 and Dave Adams
这是他们看到的图。图中每个照片都映射成了两个数据集中的点,很明显,许多捷豹品牌的汽车被错误地贴上了美洲虎的标签。一旦他们知道这个方法,他们就能看到标记过程,这样就可以意识到他们的研究方向混淆了用户使用界面。有了这些信息,他们就能够改进标记人员的培训过程并修复这个工具,即只要将所有的汽车图像从 Jaguar 类别中删除,就可以使模型中这个类别获得更高的准确率。
聚类让你对训练集有了深入的了解,并且像直接查看数据一样非常便利,但是网络实际上只是通过它自己的学习理解将输入数据分组来指导你进行探究。作为人类,我们很擅长在视觉上发现异常,所以结合我们的直觉和计算机处理大量输入的能力,可以为追踪数据质量的问题提供了一个非常好的可扩展解决方案。关于使用 TensorBoard 进行此操作的完整教程超出了本文的范围 (我很感激您还在阅读本文),但是如果你真的想提高实验结果,我强烈建议你熟悉这个工具。
经常收集数据
我从来没有见过收集更多的数据而不能提高模型的准确性的情况,事实证明有很多研究支持我这一看法。
这张图来自于「重新审视数据的不合理有效性」,展示了当训练数据集的规模增长到数亿时,图像分类的模型精度是如何不断提高的。Facebook 最近更深入地进行了这方面的研究,使用数十亿张有标签的 Instagram 图片,刷新了 ImageNet 分类精度的记录。这表明,即使很难获得大型、高质量数据集,增加训练集的大小仍然会提高模型结果。
这意味着,只要有任何更好的模型精度使用户受益,就需要一个持续更新数据集的策略。如果可以,寻找一个创造性的方法去利用微弱的信号,以此访问更大的数据集。对此,Facebook 采用 Instagram 标签就是一个很好的案例。另一种途径是提高标记流水线的智能化,比如利用模型的初始版本预测标签,通过增强工具使贴标者做出更快的决定。这样会有初始偏差复现的风险,但在实践中,收益往往高于这种风险。在这个问题上投入大量资金去聘请更多的人为新的训练输入做标记,通常是一个有价值的投资。尽管传统上这类没有预算支出的项目组织可能会有困难。如果你是一个非盈利者,让你的支持者更容易通过一些开源工具提供数据,这是一种增加你数据规模且不会导致破产的好方法。
当然,对任何组织来说一个自然地生成更多标记数据的产品就是圣杯。我不会太在意这种观念,但这不适合现实世界的多数案例,人们只想免除复杂的标记问题而尽快得到答案。对于一家初创公司来说这是一个好的投资热点,因为它就像一个用于模型改善的永动机,但几乎总是会有一些单位成本用于清理或增加你将收到的数据,因此经济学往往最终将它看做一个比真正免费的商业众包更便宜的版本。
通往风险区的高速公路
对于应用程序的用户来说,模型误差几乎比损失函数捕获有更大的影响。你需要提前考虑最差的可能结果,并试图给模型设置一个反馈抑制来避免这些结果的发生。这可能只是一个你永远不想预测的类别的黑名单,因为错误的代价非常高,或者你可能有一个简单的算法规则来保证所做行为不会超出你所考虑到的一些边界参数。例如,你可能会保留一个永远不希望文本编辑器输出的誓言表,即使是训练集中也不行,因为它不适合你的产品。
不好的结果可能会被考虑到,但事先不总是如此明显,所以从现实错误中汲取教训是至关重要的。一旦你有一半的产品/市场是体面的,一个最简单的方法是运用有缺陷的报告。当人们使用你的系统应用,并从模型中得到一个不想要的结果,让他们容易告诉你。如果可能的话,获得模型的完整输入,但如果是敏感数据,只要知道不良输出是什么,这样可帮助指导你的调查。这些类别可以帮助你选择收集更多数据的位置和了解当前标签质量属于哪些级别。一旦你的模型有一个新的版本,除了正常的测试集外,利用先前产生不良结果的输入,并对这部分输入的测试结果进行单独评估。这种改进方法有点像回归测试,并给你提供一种方法来跟踪你改善用户体验的效果,因为单个模型精度指标将永远无法完全捕捉到人们所关心的所有信息。通过查看一小部分过去引发强烈反应的例子,你可以获得一些独立的证据来表明你确实正在为用户提供更好的服务。如果因为过于敏感以至于你无法给模型获得这些输入数据,可采用自我测试或内部实验的方式来确定哪些输入会产生这些错误,然后用回归集来替换它们。
故事是什么,牵牛花?
我希望我已成功说服你在数据上花费更多的时间,并且给了你一些如何投资改善它的观点。我对这个有价值的领域没有关注太多,仅在这里给了一些肤浅的建议,所以我感谢每一个与我分享他们策略的人,并且我希望我将会听到更多关于你已取得成功的方法。我认为将会有越来越多的专业工程师团队组织专注于数据集的优化改善,而不是留给 Ml 研究人员来推动进展,我期待看到整个领域因为这些而得到发展。我总是惊叹即使针对严重缺陷的训练数据,模型一样会运作良好,因此我迫不及待的想看到我们的数据集模型改进以后还能做些什么。