谷歌正式开源 Hinton 胶囊理论代码,即刻用 TensorFlow 实现吧

新闻 开源
AI 研习社消息,相信大家对于「深度学习教父」Geoffery Hinton 在去年年底发表的胶囊网络还记忆犹新,在论文 Dynamic Routing between Capsules 中,Hinton 团队提出了一种全新的网络结构。

[[219212]]

AI 研习社消息,相信大家对于「深度学习教父」Geoffery Hinton 在去年年底发表的胶囊网络还记忆犹新,在论文 Dynamic Routing between Capsules 中,Hinton 团队提出了一种全新的网络结构。为了避免网络结构的杂乱无章,他们提出把关注同一个类别或者同一个属性的神经元打包集合在一起,好像胶囊一样。在神经网络工作时,这些胶囊间的通路形成稀疏激活的树状结构(整个树中只有部分路径上的胶囊被激活)。这样一来,Capsule 也就具有更好的解释性。

在实验结果上,CapsNet 在数字识别和健壮性上都取得了不错的效果。详情可以参见 终于盼来了Hinton的Capsule新论文,它能开启深度神经网络的新时代吗?

日前,该论文的第一作者 Sara Sabour 在 GitHub 上公布了论文代码,大家可以马上动手实践起来。雷锋网 AI 研习社将教程编译整理如下:

所需配置:

执行 test 程序,来验证安装是否正确,诸如:

python layers_test.py 

快速 MNIST 测试:

python experiment.py --data_dir=$DATA_DIR/mnist_data/ --train=false \ --summary_dir=/tmp/ --checkpoint=$CKPT_DIR/mnist_checkpoint/model.ckpt-1 

快速 CIFAR10 ensemble 测试:

python experiment.py --data_dir=$DATA_DIR --train=false --dataset=cifar10 \ --hparams_override=num_prime_capsules=64,padding=SAME,leaky=true,remake=false \ --summary_dir=/tmp/ --checkpoint=$CKPT_DIR/cifar/cifar{}/model.ckpt-600000 \ --num_trials=7 

CIFAR10 训练指令:

python experiment.py --data_dir=$DATA_DIR --dataset=cifar10 --max_steps=600000\ --hparams_override=num_prime_capsules=64,padding=SAME,leaky=true,remake=false \ --summary_dir=/tmp/ 

MNIST full 训练指令:

  • 也可以执行--validate=true as well 在训练-测试集上训练

  • 执行 --num_gpus=NUM_GPUS 在多块GPU上训练

python experiment.py --data_dir=$DATA_DIR/mnist_data/ --max_steps=300000\ --summary_dir=/tmp/attempt0/ 

MNIST baseline 训练指令:

python experiment.py --data_dir=$DATA_DIR/mnist_data/ --max_steps=300000\ --summary_dir=/tmp/attempt1/ --model=baseline 

To test on validation during training of the above model:

训练如上模型时,在验证集上进行测试(记住,在训练过程中会持续执行指令):

  • 在训练时执行 --validate=true 也一样

  • 可能需要两块 GPU,一块用于训练集,一块用于验证集

  • 如果所有的测试都在一台机器上,你需要对训练集、验证集的测试中限制 RAM 消耗。如果不这样,TensorFlow 会在一开始占用所有的 RAM,这样就不能执行其他工作了

python experiment.py --data_dir=$DATA_DIR/mnist_data/ --max_steps=300000\ --summary_dir=/tmp/attempt0/ --train=false --validate=true

大家可以通过 --num_targets=2 和 --data_dir=$DATA_DIR/multitest_6shifted_mnist.tfrecords@10 在 MultiMNIST 上进行测试或训练,生成 multiMNIST/MNIST 记录的代码在 input_data/mnist/mnist_shift.py 目录下。

multiMNIST 测试代码:

python mnist_shift.py --data_dir=$DATA_DIR/mnist_data/ --split=test --shift=6  --pad=4 --num_pairs=1000 --max_shard=100000 --multi_targets=true 

可以通过 --shift=6 --pad=6 来构造 affNIST expanded_mnist

论文地址: https://arxiv.org/pdf/1710.09829.pdf

GitHub 地址: https://github.com/Sarasra/models/tree/master/research/capsules

 

责任编辑:张燕妮 来源: 雷锋网
相关推荐

2018-04-30 18:07:51

谷歌开源编程

2018-01-27 21:26:46

谷歌GitHub功能

2020-09-30 16:15:46

ThreadLocal

2022-03-23 15:19:00

低代码开源阿里巴巴

2020-03-10 10:42:22

量子计算机芯片超算

2016-05-17 14:24:56

亚马逊机器学习

2018-05-04 14:11:34

SwiftPython开发

2013-10-15 09:21:40

2018-01-18 09:55:32

AI 大事件

2019-09-04 09:26:42

谷歌Android开发者

2020-03-12 12:31:01

开源谷歌量子AI

2019-05-14 09:53:31

代码开发工具

2017-11-22 19:00:51

人工智能深度学习胶囊网络

2015-11-12 13:11:17

TensorFlow人工智能系统谷歌

2019-09-04 15:07:15

代码开发开源

2021-04-26 09:04:13

Python 代码音乐

2021-07-12 09:11:23

华为谷歌除名

2022-06-20 14:36:49

TensorFlow机器学习

2017-09-29 09:57:20

2020-04-17 14:48:30

代码机器学习Python
点赞
收藏

51CTO技术栈公众号