torch.cat函数在二维,三维数据中拼接时候 dim维度 理解 | 您所在的位置:网站首页 › 拼接鞋什么意思 › torch.cat函数在二维,三维数据中拼接时候 dim维度 理解 |
文章目录
一、先来看看`torch.cat`函数的参数以及参数值二、torch.cat在二维数据中示例说明1. 在二维数据中 dim 取值 (0,1)2. 发现了什么问题总结:
三、从二维数据延伸到三维数据1. 三维数据`dim`取值为(0,1,2),依次类推2. 三维数据与二维数据的相关性3 . 逐一验证
一、先来看看torch.cat函数的参数以及参数值
torch中的cat函数用于沿着指定维度将张量连接起来。具体而言,如果给定一个包含多个张量的序列,通过指定dim参数可以将它们沿着指定维度连接在一起。 函数的常见形式如下: torch.cat(seq, dim=0, out=None)其中: seq:一个Tensor序列,即要拼接的多个张量。 dim:连接的维度,默认为0(按行拼接)。可以是任何整数值,具体取值依赖于输入张量的维度。例如,对于二维张量,dim=0表示按行拼接,dim=1表示按列拼接。 out:输出张量。如果指定了此参数,则结果会被写入该张量中,不会创建新的张量。如果没有指定,则会创建新的张量作为结果返回。 二、torch.cat在二维数据中示例说明 1. 在二维数据中 dim 取值 (0,1)举个例子,假设我们有两个二维张量A和B: import torch A = torch.tensor([[1, 2], [3, 4]]) B = torch.tensor([[5, 6], [7, 8]])如果要将它们按行拼接(即在第0维度上拼接),则可以这样做: python C = torch.cat((A, B), dim=0) print(C) # 输出: # tensor([[1, 2], # [3, 4], # [5, 6], # [7, 8]])如果要将它们按列拼接(即在第1维度上拼接),则可以这样做: python C = torch.cat((A, B), dim=1) print(C) # 输出: # tensor([[1, 2, 5, 6], # [3, 4, 7, 8]])总之,cat函数可以用于将张量沿着指定的维度连接在一起,非常灵活。需要根据具体情况选择合适的dim参数值来实现多种拼接方式。 2. 发现了什么问题 按行拼接(或者叫列合并)是指将多个二维张量沿着第0维度(即行)拼接在一起,形成一个更大的二维张量。例如,假设有两个二维张量A和B: A = [[1, 2], [3, 4]] B = [[5, 6], [7, 8]]那么将它们按行拼接之后得到的结果就是: C = [[1, 2], [3, 4], [5, 6], [7, 8]]可以看到,新的张量C比原来的张量A和B都要长,因为它包含了两个输入张量中所有的行。 按列拼接(或者叫行合并)是指将多个二维张量沿着第1维度(即列)拼接在一起,形成一个更宽的二维张量。例如,假设有两个二维张量A和B: A = [[1, 2], [3, 4]] B = [[5, 6], [7, 8]]那么将它们按列拼接之后得到的结果就是: C = [[1, 2, 5, 6], [3, 4, 7, 8]]可以看到,新的张量C比原来的张量A和B都要宽,因为它包含了两个输入张量中所有的列。 总结: 当 dim = 0时候, 按行拼接(或者叫列合并),此时 输出的结果中应当包含拼接数据的所有行 A = [[1, 2], [3, 4]] B = [[5, 6], [7, 8]]那么将它们按行拼接之后得到的结果就是: C = [[1, 2], [3, 4], [5, 6], [7, 8]] 当 dim = 1时候, 按列拼接(或者叫行合并),此时 输出的结果中应当包含拼接数据的所有列 A = [[1, 2], [3, 4]] B = [[5, 6], [7, 8]]那么将它们按列拼接之后得到的结果就是: C = [[1, 2, 5, 6], [3, 4, 7, 8]] 三、从二维数据延伸到三维数据 1. 三维数据dim取值为(0,1,2),依次类推在三维张量中,dim参数的取值范围为0、1、2,具体的含义如下: dim=0:表示沿着第0维度进行拼接。这意味着将两个包含多个矩阵的三维张量连接起来,形成一个更高的三维张量。 dim=1:表示沿着第1维度进行拼接。这意味着将两个包含多个行向量的三维张量连接起来,形成一个更宽的三维张量。 dim=2:表示沿着第2维度进行拼接。这意味着将两个包含多个列向量的三维张量连接起来,形成一个更深的三维张量。 2. 三维数据与二维数据的相关性可以看到,三维比二维多了一个维度,0维度。事实上,三维数据中的 1和 2 维度,分别对应二维数据的 0和1维度,而三维数据中的 0 维度,含义就是 有多少个二维数据, 比如 :4x3x2 含义就是 4个 3x2的矩阵。 3 . 逐一验证 当dim取值为0的时候 例如,如果有两个三维张量A和B,每个张量都包含2个2x2的矩阵,我们可以按照第0维度(即沿着深度方向)将它们拼接在一起: import torch A = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) B = torch.tensor([[[9, 10], [11, 12]], [[13, 14], [15, 16]]]) C = torch.cat((A, B), dim=0) print(C.shape) # 输出:torch.Size([4, 3, 2])可以看到,输出的结果 是不是 将 A和B的 2个矩阵拼接,就是4个2x2的矩阵,输出的结果也就是 4x2x2 当dim取值为 1 的时候上面有说到:三维数据中的 1和 2 维度,分别对应二维数据的 0和1维度,而三维数据中的 0 维度,含义就是 有多少个二维数据, C = torch.cat((A, B), dim=1) #沿着维度1拼接的结果: tensor([[[ 1, 2], [ 3, 4], [ 9, 10], [11, 12]], [[ 5, 6], [ 7, 8], [13, 14], [15, 16]]]) #沿着维度1拼接的结果的形状: torch.Size([2, 4, 2])看到结果是不是验证了之前的说法,抛开0维度,当dim取值1时候。相当于二维数据中dim取值为0时候,也就是 当 dim = 0时候, 按行拼接(或者叫列合并),此时 输出的结果中应当包含拼接数据的所有行 当dim取值为 2 的时候 # 沿着维度2拼接 C2 = torch.cat((A, B), dim=2) print("沿着维度2拼接的结果:\n", C2) print("沿着维度2拼接的结果的形状:", C2.shape) #沿着维度2拼接的结果: tensor([[[ 1, 2, 9, 10], [ 3, 4, 11, 12]], [[ 5, 6, 13, 14], [ 7, 8, 15, 16]]]) #沿着维度2拼接的结果的形状: torch.Size([2, 2, 4])看结果:抛开0维度,当dim取值2时候。相当于二维数据中dim取值为1时候,也就是:当 dim = 1时候, 按列拼接(或者叫行合并),此时 输出的结果中应当包含拼接数据的所有列 |
CopyRight 2018-2019 实验室设备网 版权所有 |