pytorch 中的乘法*,@,torch.dot(), torch.matmul(), torch.mm(), torch.mul(), torch.bmm() 您所在的位置:网站首页 ot和mt的区别 pytorch 中的乘法*,@,torch.dot(), torch.matmul(), torch.mm(), torch.mul(), torch.bmm()

pytorch 中的乘法*,@,torch.dot(), torch.matmul(), torch.mm(), torch.mul(), torch.bmm()

2023-09-17 03:31| 来源: 网络整理| 查看: 265

简介

pytorch中实现乘法的操作有*,@,dot(),matmul(),mm(),mul(),bmm() *,@是两个运算符,他们分别映射到函数torch.mul和torch.matmul() 运算符映射函数表https://docs.python.org/3/library/operator.html#mapping-operators-to-functions

广播机制(摘自https://blog.csdn.net/MrR1ght/article/details/105660981 )

在具体叨叨这些函数之前先了解广播机制broadcasted:

numpy广播机制Broadcast

原理:python在进行numpy算术运算采用的是element-wise方式(逐元素操作的方式),此时要求两个数据的维度必须相同。

维度不同时,会触发广播操作使其维度相同。不满足广播操作的情况下会直接报错。

先理解下维度,便于理解broadcast,数据的维度指两个方面,维度的个数和维度的大小。

如:a = np.ones(4,3)维度个数是2,第一维大小是4,第二维大小是3

广播的执行过程:

1.如果维度个数不同,则在维度较少的左边补1,使得维度的个数相同。

2.各维度的维度大小不同时,如果有维度为1的,直接将该维拉伸至维度相同

torch.mul(a, b, *, out=None)

若a是tensor,b是标量,则torch.mul(a, b)=b乘a中每个元素,得到与a一样的tensor 若a是tensor,b是tensor,则torch.mul(a, b)会先对a、b进行广播,保持a、b维数一致,然后实现a和b elem wise相乘

import torch a = torch.randn((2,3)) print('a:',a) b = 100 c = torch.mul(a,b) print('torch.mul:',c) print('**************************************************') b = torch.randn((1,3)) #b的第二维度与a的第二维度若其中有一个不为1,则二者必须相等 print('b:',b) c = torch.mul(a,b) print('torch.mul:',c) print('**************************************************') a = torch.randn(4, 1) print('a:',a) b = torch.randn(1,4) print('b:',b) c = torch.mul(a,b) print('torch.mul:',c) a: tensor([[ 0.5374, 0.5964, -0.9717], [ 0.8812, 0.0650, 0.5432]]) torch.mul: tensor([[ 53.7366, 59.6445, -97.1669], [ 88.1228, 6.5001, 54.3242]]) ************************************************** b: tensor([[-0.5949, -1.2634, 0.5233]]) torch.mul: tensor([[-0.3197, -0.7535, -0.5084], [-0.5243, -0.0821, 0.2843]]) ************************************************** a: tensor([[ 1.1754], [ 0.0786], [-0.4219], [ 0.4158]]) b: tensor([[1.4896, 0.3157, 0.0024, 0.6818]]) torch.mul: tensor([[ 1.7509e+00, 3.7115e-01, 2.7752e-03, 8.0140e-01], [ 1.1702e-01, 2.4805e-02, 1.8548e-04, 5.3562e-02], [-6.2843e-01, -1.3321e-01, -9.9606e-04, -2.8763e-01], [ 6.1936e-01, 1.3129e-01, 9.8170e-04, 2.8348e-01]]) torch.matmul() #vector * vector = 相加相乘,最后得到一个数 a = torch.randn(3) b = torch.randn(3) c = torch.matmul(a,b) print('a:',a) print('b:',b) print('torch.matmul:',c) print('****************************************************') #matrix * vector = 矩阵相乘,matrix第二维需要与vector维度相同 a = torch.randn(3,4) b = torch.randn(4) c = torch.matmul(a,b) print('a:',a) print('b:',b) print('torch.matmul:',c) print('****************************************************') # batched matrix x broadcasted vector a = torch.randn(10, 3, 4) b = torch.randn(4) c = torch.matmul(a,b) print('a:',a.shape) print('b:',b.shape) print('torch.matmul:',c.shape) print('****************************************************') # batched matrix x batched matrix a = torch.randn(10,1,3,4) b = torch.randn(10,3,4,5) c = torch.matmul(a,b) print('a:',a.shape) print('b:',b.shape) print('torch.matmul:',c.shape) a = torch.randn(10,3,4) b = torch.randn(10,4,5) c = torch.matmul(a,b) print('a:',a.shape) print('b:',b.shape) print('torch.matmul:',c.shape) a = torch.randn(10,1,3,4) b = torch.randn(10,4,5) c = torch.matmul(a,b) print('a:',a.shape) print('b:',b.shape) print('torch.matmul:',c.shape) print('****************************************************') # batched matrix x broadcasted matrix a = torch.randn(10,3,4) b = torch.randn(4, 5) c = torch.matmul(a,b) print('a:',a.shape) print('b:',b.shape) print('torch.matmul:',c.shape) a: tensor([ 2.6767, -0.8028, 4.1741]) b: tensor([-1.0552, 0.2841, 0.8013]) torch.matmul: tensor(0.2923) **************************************************** a: tensor([[-1.2726, 0.6925, -0.3536, -0.2233], [-0.5659, 1.5294, 0.1152, -0.9903], [-0.2644, 0.5090, 0.7059, 0.2046]]) b: tensor([ 0.7085, -0.0952, 1.6654, -0.8139]) torch.matmul: tensor([-1.3747, 0.4513, 0.7733]) **************************************************** a: torch.Size([10, 3, 4]) b: torch.Size([4]) torch.matmul: torch.Size([10, 3]) **************************************************** a: torch.Size([10, 1, 3, 4]) b: torch.Size([10, 3, 4, 5]) torch.matmul: torch.Size([10, 3, 3, 5]) a: torch.Size([10, 3, 4]) b: torch.Size([10, 4, 5]) torch.matmul: torch.Size([10, 3, 5]) a: torch.Size([10, 1, 3, 4]) b: torch.Size([10, 4, 5]) torch.matmul: torch.Size([10, 10, 3, 5]) **************************************************** a: torch.Size([10, 3, 4]) b: torch.Size([4, 5]) torch.matmul: torch.Size([10, 3, 5]) torch.mm()

矩阵相乘,不会进行广播,必须满足矩阵相乘维数条件,两矩阵最多是2维

a = torch.randn(2, 3) b = torch.randn(3, 3) c = torch.mm(a,b) print('a:',a.shape) print('b:',b.shape) print('torch.mm:',c.shape) a: torch.Size([2, 3]) b: torch.Size([3, 3]) torch.mm: torch.Size([2, 3]) torch.bmm(a,b)

批矩阵相乘,不会进行广播,必须满足矩阵相乘维数条件,a,b最多只能3维,且a,b中必须包含相同的矩阵个数即a,b第一维度必须相同

a = torch.randn(10,2, 3) b = torch.randn(10,3, 3) c = torch.bmm(a,b) print('a:',a.shape) print('b:',b.shape) print('torch.bmm:',c.shape) a: torch.Size([10, 2, 3]) b: torch.Size([10, 3, 3]) torch.bmm: torch.Size([10, 2, 3]) torch.dot(a,b)

两向量相乘相加得到一个标量,必须都是一维的

a = torch.tensor([2, 3]) b = torch.tensor([1,2]) c = torch.dot(a,b) print('a:',a.shape) print('b:',b.shape) print('torch.dot:',c,c.shape) a: torch.Size([2]) b: torch.Size([2]) torch.dot: tensor(8) torch.Size([])


【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有