译者 | 布加迪
审校 | 重楼
近年来,Keras和Tensorflow组合遇到了一种与之竞争的框架:JAX,它在深度学习开发者社区逐渐变得很重要。那么JAX到底是什么?它有哪些功能?它与Keras API又有什么相似和不同之处?Keras API一直是使用Tensorflow(最庞大的Python深度学习库)的几乎通用的方法。本文逐一解答了这些问题。
Keras是什么?
Keras诞生于2015年,这种接口用于简化使用成熟的库来构建神经网络架构,比如Tensorflow。尽管Keras最初作为一种独立的框架而创建,但它最终成为了与Tensorflow结合使用的框架:Tensorflow是用于高效训练和使用可扩展深度神经网络的主要Python库。随后,Keras成为Tensorflow上面的抽象层:换句话说,它使“原始”Tensorflow用起来变得容易多了。
Keras便于实现神经网络架构最常见的构建模块:神经元层、目标及激活函数以及优化器等等。特殊类型的深度神经网络架构使用Keras抽象类和方法可以轻松构建,比如卷积神经网络(CNN)和循环神经网络(RNN)。
JAX是什么?
JAX是一种比较新的框架,不仅适用于深度学习,还适用于整个机器学习开发。它于2018年由谷歌发布,侧重于高性能数值计算。具体来说,JAX使Python和numpy(其最大的数值计算库)用起来更简单更快捷,同时无缝支持GPU和TPU高性能处理。就科学计算和数值计算而言,这是相对普通numpy的一个重要优势,因为numpy只支持CPU执行。
由于兼顾高性能执行模式的直观性和多功能,JAX正迅速名声大噪,成为机器学习和深度学习开发的最先进框架,有机会最终取代Tensorflow和PyTorch等其他框架。它的自动微分特性有助于高效地执行训练深度神经网络背后的基于梯度的复杂计算。
简而言之,JAX将科学计算和高性能计算的功能整合到单单一个框架中。
Keras和JAX的异同
现在我们对Keras和JAX已有了大致的了解,下面列出这两种框架共有的一些特性和诸多不同的方面。
相似之处:
- 深度学习模型开发:这两种框架都被广泛用于构建和训练深度学习模型。
- GPU/TPU加速:Keras和JAX都可以利用GPU和TPU等加速硬件高效地训练模型。
- 自动微分:这两种框架结合了自动计算梯度的机制,梯度计算是模型在训练过程中优化的关键过程。
- 与深度学习库的互操作性:这两种框架都与流行的深度学习库TensorFlow兼容。
差异之处:
- 抽象级别:虽然两种解决方案都提供了一定程度的抽象,但Keras更适合寻求高级API且易于使用的用户;而JAX更注重控制的灵活性,停留在较低的抽象级别,专注于数值计算。
- 后端:Keras完全基于并依赖Tensorflow作为后端。同时,JAX不依赖Tensorflow,而是使用一种名为适时(JIT)编译的方法。话虽如此,JAX和Tensorflow可以结合使用,它们在某些情况下可以很好地互补,比如将高级数学转换整合到高级深度学习架构中。
- 易用性:与抽象级别密切相关,Keras旨在易于快速使用。虽然JAX功能更强大,但需要更深入的技术知识才能顺利地使用它。
- 函数转换:这是JAX独有的特性,允许高级转换功能,比如自动向量化和并行执行。
- 自动优化:同样,JAX在这方面很突出,它更加灵活,便于在神经网络范围之外优化各种函数(这就是为什么它也适用于其他机器学习方法,比如集成学习),Keras专门专注于深度学习模型。
我该选择哪种框架?
了解了这两种框架之间的异同之后,根据手头的问题或场景决定选择哪种框架就不是什么麻烦事了。
如果用户寻求易用性、更平缓的学习曲线和更高的抽象级别,Keras是不二的选择。这个基于Tensorflow库的API将使用户能够在短时间内构建原型,并利用各种深度学习模型处理预测和推理任务。
另一方面,对于经验丰富的开发人员来说,JAX是一种更强大、更通用的选择,可以获得优化计算和高级函数转换之类的附加功能,而不是严格局限于Tensorflow或深度学习建模,不过它需要用户更大的控制度和低级工程决策。
原文标题:Keras vs. JAX: A Comparison,作者:Iván Palomares Carrascosa