torch.cat()函数的官方解释,详解以及例子 您所在的位置:网站首页 cat的码数 torch.cat()函数的官方解释,详解以及例子

torch.cat()函数的官方解释,详解以及例子

2024-06-30 15:59| 来源: 网络整理| 查看: 265

可以直接看最下面的例子,再回头看前面的解释,就很明白了。

在pytorch中,常见的拼接函数主要是两个,分别是:

stack()cat()

一般torch.cat()是为了把多个tensor进行拼接而存在的。实际使用中,和torch.stack()使用场景不同:参考链接torch.stack(),但是本文主要说cat()。

torch.cat() 和python中的内置函数cat(), 在使用和目的上,是没有区别的,区别在于前者操作对象是tensor。

1. cat()

函数目的: 在给定维度上对输入的张量序列seq 进行连接操作。

outputs = torch.cat(inputs, dim=?) → Tensor

参数

inputs : 待连接的张量序列,可以是任意相同Tensor类型的python 序列dim : 选择的扩维, 必须在0到len(inputs[0])之间,沿着此维连接张量序列。 2. 重点 输入数据必须是序列,序列中数据是任意相同的shape的同类型tensor维度不可以超过输入数据的任一个张量的维度 3.举例子 准备数据,每个的shape都是[2,3] # x1 x1 = torch.tensor([[11,21,31],[21,31,41]],dtype=torch.int) x1.shape # torch.Size([2, 3]) # x2 x2 = torch.tensor([[12,22,32],[22,32,42]],dtype=torch.int) x2.shape # torch.Size([2, 3]) 合成inputs 'inputs为2个形状为[2 , 3]的矩阵 ' inputs = [x1, x2] print(inputs) '打印查看' [tensor([[11, 21, 31], [21, 31, 41]], dtype=torch.int32), tensor([[12, 22, 32], [22, 32, 42]], dtype=torch.int32)]

3.查看结果, 测试不同的dim拼接结果

In [1]: torch.cat(inputs, dim=0).shape Out[1]: torch.Size([4, 3]) In [2]: torch.cat(inputs, dim=1).shape Out[2]: torch.Size([2, 6]) In [3]: torch.cat(inputs, dim=2).shape IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

大家可以复制代码运行一下就会发现其中规律了。

总结

通常用来,把torch.stack得到tensor进行拼接而存在的。



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有