1行代码消除PyTorch的CUDA内存溢出报错,这个GitHub项目揽星600+

新闻 前端
多少人用PyTorch“炼丹”时都会被这个bug困扰。一般情况下,你得找出当下占显存的没用的程序,然后kill掉。现在,有人写了一个PyTorch wrapper,用一行代码就能“无痛”消除这个bug。

[[441177]]

本文经AI新媒体量子位(公众号ID:QbitAI)授权转载,转载请联系出处。

 多少人用PyTorch“炼丹”时都会被这个bug困扰。

1行代码消除PyTorch的CUDA内存溢出报错,这个GitHub项目揽星600+

一般情况下,你得找出当下占显存的没用的程序,然后kill掉。

如果不行,还需手动调整batch size到合适的大小……

有点麻烦。

现在,有人写了一个PyTorch wrapper,用一行代码就能“无痛”消除这个bug。

[[441178]]

有多厉害?

相关项目在GitHub才发布没几天就收获了600+星。

1行代码消除PyTorch的CUDA内存溢出报错,这个GitHub项目揽星600+

一行代码解决内存溢出错误

软件包名叫koila,已经上传PyPI,先安装一下:

  1. pip install koila 

现在,假如你面对这样一个PyTorch项目:构建一个神经网络来对FashionMNIST数据集中的图像进行分类。

先定义input、label和model:

  1. # A batch of MNIST image 
  2. input = torch.randn(82828
  3.  
  4. # A batch of labels 
  5. label = torch.randn(010, [8]) 
  6.  
  7. class NeuralNetwork(Module): 
  8.     def __init__(self): 
  9.         super(NeuralNetwork, self).__init__() 
  10.         self.flatten = Flatten() 
  11.         self.linear_relu_stack = Sequential( 
  12.             Linear(28 * 28512), 
  13.             ReLU(), 
  14.             Linear(512512), 
  15.             ReLU(), 
  16.             Linear(51210), 
  17.         ) 
  18.  
  19.     def forward(self, x): 
  20.         x = self.flatten(x) 
  21.         logits = self.linear_relu_stack(x) 
  22.         return logits 

然后定义loss函数、计算输出和losses。

  1. loss_fn = CrossEntropyLoss() 
  2.  
  3. # Calculate losses 
  4. out = nn(t) 
  5. loss = loss_fn(out, label) 
  6.  
  7. # Backward pass 
  8. nn.zero_grad() 
  9. loss.backward() 

好了,如何使用koila来防止内存溢出?

超级简单!

只需在第一行代码,也就是把输入用lazy张量wrap起来,并指定bacth维度——

koila就能自动帮你计算剩余的GPU内存并使用正确的batch size了。

在本例中,batch=0,则修改如下:

  1. input = lazy(torch.randn(82828), batch=0

完事儿!就这样和PyTorch“炼丹”时的OOM报错说拜拜。

灵感来自TensorFlow的静态/懒惰评估

下面就来说说koila背后的工作原理。

“CUDA error: out of memory”这个报错通常发生在前向传递(forward pass)中,因为这时需要保存很多临时变量。

koila的灵感来自TensorFlow的静态/懒惰评估(static/lazy evaluation)。

它通过构建图,并仅在必要时运行访问所有相关信息,来确定模型真正需要多少资源。

而只需计算临时变量的shape就能计算各变量的内存使用情况;而知道了在前向传递中使用了多少内存,koila也就能自动选择最佳batch size了。

又是算shape又是算内存的,koila听起来就很慢?

[[441179]]

NO。

即使是像GPT-3这种具有96层的巨大模型,其计算图中也只有几百个节点。

而Koila的算法是在线性时间内运行,任何现代计算机都能够立即处理这样的图计算;再加上大部分计算都是单个张量,所以,koila运行起来一点也不慢。

你又会问了,PyTorch Lightning的batch size搜索功能不是也可以解决这个问题吗?

是的,它也可以。

但作者表示,该功能已深度集成在自己那一套生态系统中,你必须得用它的DataLoader,从他们的模型中继承子类,才能训练自己的模型,太麻烦了。

koila灵活又轻量,只需一行代码就能解决问题,非常“大快人心”有没有。

不过目前,koila还不适用于分布式数据的并行训练方法(DDP),未来才会支持多GPU

1行代码消除PyTorch的CUDA内存溢出报错,这个GitHub项目揽星600+

以及现在只适用于常见的nn.Module类。

1行代码消除PyTorch的CUDA内存溢出报错,这个GitHub项目揽星600+

ps. koila作者是一位叫做RenChu Wang的小哥。

1行代码消除PyTorch的CUDA内存溢出报错,这个GitHub项目揽星600+

项目地址:
https://github.com/rentruewang/koila

 

 

责任编辑:张燕妮 来源: 量子位
相关推荐

2020-04-14 15:00:04

PyTorchGitHub检测

2024-07-10 12:41:40

数据训练

2020-08-05 17:16:53

GitHub 技术开源

2021-09-18 11:28:29

GitHub代码开发者

2020-11-26 15:48:37

代码开发GitHub

2020-12-10 10:24:25

AI 数据人工智能

2020-12-30 10:35:49

程序员技能开发者

2023-07-22 13:47:57

开源项目

2020-12-07 16:14:40

GitHub 技术开源

2020-08-03 10:42:10

GitHub代码开发者

2021-08-09 15:56:43

机器学习人工智能计算机

2020-02-20 10:00:04

GitHubPyTorch开发者

2012-07-23 09:58:50

代码程序员

2021-04-09 16:25:00

GitHub代码开发者

2021-05-26 08:02:03

ThreadLocal多线程多线程并发安全

2015-03-30 11:18:50

内存管理Android

2020-08-21 13:55:56

微软开源PyTorch

2019-07-05 15:42:58

GitHub代码开发者

2012-05-15 02:04:22

JVMJava

2024-09-09 09:41:03

内存溢出golang开发者
点赞
收藏

51CTO技术栈公众号