Pytorch中张量矩阵乘法函数(mm, bmm, matmul)使用说明,含高维张量实例及运行结果 您所在的位置:网站首页 不同维度的人决定不同维度人的命运 Pytorch中张量矩阵乘法函数(mm, bmm, matmul)使用说明,含高维张量实例及运行结果

Pytorch中张量矩阵乘法函数(mm, bmm, matmul)使用说明,含高维张量实例及运行结果

2024-07-07 22:47| 来源: 网络整理| 查看: 265

Pytorch中张量矩阵乘法函数使用说明 1 torch.mm() 函数1.1 torch.mm() 函数定义及参数1.2 torch.bmm() 官方示例 2 torch.bmm() 函数2.1 torch.bmm() 函数定义及参数2.2 torch.bmm() 官方示例 3 torch.matmul() 函数3.1 torch.matmul() 函数定义及参数3.2 torch.matmul() 规则约定3.3 torch.matmul() 官方示例3.4 高维数据实例解释 参考博文及感谢

1 torch.mm() 函数

全称为matrix-matrix product,对输入的张量做矩阵乘法运算,输入输出维度一定是2维;

1.1 torch.mm() 函数定义及参数

torch.bmm(input, mat2, , out=None) → Tensor input (Tensor) – – 第一个要相乘的矩阵 ** mat2* (Tensor) – – 第二个要相乘的矩阵 不支持广播到通用形状、类型推广以及整数、浮点和复杂输入。

1.2 torch.bmm() 官方示例 mat1 = torch.randn(2, 3) mat2 = torch.randn(3, 3) torch.mm(mat1, mat2) tensor([[ 0.4851, 0.5037, -0.3633], [-0.0760, -3.6705, 2.4784]]) 2 torch.bmm() 函数

全称为batch matrix-matrix product,对输入的张量做矩阵乘法运算,输入输出维度一定是3维;

2.1 torch.bmm() 函数定义及参数

torch.bmm(input, mat2, , out=None) → Tensor input (Tensor) – – 第一批要相乘的矩阵 ** mat2* (Tensor) – – 第二批要相乘的矩阵 不支持广播到通用形状、类型推广以及整数、浮点和复杂输入。

2.2 torch.bmm() 官方示例 input = torch.randn(10, 3, 4) mat2 = torch.randn(10, 4, 5) res = torch.bmm(input, mat2) res.size() torch.Size([10, 3, 5]) 3 torch.matmul() 函数

可进行多维矩阵运算,根据不同输入维度进行广播机制然后运算,和点积类似,广播机制可参考之前博文torch.mul()函数。

3.1 torch.matmul() 函数定义及参数

torch.matmul(input, mat2, , out=None) → Tensor input (Tensor) – – 第一个要相乘的张量 ** mat2* (Tensor) – – 第二个要相乘的张量 支持广播到通用形状、类型推广以及整数、浮点和复杂输入。

3.2 torch.matmul() 规则约定

(1)若两个都是1D(向量)的,则返回两个向量的点积;

(2)若两个都是2D(矩阵)的,则按照(矩阵相乘)规则返回2D;

(3)若input维度1D,other维度2D,则先将1D的维度扩充到2D(1D的维数前面+1),然后得到结果后再将此维度去掉,得到的与input的维度相同。即使作扩充(广播)处理,input的维度也要和other维度做对应关系;

(4)若input是2D,other是1D,则返回两者的点积结果;

(5)如果一个维度至少是1D,另外一个大于2D,则返回的是一个批矩阵乘法( a batched matrix multiply)

(a)若input是1D,other是大于2D的,则类似于规则(3);(b)若other是1D,input是大于2D的,则类似于规则(4);(c)若input和other都是3D的,则与torch.bmm()函数功能一样;(d)如果input中某一维度满足可以广播(扩充),那么也是可以进行相乘操作的。例如 input(j,1,n,m)* other (k,m,p) = output(j,k,n,p)

matmul() 根据输入矩阵自动决定如何相乘。低维根据高维需求,合理广播。

3.3 torch.matmul() 官方示例 # vector x vector tensor1 = torch.randn(3) tensor2 = torch.randn(3) torch.matmul(tensor1, tensor2).size() torch.Size([]) # matrix x vector tensor1 = torch.randn(3, 4) tensor2 = torch.randn(4) torch.matmul(tensor1, tensor2).size() torch.Size([3]) # batched matrix x broadcasted vector tensor1 = torch.randn(10, 3, 4) tensor2 = torch.randn(4) torch.matmul(tensor1, tensor2).size() torch.Size([10, 3]) # batched matrix x batched matrix tensor1 = torch.randn(10, 3, 4) tensor2 = torch.randn(10, 4, 5) torch.matmul(tensor1, tensor2).size() torch.Size([10, 3, 5]) # batched matrix x broadcasted matrix tensor1 = torch.randn(10, 3, 4) tensor2 = torch.randn(4, 5) torch.matmul(tensor1, tensor2).size() torch.Size([10, 3, 5]) 3.4 高维数据实例解释

直接看一个4维的二值例子,先看图(红虚线和实线是为了便于区分维度而添加),不懂再结合代码和结果分析,先做广播,然后对应矩阵进行乘积运算。 在这里插入图片描述

代码如下:

import torch import numpy as np np.random.seed(2022) a = np.random.randint(low=0, high=2, size=(2, 2, 3, 4)) a = torch.tensor(a) b = np.random.randint(low=0, high=2, size=(2, 1, 4, 3)) b = torch.tensor(b) c = torch.matmul(a, b) # or # c = a @ b print(a) print("=============================================") print(b) print("=============================================") print(c.size()) print("=============================================") print(c)

运行结果为:

tensor([[[[1, 0, 1, 0], [1, 1, 0, 1], [0, 0, 0, 0]], [[1, 1, 1, 1], [1, 1, 0, 0], [0, 1, 0, 1]]], [[[0, 0, 0, 1], [0, 0, 0, 1], [0, 1, 0, 0]], [[1, 1, 1, 1], [1, 1, 1, 1], [0, 0, 0, 0]]]], dtype=torch.int32) ============================================= tensor([[[[0, 1, 0], [1, 1, 0], [0, 0, 0], [1, 1, 0]]], [[[0, 1, 0], [1, 1, 1], [1, 1, 1], [1, 0, 1]]]], dtype=torch.int32) ============================================= torch.Size([2, 2, 3, 3]) ============================================= tensor([[[[0, 1, 0], [2, 3, 0], [0, 0, 0]], [[2, 3, 0], [1, 2, 0], [2, 2, 0]]], [[[1, 0, 1], [1, 0, 1], [1, 1, 1]], [[3, 3, 3], [3, 3, 3], [0, 0, 0]]]], dtype=torch.int32) 参考博文及感谢

部分内容参考以下链接,这里表示感谢 Thanks♪(・ω・)ノ 参考博文1 官方文档查询地址 https://pytorch.org/docs/stable/index.html 参考博文2 Pytorch矩阵乘法之torch.mul() 、 torch.mm() 及torch.matmul()的区别 https://blog.csdn.net/irober/article/details/113686080



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有