五个PyTorch 中的处理张量的基本函数

人工智能 深度学习
每个深度学习初学者都应该知道这5个Pytorch 的基本函数。

 能够以准确有效的方式构建神经网络是招聘人员在深度学习工程师中最受追捧的技能之一。PyTorch 是一个 主要用于深度学习的Python 库。 PyTorch 最基本也是最重要的部分之一是创建张量,张量是数字、向量、矩阵或任何 n 维数组。在构建神经网络时为了降低计算速度必须避免使用显式循环,我们可以使用矢量化操作来避免这种循环。在构建神经网络时,足够快地计算矩阵运算的能力至关重要。

“为什么不使用 NumPy 库呢?”

对于深度学习,我们需要计算模型参数的导数。 PyTorch 提供了在反向传播时跟踪导数的能力而 NumPy 则没有,这在Pytorch中被称为“Auto Grad”。PyTorch 为使用 GPU 的快速执行提供了内置支持。这在训练模型方面至关重要。由于 Numpy 缺乏将其计算转移到 GPU 的能力,因此训练模型的时间最终会变得非常大。

所有使用 PyTorch 的深度学习项目都从创建张量开始。让我们看看一些必须知道的函数,它们是任何涉及构建神经网络的深度学习项目的支柱。

  • torch.tensor()
  • torch.sum()
  • torch.index_select()
  • torch.stack()
  • torch.mm()

在安装完Pytorch后,在代码中可以直接导入:

 

  1. # Import torch and other required modules 
  2. import torch 

 

torch.tensor()

首先,我们定义了一个辅助函数,describe (x),它将总结张量 x 的各种属性,例如张量的类型、张量的维度和张量的内容。

 

  1. # Helper function 
  2. def describe(x): 
  3.   print("Type: {}".format(x.type())) 
  4.   print("Shape/size: {}".format(x.shape)) 
  5.   print("Values: \n{}".format(x) 

 

使用 torch.Tensor 在 PyTorch 中创建张量

PyTorch 允许我们使用 torch 包以多种不同的方式创建张量。 创建张量的一种方法是通过指定其维度来初始化一个随机张量

 

  1. describe(torch.Tensor(2, 3)) 

使用 Python 列表以声明方式创建张量

我们还可以使用 python 列表创建张量。 我们只需要将列表作为参数传递给函数,我们就有了它的张量形式。

 

  1. x = torch.Tensor([[1, 2, 3],[4, 5, 6]])  
  2. describe(x) 

 

使用 NumPy 数组创建张量

我们也可以从NumPy 数组中创建PyTorch 张量。 张量的类型是 Double Tensor 而不是默认的 Float Tensor。 这对应于 NumPy 的数据类型是float64,如下所示。

 

  1. import numpy as np 
  2. npy = np.random.rand(2, 3) 
  3. describe(torch.from_numpy(npy)) 

 

我们不能用张量做什么?张量必须是实数或复数,不应是字符串或字符。

 

  1. torch.tensor([[1, 2], [3, 4, 5]]) 
  2.  
  3.  
  4. --------------------------------------------------------------------------- 
  5. ValueError                                Traceback (most recent call last
  6. <ipython-input-5-28787d136593> in <module> 
  7.       1 # Example 3 - breaking (to illustrate when it breaks) 
  8. ----> 2 torch.tensor([[1, 2], [3, 4, 5]]) 
  9.  
  10. ValueError: expected sequence of length 2 at dim 1 (got 3) 

 

torch.tensor() 构成了任何 PyTorch 项目的核心,从字面上看,因为它就是张量。

 

torch.sum()

此函数返回输入张量中所有元素的总和。

 

  1. describe(torch.sum(x, dim=0,keepdims=True)) 

如果你了解 NumPy ,可能已经注意到,对于 2D 张量,我们将行表示为维度 0,将列表示为维度 1。torch.sum() 函数允许我们计算行和列的总和。

我们还为 keepdims 传递 True 以保留结果中的维度。 通过定义 dim = 1 我们告诉函数按列折叠数组。

 

  1. torch.sum(npy,dim=1,keepdims=True
  2.  
  3. --------------------------------------------------------------------------- 
  4. TypeError                                 Traceback (most recent call last
  5. <ipython-input-17-1617bf9e8a37> in <module>() 
  6.       1 # Example 3 - breaking (to illustrate when it breaks) 
  7. ----> 2 torch.sum(npy,dim=1,keepdims=True) 
  8.  
  9. TypeError: sum() received an invalid combination of arguments - got (numpy.ndarray, keepdims=bool, dim=int), but expected one of
  10.  * (Tensor input, *, torch.dtype dtype) 
  11.       didn't match because some of the keywords were incorrect: keepdims, dim 
  12.  * (Tensor input, tuple of ints dim, bool keepdim, *, torch.dtype dtype, Tensor out
  13.  * (Tensor input, tuple of names dim, bool keepdim, *, torch.dtype dtype, Tensor out

 

该函数在计算指标和损失函数时非常有用。

torch.index_select()

这个函数返回一个新的张量,该张量使用索引中的条目(LongTensor)沿维度 dim 对输入张量进行索引。

 

  1. indices = torch.LongTensor([0, 2]) 
  2. describe(torch.index_select(x, dim=1, index=indices)) 

 

我们可以将索引作为张量传递并将轴定义为 1,该函数返回一个新的张量大小 rows_of_original_tensor x length_of_indices_tensor。

 

  1. indices = torch.LongTensor([0, 0]) 
  2. describe(torch.index_select(x, dim=0, index=indices)) 

 

我们可以将索引作为张量传递并将轴定义为 0,该函数返回大小为

columns_of_original_tensor x length_of_indices_tensor 的新张量。

 

  1. indices = torch.FloatTensor([0, 2]) 
  2. describe(torch.index_select(x, dim=1, index=indices)) 

此函数在张量的非连续索引这种复杂索引中很有用。

torch.stack()

这将沿新维度连接一系列张量。

 

  1. describe(torch.stack([x, x, x],dim = 0)) 

我们可以将我们想要连接的张量作为一个张量列表传递,dim 为 0,以沿着行堆叠它。

 

  1. describe(torch.stack([x, x, x],dim = 1)) 

我们可以将我们想要连接的张量作为一个张量列表传递,dim 为 1,以沿着列堆叠它。

 

  1. y = torch.tensor([3,3]) 
  2. describe(torch.stack([x, y, x],dim = 1)) 
  3.  
  4. -------------------------------------------------------------------------- 
  5. RuntimeError                              Traceback (most recent call last
  6. <ipython-input-37-c97227f5da5c> in <module>() 
  7.       1 # Example 3 - breaking (to illustrate when it breaks) 
  8.       2 y = torch.tensor([3,3]) 
  9. ----> 3 describe(torch.stack([x, y, x],dim = 1)) 
  10.  
  11. RuntimeError: stack expects each tensor to be equal size, but got [2, 3] at entry 0 and [2] at entry 1 

 

该函数与torch.index_select()结合使用非常有用,可以压扁矩阵。

torch.mm()

此函数执行矩阵的矩阵乘法。

 

  1. mat1 =torch.randn(3,2) 
  2. describe(torch.mm(x, mat1)) 

 

只需将矩阵作为参数传递,我们就可以轻松地执行矩阵乘法,该函数将产生一个新的张量作为两个矩阵的乘积。

 

  1. mat1 = np.random.randn(3,2) 
  2. mat1 = torch.from_numpy(mat1).to(torch.float32) 
  3. describe(torch.mm(x, mat1)) 

 

在上面的例子中,我们定义了一个 NumPy 数组然后将其转换为 float32 类型的张量。 现在我们可以成功地对张量执行矩阵乘法。 两个张量的数据类型必须匹配才能成功操作。

 

  1. mat1 =torch.randn(2,3) 
  2. describe(torch.mm(x, mat1)) 
  3.  
  4. --------------------------------------------------------------------------- 
  5. RuntimeError                              Traceback (most recent call last
  6. <ipython-input-62-18e7760efd23> in <module>() 
  7.       1 # Example 3 - breaking (to illustrate when it breaks) 
  8.       2 mat1 =torch.randn(2,3) 
  9. ----> 3 describe(torch.mm(x, mat1)) 
  10.  
  11. RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x3 and 2x3) 

 

为了执行成功的矩阵乘法运算,矩阵1的列和矩阵2的行必须匹配。 torch.mm() 函数遵循的是矩阵乘法的基本规则。 即使矩阵的顺序相同,它仍然不会自动与另一个矩阵的转置相乘,用户必须手动定义它。

为了在反向传播时计算导数,必须能够有效地执行矩阵乘法,这就是 torch.mm () 出现的地方。

总结

我们对 5 个基本 PyTorch 函数的研究到此结束。 从基本的张量创建到具有特定用例的高级和鲜为人知的函数,如 torch.index_select (),PyTorch 提供了许多这样的函数,使数据科学爱好者的工作更轻松。

责任编辑:华轩 来源: 今日头条
相关推荐

2022-11-15 16:37:38

PyTorch抽样函数子集

2023-02-13 16:42:08

云计算CloudOps工具

2024-10-22 15:51:42

PyTorch张量

2024-08-14 16:06:02

2021-11-05 12:59:51

深度学习PytorchTenso

2023-12-27 14:19:33

Python内置函数开发

2024-03-01 20:55:40

Pytorch张量Tensor

2020-07-03 11:30:12

首席信息官数据资产

2020-07-03 14:06:37

大数据CIO技术

2021-08-11 09:33:15

Vue 技巧 开发工具

2024-07-29 10:46:50

2015-11-12 10:45:11

问题系统Linux

2024-10-07 08:37:34

PyPDF2PDF代码

2021-07-27 18:02:01

VueUse 函数开发

2010-05-27 17:45:13

MySQL存储过程

2023-05-09 15:01:43

JavaScript编程语言异常处理

2022-08-23 14:57:43

Python技巧函数

2022-02-23 21:22:52

首席数据官CIO

2022-08-29 00:37:53

Python技巧代码

2009-08-25 14:25:19

Eclipse 3.5
点赞
收藏

51CTO技术栈公众号