如何对机器学习代码进行单元测试?

人工智能 机器学习
过去一年我大部分时间都用来做深度学习的研究,和相关实习。大部分时间我都在犯错中学习机器学习,以及括如何完整、正确的实现这些系统。在谷歌大脑项目中,我学到的最主要一点就是单元测试不仅可以帮助实现算法,也能节约大量的调试、训练时间。

目前,关于神经网络代码,并没有一个特别完善的单元测试的在线教程。甚至像 OpenAI 这样的站点,也只能靠 盯着每一行看来思考哪里错了来寻找 bug。很明显,大多数人没有那样的时间,并且也讨厌这么做。所以希望这篇教程能帮助你开始稳健的测试系统。

首先来看一个简单的例子,尝试找出以下代码的 bug。

看出来了吗?网络并没有实际融合(stacking)。写这段代码时,只是复制、粘贴了 slim.conv2d(…) 这行,修改了核(kernel)大小,忘记修改实际的输入。

这个实际上是作者一周前刚刚碰到的状况,很尴尬,但是也是重要的一个教训!这些 bug 很难发现,有以下原因。

  • 这些代码不会崩溃,不会抛出异常,甚至不会变慢。
  • 这个网络仍然能训练,并且损失(loss)也会下降。
  • 运行多个小时后,值回归到很差的结果,让人抓耳挠腮不知如何修复。

只有最终的验证错误这一条线索情况下,必须回顾整个网络架构才能找到问题所在。很明显,你需要需要一个更好的处理方式。

比起在运行了很多天的训练后才发现,我们如何提前预防呢?这里可以明显注意到,层(layers)的值并没有到达函数外的任何张量(tensors)。在有损失和优化器情况下,如果这些张量从未被优化,它们会保持默认值。

因此,只需要比较值在训练步骤前后有没有发生变化,我们就可以发现这种情况。

哇。只需要短短 15 行不到的代码,就能保证至少所有创建的变量都被训练到了。

这个测试,简单但是却很有用。现在问题修复了,让我们来尝试添加批量标准化。看你能否用眼睛看出 bug 来。

发现了吗?这个 bug 很巧妙。在 tensorflow 中,batch_norm 的 is_training 默认值是 False,所以在训练过程中添加这行代码,会导致输入无法标准化!幸亏,我们刚刚添加的那个单元测试会立即捕捉到这个问题!(3 天前,它刚刚帮助我捕捉到这个问题。)

让我们看另外一个例子。这是我从 reddit 帖子中看来的。我们不会太深入原帖,简单的说,发帖的人想要创建一个分类器,输出的范围在 0 到 1 之间。看看你能否看出哪里不对。

发现问题了吗?这个问题很难发现,结果非常难以理解。简单的说,因为预测只有单个输出值,应用了 softmax 交叉熵函数后,损失就会永远是 0 了。

最简单的发现这个问题的测试方式,就是保证损失永远不等于 0。

我们***个实现的测试,也能发现这种错误,但是要反向检查:保证只训练需要训练的变量。就生成式对抗网络(GAN)为例,一个常见的 bug 就是在优化过程中不小心忘记设置需要训练哪个变量。这样的代码随处可见。

这段代码***的问题是,优化器默认会优化所有的变量。在像生成式对抗网络这样高级的结构中,这意味着遥遥无期的训练时间。然而只需要一个简单测试,就可以检查到这种错误:

也可以对判定模型(discriminator)写一个同类型的测试。同样的测试,也可以应用来加强大量其他的学习算法。很多演员评判家(actor-critic)模型,有不同的网络需要用不同的损失来优化。

这里列出一些作者推荐的测试模式。

  • 确保输入的确定性。如果发现一个诡异的失败测试,但是却再也无法重现,将会是很糟糕的事情。在特别需要随机输入的场景下,确保用了同一个随机数种子。这样出现了失败后,可以再次以同样的输入重现它。
  • 确保测试很精简。不要用同一个单元测试检查回归训练和检查一个验证集合。这样做只是浪费时间。
  • 确保每次测试时都重置了图。

作为总结,这些黑盒算法仍然有大量方法来测试!花一个小时写一个简单的测试,可以节约成天的重新运行时间,并且大大提升你的研究能力。天才的想法,永远不要因为一个充满 bug 的实现而无法成为现实。

这篇文章列出的测试远远没有完备,但是是一个很好的起步!如果你发现有其他的建议或者某种特定类型的测试,请在 twitter 上给我消息!我很乐意写这篇文章的续集。

文章中所有的观点,仅代表作者的个人经验,并没有 Google 的支持、赞助。

查看英文原文

https://medium.com/@keeper6928/how-to-unit-test-machine-learning-code-57cf6fd81765  

责任编辑:庞桂玉 来源: AI前线
相关推荐

2019-12-18 10:25:12

机器学习单元测试神经网络

2012-11-01 11:32:23

IBMdw

2012-11-01 11:37:05

JavaScript单元测试测试工具

2021-03-28 23:03:50

Python程序员编码

2017-01-14 23:26:17

单元测试JUnit测试

2017-01-16 12:12:29

单元测试JUnit

2013-12-18 09:56:20

AngularJS测试

2013-06-04 09:49:04

Spring单元测试软件测试

2017-03-23 16:02:10

Mock技术单元测试

2023-12-11 08:25:15

Java框架Android

2017-01-14 23:42:49

单元测试框架软件测试

2022-08-02 08:07:24

单元测试代码重构

2009-09-29 16:21:31

Hibernate单元

2024-03-29 08:03:48

单元测试流量

2009-08-19 09:00:48

单元测试框架自动化测试

2009-06-26 17:48:38

JSF项目单元测试JSFUnit

2011-04-18 13:20:40

单元测试软件测试

2020-09-30 08:08:15

单元测试应用

2021-03-24 09:30:02

Jupyter not单元测试代码

2023-08-02 13:59:00

GoogleTestCTest单元测试
点赞
收藏

51CTO技术栈公众号