torch.cat函数在二维,三维数据中拼接时候 dim维度 理解 您所在的位置:网站首页 拼接鞋什么意思 torch.cat函数在二维,三维数据中拼接时候 dim维度 理解

torch.cat函数在二维,三维数据中拼接时候 dim维度 理解

2024-06-01 03:42| 来源: 网络整理| 查看: 265

文章目录 一、先来看看`torch.cat`函数的参数以及参数值二、torch.cat在二维数据中示例说明1. 在二维数据中 dim 取值 (0,1)2. 发现了什么问题总结: 三、从二维数据延伸到三维数据1. 三维数据`dim`取值为(0,1,2),依次类推2. 三维数据与二维数据的相关性3 . 逐一验证

一、先来看看torch.cat函数的参数以及参数值

torch中的cat函数用于沿着指定维度将张量连接起来。具体而言,如果给定一个包含多个张量的序列,通过指定dim参数可以将它们沿着指定维度连接在一起。

函数的常见形式如下:

torch.cat(seq, dim=0, out=None)

其中:

seq:一个Tensor序列,即要拼接的多个张量。 dim:连接的维度,默认为0(按行拼接)。可以是任何整数值,具体取值依赖于输入张量的维度。例如,对于二维张量,dim=0表示按行拼接,dim=1表示按列拼接。 out:输出张量。如果指定了此参数,则结果会被写入该张量中,不会创建新的张量。如果没有指定,则会创建新的张量作为结果返回。

二、torch.cat在二维数据中示例说明 1. 在二维数据中 dim 取值 (0,1)

举个例子,假设我们有两个二维张量A和B:

import torch A = torch.tensor([[1, 2], [3, 4]]) B = torch.tensor([[5, 6], [7, 8]])

如果要将它们按行拼接(即在第0维度上拼接),则可以这样做:

python C = torch.cat((A, B), dim=0) print(C) # 输出: # tensor([[1, 2], # [3, 4], # [5, 6], # [7, 8]])

如果要将它们按列拼接(即在第1维度上拼接),则可以这样做:

python C = torch.cat((A, B), dim=1) print(C) # 输出: # tensor([[1, 2, 5, 6], # [3, 4, 7, 8]])

总之,cat函数可以用于将张量沿着指定的维度连接在一起,非常灵活。需要根据具体情况选择合适的dim参数值来实现多种拼接方式。

2. 发现了什么问题 按行拼接(或者叫列合并)是指将多个二维张量沿着第0维度(即行)拼接在一起,形成一个更大的二维张量。例如,假设有两个二维张量A和B: A = [[1, 2], [3, 4]] B = [[5, 6], [7, 8]]

那么将它们按行拼接之后得到的结果就是:

C = [[1, 2], [3, 4], [5, 6], [7, 8]]

可以看到,新的张量C比原来的张量A和B都要长,因为它包含了两个输入张量中所有的行。

按列拼接(或者叫行合并)是指将多个二维张量沿着第1维度(即列)拼接在一起,形成一个更宽的二维张量。例如,假设有两个二维张量A和B: A = [[1, 2], [3, 4]] B = [[5, 6], [7, 8]]

那么将它们按列拼接之后得到的结果就是:

C = [[1, 2, 5, 6], [3, 4, 7, 8]]

可以看到,新的张量C比原来的张量A和B都要宽,因为它包含了两个输入张量中所有的列。

总结: 当 dim = 0时候, 按行拼接(或者叫列合并),此时 输出的结果中应当包含拼接数据的所有行 A = [[1, 2], [3, 4]] B = [[5, 6], [7, 8]]

那么将它们按行拼接之后得到的结果就是:

C = [[1, 2], [3, 4], [5, 6], [7, 8]] 当 dim = 1时候, 按列拼接(或者叫行合并),此时 输出的结果中应当包含拼接数据的所有列 A = [[1, 2], [3, 4]] B = [[5, 6], [7, 8]]

那么将它们按列拼接之后得到的结果就是:

C = [[1, 2, 5, 6], [3, 4, 7, 8]] 三、从二维数据延伸到三维数据 1. 三维数据dim取值为(0,1,2),依次类推

在三维张量中,dim参数的取值范围为0、1、2,具体的含义如下:

dim=0:表示沿着第0维度进行拼接。这意味着将两个包含多个矩阵的三维张量连接起来,形成一个更高的三维张量。 dim=1:表示沿着第1维度进行拼接。这意味着将两个包含多个行向量的三维张量连接起来,形成一个更宽的三维张量。 dim=2:表示沿着第2维度进行拼接。这意味着将两个包含多个列向量的三维张量连接起来,形成一个更深的三维张量。

2. 三维数据与二维数据的相关性

可以看到,三维比二维多了一个维度,0维度。事实上,三维数据中的 1和 2 维度,分别对应二维数据的 0和1维度,而三维数据中的 0 维度,含义就是 有多少个二维数据, 比如 :4x3x2 含义就是 4个 3x2的矩阵。

3 . 逐一验证 当dim取值为0的时候 例如,如果有两个三维张量A和B,每个张量都包含2个2x2的矩阵,我们可以按照第0维度(即沿着深度方向)将它们拼接在一起: import torch A = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) B = torch.tensor([[[9, 10], [11, 12]], [[13, 14], [15, 16]]]) C = torch.cat((A, B), dim=0) print(C.shape) # 输出:torch.Size([4, 3, 2])

可以看到,输出的结果 是不是 将 A和B的 2个矩阵拼接,就是4个2x2的矩阵,输出的结果也就是 4x2x2

当dim取值为 1 的时候

上面有说到:三维数据中的 1和 2 维度,分别对应二维数据的 0和1维度,而三维数据中的 0 维度,含义就是 有多少个二维数据,

C = torch.cat((A, B), dim=1) #沿着维度1拼接的结果: tensor([[[ 1, 2], [ 3, 4], [ 9, 10], [11, 12]], [[ 5, 6], [ 7, 8], [13, 14], [15, 16]]]) #沿着维度1拼接的结果的形状: torch.Size([2, 4, 2])

看到结果是不是验证了之前的说法,抛开0维度,当dim取值1时候。相当于二维数据中dim取值为0时候,也就是 当 dim = 0时候, 按行拼接(或者叫列合并),此时 输出的结果中应当包含拼接数据的所有行

当dim取值为 2 的时候 # 沿着维度2拼接 C2 = torch.cat((A, B), dim=2) print("沿着维度2拼接的结果:\n", C2) print("沿着维度2拼接的结果的形状:", C2.shape) #沿着维度2拼接的结果: tensor([[[ 1, 2, 9, 10], [ 3, 4, 11, 12]], [[ 5, 6, 13, 14], [ 7, 8, 15, 16]]]) #沿着维度2拼接的结果的形状: torch.Size([2, 2, 4])

看结果:抛开0维度,当dim取值2时候。相当于二维数据中dim取值为1时候,也就是:当 dim = 1时候, 按列拼接(或者叫行合并),此时 输出的结果中应当包含拼接数据的所有列



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有