Pytorch的类(nn.Module的子类)中的forward函数 您所在的位置:网站首页 类中的forward函数 Pytorch的类(nn.Module的子类)中的forward函数

Pytorch的类(nn.Module的子类)中的forward函数

#Pytorch的类(nn.Module的子类)中的forward函数| 来源: 网络整理| 查看: 265

使用

直接通过类的实例对象就可以向类中的forward函数进行参数的传递(当然也可以通过调用forward函数进行传参)

import torch.nn as nn class MyModule(nn.Module): def __init__(self): super().__init__() def forward(self, x): return x data1 = 1 data2 = 2 module = MyModule() x1 = module(data1) # 不需要显示调用forward函数就可以传递参数 x2 = module.forward(data2) print(x1) print(x2) >> 1 >> 2 解释

nn.Module() 中包含了 __call__ 函数;

实现了 __call__ 函数的类,其类实例是一个可调用的对象,其可以简化对于类中某些方法的调用(写在__call__ 中的方法),模糊了实例对象和类成员函数的区别。使用类实例 module() 时 就相当于 module.__call__(),如果在 __call()__ 中写上函数,就可以直接通过类实例对象传参调用了。

而在 nn.Module() 中的 __call__ 函数中调用了 forward() 函数,

... # 例子 # def __call__(self, param): res = self.forward(param) return res ...

由于继承关系,对于MyModule(nn.Module) 类 同样具备了 __call__ 函数的功能,即可以通过类实例module 直接 调用 forward 并传参。



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有