.view()对x张量有什么作用?负值是什么意思?
x = x.view(-1, 16 * 5 * 5)
.view()对x张量有什么作用?负值是什么意思?
x = x.view(-1, 16 * 5 * 5)
当前回答
torch.Tensor.view ()
简单地说,torch.Tensor.view()受到numpy.ndarray.重塑()或numpy.重塑()的启发,只要新形状与原始张量的形状兼容,就会创建一个新的张量视图。
让我们通过一个具体的例子来详细理解这一点。
In [43]: t = torch.arange(18)
In [44]: t
Out[44]:
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17])
使用这个张量t of shape(18,),只能为以下形状创建新的视图:
(1,18)或等效的(1,1)或(- 1,18) (2,9)或等效的(2,1)或(1,9) (3,6)或等效的(3,1)或(1,6) (6,3)或等效的(6,1)或(1,3) (9,2)或等效的(9,1)或(1,2) (18,1)或等效的(18,-1)或(- 1,1)
正如我们已经从上面的形状元组中观察到的,形状元组元素的乘法(例如2* 9,3 *6等)必须总是等于原始张量中元素的总数(在我们的例子中是18)。
另一件需要注意的事情是,我们在每个形状元组的其中一个地方使用了-1。通过使用-1,我们在自己进行计算时显得懒惰,而是将任务委托给PyTorch,在它创建新视图时为形状计算该值。需要注意的一件重要的事情是,我们只能在形状tuple中使用一个-1。其余的值应该由我们显式提供。否则PyTorch将通过抛出RuntimeError来投诉:
RuntimeError:只能推断一个维度
因此,对于上面提到的所有形状,PyTorch总是会返回原始张量t的一个新视图。这基本上意味着它只是为每个被请求的新视图改变张量的步幅信息。
下面是一些例子,说明张量的步幅是如何随着每个新视图而改变的。
# stride of our original tensor `t`
In [53]: t.stride()
Out[53]: (1,)
现在,我们将看到新观点的进展:
# shape (1, 18)
In [54]: t1 = t.view(1, -1)
# stride tensor `t1` with shape (1, 18)
In [55]: t1.stride()
Out[55]: (18, 1)
# shape (2, 9)
In [56]: t2 = t.view(2, -1)
# stride of tensor `t2` with shape (2, 9)
In [57]: t2.stride()
Out[57]: (9, 1)
# shape (3, 6)
In [59]: t3 = t.view(3, -1)
# stride of tensor `t3` with shape (3, 6)
In [60]: t3.stride()
Out[60]: (6, 1)
# shape (6, 3)
In [62]: t4 = t.view(6,-1)
# stride of tensor `t4` with shape (6, 3)
In [63]: t4.stride()
Out[63]: (3, 1)
# shape (9, 2)
In [65]: t5 = t.view(9, -1)
# stride of tensor `t5` with shape (9, 2)
In [66]: t5.stride()
Out[66]: (2, 1)
# shape (18, 1)
In [68]: t6 = t.view(18, -1)
# stride of tensor `t6` with shape (18, 1)
In [69]: t6.stride()
Out[69]: (1, 1)
这就是view()函数的神奇之处。它只是改变了每个新视图的(原始)张量的步幅,只要新视图的形状与原始形状兼容。
从strides元组可以观察到的另一件有趣的事情是,第0个位置的元素的值等于shape元组第1个位置的元素的值。
In [74]: t3.shape
Out[74]: torch.Size([3, 6])
|
In [75]: t3.stride() |
Out[75]: (6, 1) |
|_____________|
这是因为:
In [76]: t3
Out[76]:
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17]])
stride(6,1)表示沿着第0维从一个元素到下一个元素,我们必须跳跃或走6步。(即从0到6,要走6步。)但是在第一维中,从一个元素到下一个元素,我们只需要一步(例如从2到3)。
因此,步长信息是如何从内存中访问元素以执行计算的核心。
torch.reshape ()
这个函数将返回一个视图,并且与使用torch.Tensor.view()完全相同,只要新形状与原始张量的形状兼容。否则,它将返回一个副本。
然而,torch.重塑()的注释警告:
连续的输入和具有兼容步长的输入可以在不复制的情况下进行重塑,但是不应该依赖于复制与查看行为。
其他回答
View()在不复制内存的情况下重塑张量,类似于numpy的重塑()。
给定一个包含16个元素的张量a:
import torch
a = torch.range(1, 16)
为了重塑这个张量,使它成为一个4 x 4张量,使用:
a = a.view(4, 4)
现在a就是一个4 x 4张量。注意,在重塑之后,元素的总数需要保持不变。将张量a重塑为3 x 5张量是不合适的。
参数-1是什么意思?
如果有任何情况,你不知道你想要多少行,但确定列的数量,那么你可以指定这个-1。(注意,你可以将其扩展到更多维度的张量。只有一个轴值可以是-1)。这是一种告诉库的方式:“给我一个张量,它有这么多列,然后你计算出实现这一点所需的适当行数”。
这可以在模型定义代码中看到。在forward函数中的x = self.pool(F.relu(self.conv2(x)))行之后,您将得到一个16深度的特征映射。你必须把它压平,让它成为完全连接的层。所以你告诉PyTorch重塑你得到的张量,让它有特定的列数,并让它自己决定行数。
我们来做一些例题,从简单到难。
The view method returns a tensor with the same data as the self tensor (which means that the returned tensor has the same number of elements), but with a different shape. For example: a = torch.arange(1, 17) # a's shape is (16,) a.view(4, 4) # output below 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 [torch.FloatTensor of size 4x4] a.view(2, 2, 4) # output below (0 ,.,.) = 1 2 3 4 5 6 7 8 (1 ,.,.) = 9 10 11 12 13 14 15 16 [torch.FloatTensor of size 2x2x4] Assuming that -1 is not one of the parameters, when you multiply them together, the result must be equal to the number of elements in the tensor. If you do: a.view(3, 3), it will raise a RuntimeError because shape (3 x 3) is invalid for input with 16 elements. In other words: 3 x 3 does not equal 16 but 9. You can use -1 as one of the parameters that you pass to the function, but only once. All that happens is that the method will do the math for you on how to fill that dimension. For example a.view(2, -1, 4) is equivalent to a.view(2, 2, 4). [16 / (2 x 4) = 2] Notice that the returned tensor shares the same data. If you make a change in the "view" you are changing the original tensor's data: b = a.view(4, 4) b[0, 2] = 2 a[2] == 3.0 False Now, for a more complex use case. The documentation says that each new view dimension must either be a subspace of an original dimension, or only span d, d + 1, ..., d + k that satisfy the following contiguity-like condition that for all i = 0, ..., k - 1, stride[i] = stride[i + 1] x size[i + 1]. Otherwise, contiguous() needs to be called before the tensor can be viewed. For example: a = torch.rand(5, 4, 3, 2) # size (5, 4, 3, 2) a_t = a.permute(0, 2, 3, 1) # size (5, 3, 2, 4) # The commented line below will raise a RuntimeError, because one dimension # spans across two contiguous subspaces # a_t.view(-1, 4) # instead do: a_t.contiguous().view(-1, 4) # To see why the first one does not work and the second does, # compare a.stride() and a_t.stride() a.stride() # (24, 6, 2, 1) a_t.stride() # (24, 2, 1, 6) Notice that for a_t, stride[0] != stride[1] x size[1] since 24 != 2 x 3
我真的很喜欢@Jadiel de Armas的例子。
我想添加一个关于.view(…)元素如何排序的小见解。
对于形状为(a,b,c)的张量,其元素的顺序为 由编号系统确定:其中第一个数字有 数字,第二个数字是b,第三个数字是c。 由.view(…)返回的新张量中元素的映射 保持原始张量的阶数。
让我们通过下面的例子来理解view:
a=torch.range(1,16)
print(a)
tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
15., 16.])
print(a.view(-1,2))
tensor([[ 1., 2.],
[ 3., 4.],
[ 5., 6.],
[ 7., 8.],
[ 9., 10.],
[11., 12.],
[13., 14.],
[15., 16.]])
print(a.view(2,-1,4)) #3d tensor
tensor([[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.]],
[[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]])
print(a.view(2,-1,2))
tensor([[[ 1., 2.],
[ 3., 4.],
[ 5., 6.],
[ 7., 8.]],
[[ 9., 10.],
[11., 12.],
[13., 14.],
[15., 16.]]])
print(a.view(4,-1,2))
tensor([[[ 1., 2.],
[ 3., 4.]],
[[ 5., 6.],
[ 7., 8.]],
[[ 9., 10.],
[11., 12.]],
[[13., 14.],
[15., 16.]]])
-1作为参数值是计算x值的一种简单方法,前提是我们知道y和z的值,反之亦然,对于3d和2d,同样是计算x值的一种简单方法,前提是我们知道y的值,反之亦然。
参数-1是什么意思?
您可以将-1读取为动态参数数或“任何东西”。因此,在view()中只能有一个参数-1。
如果你问x.view(-1,1),这将输出张量形状[任何东西,1]取决于x中的元素数量。例如:
import torch
x = torch.tensor([1, 2, 3, 4])
print(x,x.shape)
print("...")
print(x.view(-1,1), x.view(-1,1).shape)
print(x.view(1,-1), x.view(1,-1).shape)
将输出:
tensor([1, 2, 3, 4]) torch.Size([4])
...
tensor([[1],
[2],
[3],
[4]]) torch.Size([4, 1])
tensor([[1, 2, 3, 4]]) torch.Size([1, 4])