pytorch读取.pth文件 您所在的位置:网站首页 prt1怎么打开 pytorch读取.pth文件

pytorch读取.pth文件

2024-01-11 02:36| 来源: 网络整理| 查看: 265

1.pth文件中保存的是什么 import torch state_dict = torch.load("resnet18.pth") print(type(state_dict)) ---------------

如上打印输出所示,pth文件通过有序字典来保持模型参数。有序字典与常规字典一样,但是在排序操作方面有一些额外的功能。常规的dict是无序的,OrderedDict能够比dict更好地处理频繁的重新排序操作。 OrderedDict有一个方法popitem(last=True)用于有序字典的popitem()方法返回并删除一个(键,值)对。如果last为真,则按LIFO顺序返回对;如果为假,则按FIFO顺序返回对。 OrderedDict还有一个方法move_to_end(key,last=True),将现有的键移动到有序字典的两端。如果last为真,则将项目移动到右端(默认);如果last为假,则移动到开头。

import torch state_dict = torch.load("resnet18.pth") print(type(state_dict)) for i in state_dict: print(i) print(type(state_dict[i])) print("aa:",state_dict[i].data.size()) print("bb:",state_dict[i].requires_grad) break ------------------------------ conv1.weight aa: torch.Size([64, 3, 7, 7]) bb: True

如上打印所示,有序字典state_dict中每个元素都是Parameter参数,该参数是一种特殊的张量,包含data和requires_grad两个方法。其中data字段保存的是模型参数,requires_grad字段表示当前参数是否需要进行反向传播。

更多参考:https://www.jb51.net/article/168000.htm

2.torch.save()

先建立一个字典,保存三个参数:调用torch.save(),即可保存对应的pth文件。需要注意的是若模型是由nn.Moudle类继承的模型,保存pth文件时,state_dict参数需要由model.state_dict指定。

state_dict = {‘net':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch} torch.save(state_dict , dir) -------------------------------- torch.save(model.state_dict,dir) 3.torch.load()

当你想恢复某一阶段的训练(或者进行测试)时,那么就可以读取之前保存的网络模型参数等。

checkpoint = torch.load(dir) model.load_state_dict(checkpoint['net']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] + 1

pytorch加载预训练模型部分参数 resnet = models.resnet50(pretrained=True) new_state_dict = resnet.state_dict() dd = net.state_dict() #net是自己定义的含有resnet backbone的模型 for k in new_state_dict.keys(): print(k) if k in dd.keys() and not k.startswith('fc'): #不使用全连接的参数 print('yes') dd[k] = new_state_dict[k] net.load_state_dict(dd)

更加全面参考:https://blog.csdn.net/weixin_41519463/article/details/103205665



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有