一文读懂全新深度学习库Rust Burn

人工智能 深度学习
Rust Burn是一个全新的完全使用Rust编写的深度学习框架,具有灵活性、高性能和易用性的特点。

一、什么是Rust Burn?

Rust Burn是一个全新的深度学习框架,完全使用Rust编程语言编写。创建这个新框架而不是使用现有框架(如PyTorch或TensorFlow)的动机是为了构建一个适应多种用户需求的通用框架,包括研究人员、机器学习工程师和底层软件工程师。

Rust Burn的关键设计原则包括灵活性、高性能和易用性。

灵活性:能够快速实现前沿研究想法,并进行实验。

高性能:通过优化措施,例如利用特定硬件功能,如Nvidia GPU上的张量内核(Tensor Cores)。

易用性:简化训练、部署和运行模型的工作流程。

Rust Burn的主要特点:

  • 灵活而动态的计算图。
  • 线程安全的数据结构。
  • 直观的抽象,简化开发过程。
  • 在训练和推理过程中实现极快的性能。
  • 支持CPU和GPU的多种后端实现。
  • 完全支持训练过程中的日志记录、度量和检查点功能。
  • 小型但活跃的开发者社区。

二、快速入门

2.1、安装Rust

Burn是基于Rust编程语言的、强大的深度学习框架,需要对Rust有基本的了解,但一旦掌握了这些知识,用户将能够充分利用Burn提供的所有功能。

按照官方指南进行安装。也可以查看GeeksforGeeks在Windows和Linux上安装Rust的指南和截图。

【官方指南】:https://www.rust-lang.org/tools/install

图片来自Install Rust图片来自Install Rust

【安装指南和截图】:https://www.geeksforgeeks.org/how-to-install-rust-on-windows-and-linux-operating-system/

2.2、安装Burn

要使用Rust Burn,首先需要在系统上安装Rust。一旦正确设置了Rust,就可以使用cargo(Rust的软件包管理器)创建新的Rust应用程序。

在当前目录中运行以下命令:

cargo new new_burn_app

导航到这个新目录:

cd new_burn_app

接下来,添加Burn作为依赖项,并添加启用GPU操作的WGPU后端功能:

cargo add burn --features wgpu

最后,编译项目以安装Burn:

cargo build

这将安装Burn框架以及WGPU后端。WGPU允许Burn执行底层的GPU操作。

三、代码示例

3.1、逐元素相加

要运行以下代码,用户需要打开并替换src/main.rs中的内容:

use burn::tensor::Tensor;
use burn::backend::WgpuBackend;

// Type alias for the backend to use.
type Backend = WgpuBackend;

fn main() {
    // Creation of two tensors, the first with explicit values and the second one with ones, with the same shape as the first
    let tensor_1 = Tensor::::from_data([[2., 3.], [4., 5.]]);
    let tensor_2 = Tensor::::ones_like(&tensor_1);

    // Print the element-wise addition (done with the WGPU backend) of the two tensors.
    println!("{}", tensor_1 + tensor_2);
}

main函数使用WGPU后端创建了两个张量,并进行了相加运算。

在终端中运行cargo run,执行该代码。

输出:

查看相加的结果:

Tensor {
  data: [[3.0, 4.0], [5.0, 6.0]],
  shape:  [2, 2],
  device:  BestAvailable,
  backend:  "wgpu",
  kind:  "Float",
  dtype:  "f32",
}

3.2、位置智能前馈模块

以下是使用Burn框架的一个简单示例。示例创建了一个前馈模块,并使用以下代码片段定义了它的前向传播。

use burn::nn;
use burn::module::Module;
use burn::tensor::backend::Backend;

#[derive(Module, Debug)]
pub struct PositionWiseFeedForward<B: Backend> {
    linear_inner: Linear<B>,
    linear_outer: Linear<B>,
    dropout: Dropout,
    gelu: GELU,
}

impl PositionWiseFeedForward<B> {
    pub fn forward(&self, input: Tensor<B, D>) -> Tensor<B, D> {
        let x = self.linear_inner.forward(input);
        let x = self.gelu.forward(x);
        let x = self.dropout.forward(x);

        self.linear_outer.forward(x)
    }
}

3.3、项目示例

要了解更多示例并运行它们,请复制https://github.com/burn-rs/burn存储库,并运行以下项目:

  • MNIST:使用各种后端在CPU或GPU上训练模型。

【MNIST】:https://github.com/burn-rs/burn/tree/main/examples/mnist

  • MNIST网络推理:在浏览器中进行模型推理。

【MNIST网络推理】:https://github.com/burn-rs/burn/tree/main/examples/mnist-inference-web

  • 文本分类:在GPU上从头开始训练一个Transformer编码器。

【文本分类】:https://github.com/burn-rs/burn/tree/main/examples/text-classification

  • 文本生成:在GPU上从头开始构建和训练自回归Transformer。

【文本生成】:https://github.com/burn-rs/burn/tree/main/examples/text-generation

3.4、预训练模型

要构建AI应用程序,可以使用以下预训练模型,并根据数据集对其进行微调。

  • SqueezeNet:squeezenet-burn

【链接】:https://github.com/burn-rs/models/blob/main/squeezenet-burn/README.md

  • Llama 2:Gadersd/llama2-burn

【链接】:https://github.com/Gadersd/llama2-burn

  • Whisper:Gadersd/whisper-burn

【链接】:https://github.com/Gadersd/whisper-burn

  • Stable Diffusion v1.4:Gadersd/stable-diffusion-burn

【链接】:https://github.com/Gadersd/stable-diffusion-burn

四、结论

Rust Burn在深度学习框架领域提供了一个令人兴奋的新选择。如果你已经是一名Rust开发者,就可以利用Rust的速度、安全性和并发性来推动深度学习研究和生产的发展。Burn致力于在灵活性、性能和可用性方面找到合适的折衷方案,从而创建一个适用于各种用例的、独特的多功能框架。

尽管Burn还处于早期阶段,但它在解决现有框架的痛点并满足该领域内各种从业者的需求方面已显示出前景。随着该框架的成熟和社区的发展,它有可能成为与现有框架相媲美的生产就绪框架。其新颖的设计和语言选择为深度学习社区带来了新的可能性。

责任编辑:武晓燕 来源: Python学研大本营
相关推荐

2017-10-24 11:19:16

深度学习机器学习数据

2018-08-16 08:19:30

2023-05-11 15:24:12

2023-12-22 19:59:15

2021-08-04 16:06:45

DataOps智领云

2020-11-08 13:33:05

机器学习数据中毒人工智能

2018-09-28 14:06:25

前端缓存后端

2022-11-06 21:14:02

数据驱动架构数据

2022-09-22 09:00:46

CSS单位

2022-10-20 08:01:23

2021-12-29 18:00:19

无损网络网络通信网络

2022-07-26 00:00:03

语言模型人工智能

2023-05-20 17:58:31

低代码软件

2023-11-27 17:35:48

ComponentWeb外层

2022-07-05 06:30:54

云网络网络云原生

2022-12-01 17:23:45

2019-07-15 10:11:57

深度学习编程人工智能

2018-10-18 11:00:50

人工智能机器学习模型偏差

2023-12-26 14:12:12

人工智能机器学习Gen AI

2017-10-02 16:13:47

深度学习目标检测计算机视觉
点赞
收藏

51CTO技术栈公众号