pytorch中获取模型参数:state 您所在的位置:网站首页 adx2650ck23gm参数 pytorch中获取模型参数:state

pytorch中获取模型参数:state

2024-07-09 15:24| 来源: 网络整理| 查看: 265

一、本文的模型案例

代码如下:

import torch import torch.nn.functional as F from torch.optim import SGD class MyNet(torch.nn.Module): def __init__(self): super(MyNet, self).__init__() # 第一句话,调用父类的构造函数 self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1) self.relu1=torch.nn.ReLU() self.max_pooling1=torch.nn.MaxPool2d(2,1) self.conv2 = torch.nn.Conv2d(3, 32, 3, 1, 1) self.relu2=torch.nn.ReLU() self.max_pooling2=torch.nn.MaxPool2d(2,1) self.dense1 = torch.nn.Linear(32 * 3 * 3, 128) self.dense2 = torch.nn.Linear(128, 10) def forward(self, x): x = self.conv1(x) x = self.relu1(x) x = self.max_pooling1(x) x = self.conv2(x) x = self.relu2(x) x = self.max_pooling2(x) x = self.dense1(x) x = self.dense2(x) return x model = MyNet() # 构造模型 二、model.state_dict()方法

pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)。这个方法的作用一方面是方便查看某一个层的权值和偏置数据,另一方面更多的是在模型保存的时候使用。

2.1 Module的层的权值以及bias查看

print(type(model.state_dict())) # 查看state_dict所返回的类型,是一个“顺序字典OrderedDict” for param_tensor in model.state_dict(): # 字典的遍历默认是遍历 key,所以param_tensor实际上是键值 print(param_tensor,'\t',model.state_dict()[param_tensor].size()) ''' conv1.weight torch.Size([32, 3, 3, 3]) conv1.bias torch.Size([32]) conv2.weight torch.Size([32, 3, 3, 3]) conv2.bias torch.Size([32]) dense1.weight torch.Size([128, 288]) dense1.bias torch.Size([128]) dense2.weight torch.Size([10, 128]) dense2.bias torch.Size([10]) '''

2.2 优化器optimizer的state_dict()方法

优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

optimizer = SGD(model.parameters(),lr=0.001,momentum=0.9) for var_name in optimizer.state_dict(): print(var_name,'\t',optimizer.state_dict()[var_name]) ''' state {} param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [1412966600640, 1412966613064, 1412966613136, 1412966613208, 1412966613280, 1412966613352, 1412966613496, 1412966613568] }] ''' 三、model.parameters()方法

这个方法也会获得模型的参数信息,如下:

print(type(model.parameters())) # 返回的是一个generator for para in model.parameters(): print(para.size()) # 只查看形状 ''' torch.Size([32, 3, 3, 3]) torch.Size([32]) torch.Size([32, 3, 3, 3]) torch.Size([32]) torch.Size([128, 288]) torch.Size([128]) torch.Size([10, 128]) torch.Size([10]) '''

从这里可以看出,其实这个state_dict方法所得到结果差不多,不同的是,model.parameters()方法返回的是一个生成器generator,每一个元素是从开头到结尾的参数,parameters没有对应的key名称,是一个由纯参数组成的generator,而state_dict是一个字典,包含了一个key。

其实Module还有一个与parameters类似的函数,named_parameters,而且parameters正是通过named_parameters来实现的,

看一下parameters的定义,很简单:  

def parameters(self, recurse=True): for name, param in self.named_parameters(recurse=recurse): yield param

来一起看一下named_parameters的简单使用。

print(type(model.named_parameters())) # 返回的是一个generator for para in model.named_parameters(): # 返回的每一个元素是一个元组 tuple ''' 是一个元组 tuple ,元组的第一个元素是参数所对应的名称,第二个元素就是对应的参数值 ''' print(para[0],'\t',para[1].size()) ''' conv1.weight torch.Size([32, 3, 3, 3]) conv1.bias torch.Size([32]) conv2.weight torch.Size([32, 3, 3, 3]) conv2.bias torch.Size([32]) dense1.weight torch.Size([128, 288]) dense1.bias torch.Size([128]) dense2.weight torch.Size([10, 128]) dense2.bias torch.Size([10]) '''

总结:model.state_dict()、model.parameters()、model.named_parameters()这三个方法都可以查看Module的参数信息,用于更新参数,或者用于模型的保存。



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有