近几年,深度学习发展的非常迅速。在雅虎,我们发现,为了从海量数据中获得洞察力,需要部署分布式深度学习。现有的深度学习框架常常要求为深度学习单独设定集群,迫使我们要为一个机器学习流程(见下图 1)创建多个程序。
设定独立的集群则需要我们转移大数据集,带来了不必要的系统复杂性和端到端的学习延迟。
去年我们通过开发和公开 CaffeOnSpark 解决了 scaleout 的问题,我们开源的框架支持在相同的 Spark 和 Hadoop 集群上进行分布式深度学习和大数据处理。我们在雅虎内部使用 CaffeOnSpark 改善了我们的 NSFW 图像检测,自动从实况录像中识别电竞比赛片段等等。在社区大量有价值的反馈和贡献下,CaffeOnSpark 已经得到了更新,现在可以支持 LSTM,有了一个新的数据层,可以训练与测试交错,有了一个 Python API,和 Docker container 的部署。这些都提升了我们的用户体验。但是那些使用 TensorFlow 框架的人怎么办?于是我们效仿了之前的做法,开发了 TensorFlowOnSpark。
TensorFlow 公开后,谷歌于 2016 年 4 月就开放了一个带有分布式学习功能的增强版 TensorFlow。2016 年 10 月,TensorFlow 开始支持 HDFS。然而在谷歌云之外,用户仍然需要一个 TensorFlow 应用的专用集群。TensorFlow 程序无法在现有的大数据集群上部署,这样一来,那些想大规模使用这个技术的人就需要花更多的成本和时间。
为了打破这个限制,几个社区项目将 TensorFlow 连接到 Spark 集群上。SparkNet 让 Spark 执行器获得了可以运行 TensorFlow 网络的能力。DataBricks 提出 tensorframe,用来使用 TensorFlow 程序操纵 Apache Spark 的数据帧。虽然这些方法都朝着正确的方向迈出了一步,但是我们检查他们的代码后发现,我们无法让多个 TensorFlow 过程直接相互沟通,我们也无法实现异步分布式学习,并且我们需要在迁移现有的 tensorflow 程序上花大功夫。
TensorFlowOnSpark
我们的新框架,TensorFlowOnSpark(TFoS),支持 TensorFlow 在 Spark 和 Hadoop 上的分布式运行。如上图(图 2)所示,TFoS 与 SparkSQL、MLlib 以及其他的 Spark 库一起在一个项目或线程(pipeline)中运行。
TFoS 支持所有类型的 TensorFlow 程序,能实现同步和异步的训练与推理。并且支持模型和数据的平行处理,以及 TensorFlow 工具(如 TensorBoard)在 Spark 群集上使用。
任何 TensorFlow 程序都能够很容易通过修改实现在 TFoS 上运行的。通常情况下,只需要修改少于 10 行的 Python 代码。很多在雅虎平台上使用 TensorFlow 的开发者,已经轻松将 TensorFlow 项目转移到 TFoS 上执行了。
TFoS 支持张量(tensor)在 TensorFlow 处理过程中(计算节点和参数服务节点)信息的直接沟通。过程到过程(Process-to-process)的直接沟通机制使 TFoS 项目很容易在增加的机器上进行扩展。如图 3 所示,TFoS 不需要 Spark 驱动器(driver)参与到张量沟通中来,因此也就与具备类似于独立 TensorFlow 群集的扩展能力。
TFoS 提供两种不同模式来「吞入」用于训练和推理的数据 :
1. TensorFlow QueueRunners:TFoS 利用 TensorFlow 的文件读取(file readers)和 QueueRunners 来直接从 HDFS 文件中读入数据。在数据获取过程中不需要 Spark 参与。
2. Spark 供给:Spark RDD 数据将会被传输至每一个 Spark 执行器里,Spark 执行器会进一步将数据传入 TensorFlow 图(通过 feed_dict 参数)。
图 4 展示了 Inception 图像分类网络中同时进行的分布式训练如何在 TFoS 中通过 QueueRunners 的一个简单设置进行扩展:将每个计算节点设置为 1 个 GPU,一个读入(reader)以及批处理数为 32。四个 TFoS 的任务同时进行以用于训练 10 万步。两天多后,当这些任务完成时,top-5 精确度(accuracy)分别为 0.730, 0.814, 0.854,0.879。0.730 的精确度需要单计算节点运行 46 小时得到,双计算节点需要 22.5 个小时,4 计算机点需要 13 小时,8 计算节点需要 7.5 个小时。在 Inception 模型训练上,TFoS 几乎能达到线性扩展。这是很鼓舞人心的,虽然 TFoS 在不同模型和超参数上的扩展能力不同。
用于分布式 TensorFlow 的 RDMA
在雅虎的 Hadoop 集群上,GPU 节点通过以太网和无线宽带相互连接。无线宽带提供了高速的连接,并支持在 RDMA 中直接访问其他服务器的存储。然而目前 TensorFlow 仅支持在以太网上使用 「gRPC」 的分布式学习。为了加速分布式学习,我们增强了 TensorFlowC++层,实现了无线宽带上的 RDMA。
为了结合我们发布的 TFoS,我们在 default「gRPC」协议之外,引进了一个新的 TensorFlow 服务器协议。任何分布式 tensorflow 程序都能通过指定 protocol=「grpc_rdma」in tf.train.ServerDef()or tf.train.Server() 来使用我们的增强版的 TensorFlow。
有了这个新协议后,就需要一个 RDMA 汇集管理器(rendezvous manager)来确保张量直接写入远程服务器的内存。我们最大限度地减少张量缓冲的创建:张量缓冲器在开始时分配一次,然后在一个 TensorFlow 工作任务的所用训练步骤中重复使用。从我们早期的大型模型实验,比如 VGG-19 开始,我们的就已经证明了,与现有的 gRPC 相比,我们的 RDMA 实现在训练时间上带来了显著的提速。
由于 RDMA 支持对性能要求很高(见 TensorFlow issue#2916),我们决定让我们现有的实现版本作为一个预览版向 TensorFlow 社区开放。在未来的几周内,我们将会进一步优化我们的 RDMA 实现,并分享一些基准结果细节。
简单的 CLI 和 API
TFoS 程序是通过标准的 ApacheSpark 命令 spark-submit 运行的。如下所示,用户可以在 CLI 中指定 Spark 执行器的数目,每个执行器所用的 GPU 数目以及参数服务节点数。用户还可以表明愿意使用 TensorBoard(–tensorboard)还是 RDMA(–rdma)。
- spark-submit –master ${MASTER} \
- ${TFoS_HOME}/examples/slim/train_image_classifier.py \
- –model_name inception_v3 \
- –train_dir hdfs://default/slim_train \
- –dataset_dir hdfs://default/data/imagenet \
- –dataset_name imagenet \
- –dataset_split_name train \
- –cluster_size ${NUM_EXEC} \
- –num_gpus ${NUM_GPU} \
- –num_ps_tasks ${NUM_PS} \
- –sync_replicas \
- –replicas_to_aggregate ${NUM_WORKERS} \
- –tensorboard \
- –rdma
TFoS 提供高层次的 Python API(在 Python notebook 的范例中有显示):
- TFCluster.reserve()... 从 Spark 执行器构建一个 TensorFlow 群集
- TFCluster.start()... 在执行器上加载 TensorFlow 程序
- TFCluster.train() or TFCluster.inference() …将 RDD 数据传入 TensorFlow 处理
- TFCluster.shutdown() …在执行器中结束 TensorFlow 的运行
开源
- TensorFlowOnSpark 开源地址: github.com/yahoo/TensorFlowOnSpark
- RDMA 增强版开源地址: github.com/yahoo/tensorflow/tree/yahoo
- 提供多示例程序(包括 MNIST,Cifar10,Inception,and VGG)以说明 TensorFlow 程序到TensorFlowOnspar转换过程,并且利用 RDMA。地址:https://github.com/yahoo/TensorFlowOnSpark/tree/master/examples
- 提供一张亚马逊机器图像用于在 AWS EC2 上应用 TensorFlowOnSpark。接着,与 CaffeOnSpark 一样,我们会推进 TensorFlowOnSpark。地址:https://github.com/yahoo/TensorFlowOnSpark/wiki/GetStarted_EC2
【本文是51CTO专栏机构机器之心的原创文章,微信公众号“机器之心( id: almosthuman2014)”】